diff --git a/api_v2.py b/api_v2.py index 5443e33..9abfcba 100644 --- a/api_v2.py +++ b/api_v2.py @@ -122,6 +122,9 @@ from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names from pydantic import BaseModel import sqlite3 +import json +import socket + # print(sys.path) i18n = I18nAuto() @@ -168,85 +171,108 @@ def get_port_from_json_config(): return None -def get_port_from_db(): + +def import_db(db_path: str): + """ + 连接到指定的数据库文件并返回连接对象。 + + 参数: + - db_path: 数据库文件的路径 + + 返回: + - conn: 数据库连接对象 + """ + try: + # 连接到指定的数据库 + conn = sqlite3.connect(db_path, check_same_thread=False) + print(f"成功连接到数据库:{db_path}") + return conn + except sqlite3.Error as e: + print(f"连接数据库失败:{e}") + return None + +def get_port_from_db(db_path: str): """从数据库中读取端口配置""" try: - db_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'live_chat.db') - db_path = os.path.normpath(db_path) - conn = sqlite3.connect(db_path) - conn.execute("PRAGMA journal_mode=WAL;") + conn = import_db(db_path) + if conn is None: + return None + cursor = conn.cursor() - + # 从key-value结构的表中查询port配置 cursor.execute("SELECT value FROM gptsovits_config WHERE key = 'port' LIMIT 1;") result = cursor.fetchone() - + cursor.close() conn.close() - + # 转换为整数返回 if result and result[0] is not None: try: return int(result[0]) except ValueError: print(f"数据库中 port 的值 {result[0]} 不是有效的整数") - + except sqlite3.Error as e: print(f"数据库操作错误: {e}") except Exception as e: print(f"读取端口配置时发生错误: {e}") - + return None -def enable_gptsovits_in_db(): - """启用数据库中的gptsovits服务""" + +def set_status(db_path: str, status_code: int): + """ + 更新数据库中 gptsovits_enable_status 的状态值 + 0 未启动(前端写入) + 1 启动中 + 2 启动成功 + 3 启动失败 + + 参数: + - db_path: 数据库文件路径 + - status_code: 服务状态码 (0, 1, 2, 3) + """ + conn = import_db(db_path) + if conn is None: + print("无法连接到数据库,无法更新状态") + return + try: - db_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'live_chat.db') - db_path = os.path.normpath(db_path) - conn = sqlite3.connect(db_path) - conn.execute("PRAGMA journal_mode=WAL;") - cursor = conn.cursor() - - # 更新live_config表中的gptsovits_enable_status为1 - cursor.execute(""" - UPDATE live_config - SET value = '1' - WHERE key = 'gptsovits_enable_status' - """) - - # 检查是否有行被更新 - if cursor.rowcount > 0: - print("已成功将gptsovits_enable_status设置为1") - conn.commit() - else: - print("未找到gptsovits_enable_status记录,可能需要插入") - # 尝试插入记录(如果不存在) - cursor.execute(""" - INSERT INTO live_config (key, value, comment) - VALUES ('gptsovits_enable_status', '1', 'gptsovits是否启动,1=启动;0=关闭') - """) - conn.commit() - print("已插入gptsovits_enable_status记录并设置为1") - - cursor.close() - conn.close() - # 打印状态更新完毕信息 - print("状态更新完毕") - + conn.execute( + "UPDATE live_config SET value = ? WHERE key = 'gptsovits_enable_status';", + (str(status_code),) + ) + conn.commit() + print(f"gptsovits_enable_status 已更新为 {status_code}") except sqlite3.Error as e: print(f"更新数据库时发生错误: {e}") - except Exception as e: - print(f"启用gptsovits服务时发生错误: {e}") + finally: + conn.close() + +def check_port(port: int): + """检查端口是否被占用""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(('0.0.0.0', port)) # 尝试绑定端口 + return True # 端口可用 + except socket.error: + return False # 端口已占用 # 解析命令行参数 parser = argparse.ArgumentParser(description="GPT-SoVITS api") parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径") -parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1") +parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 127.0.0.1") parser.add_argument("-p", "--port", type=int, default=9880, help="default: 9880") args = parser.parse_args() config_path = args.tts_config +db_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'live_chat.db') +db_path = os.path.normpath(db_path) + + # 按优先级获取端口 # 1. 从命令行参数获取 if args.port != 9880: @@ -254,7 +280,7 @@ if args.port != 9880: port_source = "命令行参数" # 2. 从数据库获取 else: - db_port = get_port_from_db() + db_port = get_port_from_db(db_path) if db_port is not None: port = db_port port_source = "数据库配置" @@ -628,17 +654,25 @@ async def set_sovits_weights(weights_path: str = None): if __name__ == "__main__": try: - # 先执行数据库更新操作 - enable_gptsovits_in_db() - + + set_status(db_path, 1) + if host == "None": # 在调用时使用 -a None 参数,可以让api监听双栈 host = None - + + port = get_port_from_db(db_path) + + if not check_port(port): + print(f"端口 {port} 已被占用,无法启动服务。") + set_status(db_path, 3) # 更新状态为启动失败 + exit(1) + set_status(db_path, 2) # 再启动服务器(这是一个阻塞调用) uvicorn.run(app=APP, host=host, port=port, workers=1) except Exception: traceback.print_exc() + set_status(db_path, 3) os.kill(os.getpid(), signal.SIGTERM) exit(0) \ No newline at end of file