You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

834 lines
28 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
# WebAPI文档
` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml `
## 执行参数:
`-a` - `绑定地址, 默认"127.0.0.1"`
`-p` - `绑定端口, 默认9880`
`-c` - `TTS配置文件路径, 默认"GPT_SoVITS/configs/tts_infer.yaml"`
## 调用:
### 推理
endpoint: `/tts`
GET:
```
http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是「罗浮」云骑将军景元。不必拘谨,「将军」只是一时的身份,你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true
```
POST:
```json
{
"text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"streaming_mode": False, # bool. whether to return a streaming response.
"seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False # bool. whether to use super-sampling for audio when using VITS model V3.
}
```
RESP:
成功: 直接返回 wav 音频流, http code 200
失败: 返回包含错误信息的 json, http code 400
### 命令控制
endpoint: `/control`
command:
"restart": 重新运行
"exit": 结束运行
GET:
```
http://127.0.0.1:9880/control?command=restart
```
POST:
```json
{
"command": "restart"
}
```
RESP: 无
### 切换GPT模型
endpoint: `/set_gpt_weights`
GET:
```
http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
```
RESP:
成功: 返回"success", http code 200
失败: 返回包含错误信息的 json, http code 400
### 切换Sovits模型
endpoint: `/set_sovits_weights`
GET:
```
http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth
```
RESP:
成功: 返回"success", http code 200
失败: 返回包含错误信息的 json, http code 400
"""
import os
import sys
import traceback
from typing import Generator
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir))
import argparse
import subprocess
import wave
import signal
import numpy as np
import soundfile as sf
from fastapi import FastAPI, Response
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn
from io import BytesIO
from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
from pydantic import BaseModel
import sqlite3
import json
import socket
import nltk
# 获取当前脚本所在的目录
base_dir = os.path.dirname(__file__)
# 拼接相对路径
nltk.data.path.append(os.path.join(base_dir, "nltk_data"))
# print(sys.path)
i18n = I18nAuto()
cut_method_names = get_cut_method_names()
def get_port_from_json_config():
"""从父目录的config.json中读取端口配置"""
try:
# 配置文件路径父目录中的config.json
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'config.json')
config_path = os.path.normpath(config_path)
# 检查文件是否存在
if not os.path.exists(config_path):
print(f"配置文件不存在: {config_path}")
return None
# 读取并解析JSON文件
with open(config_path, 'r', encoding='utf-8') as f:
config_data = json.load(f)
# 查找端口配置,支持嵌套结构中的"port"或"gptsovits_port"等键
if 'port' in config_data:
port_value = config_data['port']
elif 'gptsovits' in config_data and 'port' in config_data['gptsovits']:
port_value = config_data['gptsovits']['port']
else:
print("配置文件中未找到端口配置")
return None
# 验证端口是否为有效整数
if isinstance(port_value, int):
return port_value
elif isinstance(port_value, str) and port_value.isdigit():
return int(port_value)
else:
print(f"配置文件中的端口值 {port_value} 不是有效的整数")
return None
except json.JSONDecodeError as e:
print(f"配置文件解析错误: {e}")
except Exception as e:
print(f"读取配置文件时发生错误: {e}")
return None
def import_db(db_path: str):
"""
连接到指定的数据库文件并返回连接对象。
参数:
- db_path: 数据库文件的路径
返回:
- conn: 数据库连接对象
"""
try:
# 连接到指定的数据库
conn = sqlite3.connect(db_path, check_same_thread=False)
print(f"成功连接到数据库:{db_path}")
return conn
except sqlite3.Error as e:
print(f"连接数据库失败:{e}")
return None
def get_port_from_db(db_path: str):
"""从数据库中读取端口配置"""
try:
conn = import_db(db_path)
if conn is None:
return None
cursor = conn.cursor()
# 从key-value结构的表中查询port配置
cursor.execute("SELECT value FROM gptsovits_config WHERE key = 'port' LIMIT 1;")
result = cursor.fetchone()
cursor.close()
conn.close()
# 转换为整数返回
if result and result[0] is not None:
try:
return int(result[0])
except ValueError:
print(f"数据库中 port 的值 {result[0]} 不是有效的整数")
except sqlite3.Error as e:
print(f"数据库操作错误: {e}")
except Exception as e:
print(f"读取端口配置时发生错误: {e}")
return None
def set_status(db_path: str, status_code: int):
"""
更新数据库中 gptsovits_enable_status 的状态值
0 未启动(前端写入)
1 启动中
2 启动成功
3 启动失败
参数:
- db_path: 数据库文件路径
- status_code: 服务状态码 (0, 1, 2, 3)
"""
conn = import_db(db_path)
if conn is None:
print("无法连接到数据库,无法更新状态")
return
try:
conn.execute(
"UPDATE live_config SET value = ? WHERE key = 'gptsovits_enable_status';",
(str(status_code),)
)
conn.commit()
print(f"gptsovits_enable_status 已更新为 {status_code}")
except sqlite3.Error as e:
print(f"更新数据库时发生错误: {e}")
finally:
conn.close()
def check_port(port: int):
"""检查端口是否被占用"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(('0.0.0.0', port)) # 尝试绑定端口
return True # 端口可用
except socket.error:
return False # 端口已占用
# 解析命令行参数
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 127.0.0.1")
parser.add_argument("-p", "--port", type=int, default=9880, help="default: 9880")
args = parser.parse_args()
config_path = args.tts_config
db_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'live_chat.db')
db_path = os.path.normpath(db_path)
# 按优先级获取端口
# 1. 从命令行参数获取
if args.port != 9880:
port = args.port
port_source = "命令行参数"
# 2. 从数据库获取
else:
db_port = get_port_from_db(db_path)
if db_port is not None:
port = db_port
port_source = "数据库配置"
# 3. 从配置文件获取
else:
json_port = get_port_from_json_config()
if json_port is not None:
port = json_port
port_source = "配置文件(config.json)"
# 4. 使用默认值
else:
port = 9880
port_source = "默认值"
# 打印端口使用信息
print(f"使用的端口: {port} (来源: {port_source})")
host = args.bind_addr
argv = sys.argv
if config_path in [None, ""]:
config_path = "GPT-SoVITS/configs/tts_infer.yaml"
tts_config = TTS_Config(config_path)
print(tts_config)
tts_pipeline = TTS(tts_config)
APP = FastAPI()
class TTS_Request(BaseModel):
text: str = None
text_lang: str = None
ref_audio_path: str = None
aux_ref_audio_paths: list = None
prompt_lang: str = None
prompt_text: str = ""
top_k: int = 5
top_p: float = 1
temperature: float = 1
text_split_method: str = "cut5"
batch_size: int = 1
batch_threshold: float = 0.75
split_bucket: bool = True
speed_factor: float = 1.0
fragment_interval: float = 0.3
seed: int = -1
media_type: str = "wav"
streaming_mode: bool = False
parallel_infer: bool = True
repetition_penalty: float = 1.35
sample_steps: int = 32
super_sampling: bool = False
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
audio_file.write(data)
return io_buffer
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
io_buffer.write(data.tobytes())
return io_buffer
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
io_buffer = BytesIO()
sf.write(io_buffer, data, rate, format="wav")
return io_buffer
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
process = subprocess.Popen(
[
"ffmpeg",
"-f",
"s16le", # 输入16位有符号小端整数PCM
"-ar",
str(rate), # 设置采样率
"-ac",
"1", # 单声道
"-i",
"pipe:0", # 从管道读取输入
"-c:a",
"aac", # 音频编码器为AAC
"-b:a",
"192k", # 比特率
"-vn", # 不包含视频
"-f",
"adts", # 输出AAC数据流格式
"pipe:1", # 将输出写入管道
],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
out, _ = process.communicate(input=data.tobytes())
io_buffer.write(out)
return io_buffer
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str):
if media_type == "ogg":
io_buffer = pack_ogg(io_buffer, data, rate)
elif media_type == "aac":
io_buffer = pack_aac(io_buffer, data, rate)
elif media_type == "wav":
io_buffer = pack_wav(io_buffer, data, rate)
else:
io_buffer = pack_raw(io_buffer, data, rate)
io_buffer.seek(0)
return io_buffer
# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
# This will create a wave header then append the frame input
# It should be first on a streaming wav file
# Other frames better should not have it (else you will hear some artifacts each chunk start)
wav_buf = BytesIO()
with wave.open(wav_buf, "wb") as vfout:
vfout.setnchannels(channels)
vfout.setsampwidth(sample_width)
vfout.setframerate(sample_rate)
vfout.writeframes(frame_input)
wav_buf.seek(0)
return wav_buf.read()
def handle_control(command: str):
if command == "restart":
os.execl(sys.executable, sys.executable, *argv)
elif command == "exit":
os.kill(os.getpid(), signal.SIGTERM)
exit(0)
def check_params(req: dict):
text: str = req.get("text", "")
text_lang: str = req.get("text_lang", "")
ref_audio_path: str = req.get("ref_audio_path", "")
streaming_mode: bool = req.get("streaming_mode", False)
media_type: str = req.get("media_type", "wav")
prompt_lang: str = req.get("prompt_lang", "")
text_split_method: str = req.get("text_split_method", "cut5")
if ref_audio_path in [None, ""]:
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
if text in [None, ""]:
return JSONResponse(status_code=400, content={"message": "text is required"})
if text_lang in [None, ""]:
return JSONResponse(status_code=400, content={"message": "text_lang is required"})
elif text_lang.lower() not in tts_config.languages:
return JSONResponse(
status_code=400,
content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"},
)
if prompt_lang in [None, ""]:
return JSONResponse(status_code=400, content={"message": "prompt_lang is required"})
elif prompt_lang.lower() not in tts_config.languages:
return JSONResponse(
status_code=400,
content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"},
)
if media_type not in ["wav", "raw", "ogg", "aac"]:
return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"})
elif media_type == "ogg" and not streaming_mode:
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
if text_split_method not in cut_method_names:
return JSONResponse(
status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"}
)
return None
# 分贝调整策略
def _to_float32_mono(x: np.ndarray) -> np.ndarray:
if x.dtype == np.int16:
return x.astype(np.float32) / 32768.0
return x.astype(np.float32)
def _peak_dbfs(xf: np.ndarray) -> float:
if xf.size == 0:
return -float("inf")
peak = float(np.max(np.abs(xf)))
if peak <= 1e-12:
return -float("inf")
return 20.0 * np.log10(peak)
def _rms_dbfs(xf: np.ndarray) -> float:
if xf.size == 0:
return -float("inf")
rms = float(np.sqrt(np.mean(xf * xf)))
if rms <= 1e-12:
return -float("inf")
return 20.0 * np.log10(rms)
def _apply_gain_linear(xf: np.ndarray, gain_db: float) -> np.ndarray:
gain = 10.0 ** (gain_db / 20.0)
return xf * gain
def _limiter_peak(xf: np.ndarray, thresh_db: float = -1.0, soft: bool = True):
"""
简易峰值限幅器
thresh_db 为阈值,默认 -1 dBFS
soft 为 True 时使用软限幅tanh听感更顺滑
返回 (处理后波形, 是否触发限幅)
"""
if xf.size == 0:
return xf, False
thresh_lin = 10.0 ** (thresh_db / 20.0)
peak = float(np.max(np.abs(xf)))
if peak <= thresh_lin or peak <= 1e-12:
return xf, False
if soft:
k = 2.0
out = np.tanh(k * xf / peak) * thresh_lin
return out.astype(np.float32), True
scale = thresh_lin / peak
out = xf * scale
return out.astype(np.float32), True
class DynamicGainState:
def __init__(
self,
target_peak_db: float = -1.0,
max_boost_db: float = 18.0,
max_cut_db: float = 24.0,
min_rms_gate_db: float = -45.0,
quiet_boost_cap_db: float = 6.0,
attack_fast: float = 0.25,
release_slow: float = 0.08,
limiter_thresh_db: float = -1.0,
limiter_soft: bool = True,
):
self.target_peak_db = target_peak_db
self.max_boost_db = max_boost_db
self.max_cut_db = max_cut_db
self.min_rms_gate_db = min_rms_gate_db
self.quiet_boost_cap_db = quiet_boost_cap_db
self.attack_fast = attack_fast
self.release_slow = release_slow
self.limiter_thresh_db = limiter_thresh_db
self.limiter_soft = limiter_soft
self.prev_gain_db = 0.0
def compute_chunk(self, x: np.ndarray):
xf = _to_float32_mono(x)
peak_db = _peak_dbfs(xf)
rms_db = _rms_dbfs(xf)
if peak_db == -float("inf"):
ideal_gain_db = self.max_boost_db
else:
ideal_gain_db = self.target_peak_db - peak_db
if rms_db != -float("inf") and rms_db < self.min_rms_gate_db:
ideal_gain_db = min(ideal_gain_db, self.quiet_boost_cap_db)
ideal_gain_db = max(-self.max_cut_db, min(self.max_boost_db, ideal_gain_db))
if ideal_gain_db > self.prev_gain_db:
alpha = self.attack_fast
else:
alpha = self.release_slow
gain_db = self.prev_gain_db + alpha * (ideal_gain_db - self.prev_gain_db)
self.prev_gain_db = gain_db
y = _apply_gain_linear(xf, gain_db)
y, limited = _limiter_peak(y, self.limiter_thresh_db, soft=self.limiter_soft)
post_peak_db = _peak_dbfs(y)
post_rms_db = _rms_dbfs(y)
info = {
"peak_db": peak_db,
"rms_db": rms_db,
"ideal_gain_db": ideal_gain_db,
"applied_gain_db": gain_db,
"limited": limited,
"post_peak_db": post_peak_db,
"post_rms_db": post_rms_db,
}
return y.astype(np.float32), info
dyn_state = DynamicGainState(
target_peak_db=-1.0,
max_boost_db=18.0,
max_cut_db=24.0,
min_rms_gate_db=-45.0,
quiet_boost_cap_db=6.0,
attack_fast=0.25,
release_slow=0.08,
limiter_thresh_db=-1.0,
limiter_soft=True,
)
async def tts_handle(req: dict):
"""
Text to speech handler.
Args:
req (dict):
{
"text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
"streaming_mode": False, # bool. whether to return a streaming response.
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
}
returns:
StreamingResponse: audio stream response.
"""
streaming_mode = req.get("streaming_mode", False)
return_fragment = req.get("return_fragment", False)
media_type = req.get("media_type", "wav")
check_res = check_params(req)
if check_res is not None:
return check_res
if streaming_mode or return_fragment:
req["return_fragment"] = True
try:
tts_generator = tts_pipeline.run(req)
if streaming_mode:
def streaming_generator(tts_generator: Generator, media_type: str):
if_frist_chunk = True
for sr, chunk in tts_generator:
processed, info = dyn_state.compute_chunk(chunk)
print(
f"[响度] 原峰值 {info['peak_db']:.2f} dBFS | 原RMS {info['rms_db']:.2f} dBFS | "
f"理想增益 {info['ideal_gain_db']:.2f} dB | 实际增益 {info['applied_gain_db']:.2f} dB | "
f"限幅 {info['limited']} | 处理后峰值 {info['post_peak_db']:.2f} dBFS | 处理后RMS {info['post_rms_db']:.2f} dBFS"
)
if if_frist_chunk and media_type == "wav":
yield wave_header_chunk(sample_rate=sr)
media_type = "raw"
if_frist_chunk = False
yield pack_audio(BytesIO(), processed, sr, media_type).getvalue()
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
return StreamingResponse(
streaming_generator(
tts_generator,
media_type,
),
media_type=f"audio/{media_type}",
)
else:
sr, audio_data = next(tts_generator)
processed, info = dyn_state.compute_chunk(audio_data)
print(
f"[响度] 原峰值 {info['peak_db']:.2f} dBFS | 原RMS {info['rms_db']:.2f} dBFS | "
f"理想增益 {info['ideal_gain_db']:.2f} dB | 实际增益 {info['applied_gain_db']:.2f} dB | "
f"限幅 {info['limited']} | 处理后峰值 {info['post_peak_db']:.2f} dBFS | 处理后RMS {info['post_rms_db']:.2f} dBFS"
)
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
return Response(audio_data, media_type=f"audio/{media_type}")
except Exception as e:
return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)})
@APP.get("/control")
async def control(command: str = None):
if command is None:
return JSONResponse(status_code=400, content={"message": "command is required"})
handle_control(command)
@APP.get("/tts")
async def tts_get_endpoint(
text: str = None,
text_lang: str = None,
ref_audio_path: str = None,
aux_ref_audio_paths: list = None,
prompt_lang: str = None,
prompt_text: str = "",
top_k: int = 5,
top_p: float = 1,
temperature: float = 1,
text_split_method: str = "cut0",
batch_size: int = 1,
batch_threshold: float = 0.75,
split_bucket: bool = True,
speed_factor: float = 1.0,
fragment_interval: float = 0.3,
seed: int = -1,
media_type: str = "wav",
streaming_mode: bool = False,
parallel_infer: bool = True,
repetition_penalty: float = 1.35,
sample_steps: int = 32,
super_sampling: bool = False,
):
req = {
"text": text,
"text_lang": text_lang.lower(),
"ref_audio_path": ref_audio_path,
"aux_ref_audio_paths": aux_ref_audio_paths,
"prompt_text": prompt_text,
"prompt_lang": prompt_lang.lower(),
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"text_split_method": text_split_method,
"batch_size": int(batch_size),
"batch_threshold": float(batch_threshold),
"speed_factor": float(speed_factor),
"split_bucket": split_bucket,
"fragment_interval": fragment_interval,
"seed": seed,
"media_type": media_type,
"streaming_mode": streaming_mode,
"parallel_infer": parallel_infer,
"repetition_penalty": float(repetition_penalty),
"sample_steps": int(sample_steps),
"super_sampling": super_sampling,
}
return await tts_handle(req)
@APP.post("/tts")
async def tts_post_endpoint(request: TTS_Request):
req = request.dict()
return await tts_handle(req)
@APP.get("/set_refer_audio")
async def set_refer_aduio(refer_audio_path: str = None):
try:
tts_pipeline.set_ref_audio(refer_audio_path)
except Exception as e:
return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)})
return JSONResponse(status_code=200, content={"message": "success"})
# @APP.post("/set_refer_audio")
# async def set_refer_aduio_post(audio_file: UploadFile = File(...)):
# try:
# # 检查文件类型,确保是音频文件
# if not audio_file.content_type.startswith("audio/"):
# return JSONResponse(status_code=400, content={"message": "file type is not supported"})
# os.makedirs("uploaded_audio", exist_ok=True)
# save_path = os.path.join("uploaded_audio", audio_file.filename)
# # 保存音频文件到服务器上的一个目录
# with open(save_path , "wb") as buffer:
# buffer.write(await audio_file.read())
# tts_pipeline.set_ref_audio(save_path)
# except Exception as e:
# return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)})
# return JSONResponse(status_code=200, content={"message": "success"})
@APP.get("/set_gpt_weights")
async def set_gpt_weights(weights_path: str = None):
try:
if weights_path in ["", None]:
return JSONResponse(status_code=400, content={"message": "gpt weight path is required"})
tts_pipeline.init_t2s_weights(weights_path)
except Exception as e:
return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)})
return JSONResponse(status_code=200, content={"message": "success"})
@APP.get("/set_sovits_weights")
async def set_sovits_weights(weights_path: str = None):
try:
if weights_path in ["", None]:
return JSONResponse(status_code=400, content={"message": "sovits weight path is required"})
tts_pipeline.init_vits_weights(weights_path)
except Exception as e:
return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)})
return JSONResponse(status_code=200, content={"message": "success"})
if __name__ == "__main__":
try:
# 启动服务器中状态修改为1
set_status(db_path, 1)
if host == "None": # 在调用时使用 -a None 参数可以让api监听双栈
host = None
port = get_port_from_db(db_path)
if not check_port(port):
print(f"端口 {port} 已被占用,无法启动服务。")
# 端口被占用启动服务器失败状态修改为3
set_status(db_path, 3)
exit(1)
# 启动服务器成功状态修改为2
set_status(db_path, 2)
# 再启动服务器(这是一个阻塞调用)
uvicorn.run(app=APP, host=host, port=port, workers=1)
except Exception:
traceback.print_exc()
# 启动服务器失败状态修改为3
set_status(db_path, 3)
os.kill(os.getpid(), signal.SIGTERM)
exit(0)