############################################################################### # Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking # email: lipku@foxmail.com # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ############################################################################### # server.py from flask import Flask, render_template,send_from_directory,request, jsonify from flask_sockets import Sockets import base64 import json #import gevent #from gevent import pywsgi #from geventwebsocket.handler import WebSocketHandler import re import numpy as np from threading import Thread,Event #import multiprocessing import torch.multiprocessing as mp from aiohttp import web import aiohttp import aiohttp_cors from aiortc import RTCPeerConnection, RTCSessionDescription, RTCIceServer, RTCConfiguration from aiortc.rtcrtpsender import RTCRtpSender from webrtc import HumanPlayer from basereal import BaseReal from llm import llm_response import argparse import random import shutil import asyncio import torch from typing import Dict from logger import logger import gc import json import multiprocessing as mp import sqlite3 import os app = Flask(__name__) #sockets = Sockets(app) nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal opt = None model = None avatar = None #####webrtc############################### pcs = set() def randN(N)->int: '''生成长度为 N的随机数 ''' min = pow(10, N - 1) max = pow(10, N) return random.randint(min, max - 1) # 全局数据库连接和配置字典 conn = None config_dict = {} def load_db_config(db_path): global conn, config_dict # 建立可跨线程复用的连接 conn = sqlite3.connect(db_path, check_same_thread=False) conn.row_factory = sqlite3.Row # 性能调优:WAL 模式、不阻塞读取、缓存增大 conn.execute("PRAGMA journal_mode = WAL;") conn.execute("PRAGMA synchronous = NORMAL;") conn.execute("PRAGMA cache_size = 10000;") # 一次性加载所有 key/value cursor = conn.execute("SELECT key, value FROM livetalking_config;") config_dict = {row["key"]: row["value"] for row in cursor.fetchall()} def set_enable_status(db_path): # 连接数据库 conn = sqlite3.connect(db_path, check_same_thread=False) try: # 把 enable_status 对应的 value 字段改为 '1' conn.execute( "UPDATE live_config SET value = '1' WHERE key = 'livetlking_enable_status';" ) conn.commit() print("livetlking_enable_status 对应的 value 字段改为 '1'") finally: conn.close() def build_nerfreal(sessionid:int)->BaseReal: opt.sessionid=sessionid if opt.model == 'wav2lip': from lipreal import LipReal nerfreal = LipReal(opt,model,avatar) return nerfreal #@app.route('/offer', methods=['POST']) async def offer(request): logger.info("开始处理offer请求") # 接收请求参数 try: params = await request.json() logger.debug("接收到的请求参数: %s", params) except Exception as e: logger.error("解析请求参数失败: %s", str(e)) return web.Response( content_type="application/json", text=json.dumps({"code": -2, "msg": "解析参数失败"}) ) # 验证必要参数 if "sdp" not in params or "type" not in params: logger.warning("请求参数缺少sdp或type字段: %s", params) return web.Response( content_type="application/json", text=json.dumps({"code": -3, "msg": "缺少必要参数"}) ) # 创建RTCSessionDescription try: offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) logger.debug("成功创建RTCSessionDescription,类型: %s", params["type"]) except Exception as e: logger.error("创建RTCSessionDescription失败: %s", str(e)) return web.Response( content_type="application/json", text=json.dumps({"code": -4, "msg": "创建offer失败"}) ) # 生成会话ID并记录 sessionid = randN(6) nerfreals[sessionid] = None logger.info('生成新会话,sessionid=%d, 当前会话数量=%d', sessionid, len(nerfreals)) # 构建nerfreal try: logger.debug("开始构建nerfreal,sessionid=%d", sessionid) nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal, sessionid) nerfreals[sessionid] = nerfreal logger.info("成功构建nerfreal,sessionid=%d", sessionid) except Exception as e: logger.error("构建nerfreal失败,sessionid=%d: %s", sessionid, str(e)) del nerfreals[sessionid] return web.Response( content_type="application/json", text=json.dumps({"code": -5, "msg": "构建会话失败"}) ) # 配置ICE服务器并创建RTCPeerConnection try: ice_server = RTCIceServer(urls='stun:stun.miwifi.com:3478') logger.debug("使用ICE服务器: %s", ice_server.urls) pc = RTCPeerConnection(configuration=RTCConfiguration(iceServers=[ice_server])) pcs.add(pc) logger.info("创建RTCPeerConnection成功,sessionid=%d", sessionid) except Exception as e: logger.error("创建RTCPeerConnection失败,sessionid=%d: %s", sessionid, str(e)) del nerfreals[sessionid] return web.Response( content_type="application/json", text=json.dumps({"code": -6, "msg": "创建连接失败"}) ) # 连接状态变化处理 @pc.on("connectionstatechange") async def on_connectionstatechange(): logger.info("sessionid=%d, 连接状态变更为: %s", sessionid, pc.connectionState) if pc.connectionState == "failed": logger.warning("sessionid=%d, 连接失败,关闭连接", sessionid) await pc.close() pcs.discard(pc) del nerfreals[sessionid] if pc.connectionState == "closed": logger.info("sessionid=%d, 连接已关闭", sessionid) pcs.discard(pc) del nerfreals[sessionid] gc.collect() logger.debug("sessionid=%d, 已清理资源", sessionid) # 添加媒体轨道 try: player = HumanPlayer(nerfreals[sessionid]) audio_sender = pc.addTrack(player.audio) video_sender = pc.addTrack(player.video) logger.debug("sessionid=%d, 已添加音频和视频轨道", sessionid) except Exception as e: logger.error("sessionid=%d, 添加媒体轨道失败: %s", sessionid, str(e)) return web.Response( content_type="application/json", text=json.dumps({"code": -7, "msg": "添加媒体轨道失败"}) ) # 配置编解码器偏好 try: capabilities = RTCRtpSender.getCapabilities("video") preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs)) preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs)) preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs)) logger.debug("sessionid=%d, 编解码器偏好: %s", sessionid, [p.name for p in preferences]) transceiver = pc.getTransceivers()[1] transceiver.setCodecPreferences(preferences) logger.debug("sessionid=%d, 已设置编解码器偏好", sessionid) except Exception as e: logger.error("sessionid=%d, 配置编解码器失败: %s", sessionid, str(e)) # 处理offer并创建answer try: await pc.setRemoteDescription(offer) logger.debug("sessionid=%d, 已设置远程描述", sessionid) answer = await pc.createAnswer() await pc.setLocalDescription(answer) logger.debug("sessionid=%d, 已创建并设置本地answer", sessionid) except Exception as e: logger.error("sessionid=%d, 处理offer创建answer失败: %s", sessionid, str(e)) return web.Response( content_type="application/json", text=json.dumps({"code": -8, "msg": "处理offer失败"}) ) # 返回结果 response_data = { "sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "sessionid": sessionid } logger.info("sessionid=%d, offer请求处理完成,返回响应", sessionid) return web.Response( content_type="application/json", text=json.dumps(response_data) ) async def human(request): try: params = await request.json() sessionid = params.get('sessionid',0) if params.get('interrupt'): nerfreals[sessionid].flush_talk() if params['type']=='echo': nerfreals[sessionid].speaking = True nerfreals[sessionid].put_msg_txt(params['text']) elif params['type']=='chat': asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'],nerfreals[sessionid]) #nerfreals[sessionid].put_msg_txt(res) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg":"ok"} ), ) except Exception as e: logger.exception('exception:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def interrupt_talk(request): try: params = await request.json() sessionid = params.get('sessionid',0) nerfreals[sessionid].flush_talk() return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg":"ok"} ), ) except Exception as e: logger.exception('exception:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def humanaudio(request): try: form= await request.post() sessionid = int(form.get('sessionid',0)) fileobj = form["file"] filename=fileobj.filename filebytes=fileobj.file.read() nerfreals[sessionid].put_audio_file(filebytes) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg":"ok"} ), ) except Exception as e: logger.exception('exception:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def set_audiotype(request): try: params = await request.json() sessionid = params.get('sessionid',0) nerfreals[sessionid].set_custom_state(params['audiotype'],params['reinit']) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg":"ok"} ), ) except Exception as e: logger.exception('exception:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def record(request): try: params = await request.json() sessionid = params.get('sessionid',0) if params['type']=='start_record': # nerfreals[sessionid].put_msg_txt(params['text']) nerfreals[sessionid].start_recording() elif params['type']=='end_record': nerfreals[sessionid].stop_recording() return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg":"ok"} ), ) except Exception as e: logger.exception('exception:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def is_speaking(request): params = await request.json() sessionid = params.get('sessionid',0) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "data": nerfreals[sessionid].is_speaking()} ), ) async def on_shutdown(app): # close peer connections coros = [pc.close() for pc in pcs] await asyncio.gather(*coros) pcs.clear() # 关闭数据库连接 global conn if conn: conn.close() print("[INFO] 已关闭数据库连接") async def post(url,data): try: async with aiohttp.ClientSession() as session: async with session.post(url,data=data) as response: return await response.text() except aiohttp.ClientError as e: logger.info(f'Error: {e}') async def run(push_url,sessionid): nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid) nerfreals[sessionid] = nerfreal pc = RTCPeerConnection() pcs.add(pc) @pc.on("connectionstatechange") async def on_connectionstatechange(): logger.info("Connection state is %s" % pc.connectionState) if pc.connectionState == "failed": await pc.close() pcs.discard(pc) player = HumanPlayer(nerfreals[sessionid]) audio_sender = pc.addTrack(player.audio) video_sender = pc.addTrack(player.video) await pc.setLocalDescription(await pc.createOffer()) answer = await post(push_url,pc.localDescription.sdp) await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer')) ########################################## if __name__ == '__main__': try: db_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', 'live_chat.db') load_db_config(db_path) mp.set_start_method('spawn') parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='config/config.json', help="配置文件路径") # audio FPS parser.add_argument('--fps', type=int, default=50, help="audio fps,must be 50") # sliding window left-middle-right length (unit: 20ms) parser.add_argument('-l', type=int, default=10) parser.add_argument('-m', type=int, default=8) parser.add_argument('-r', type=int, default=10) parser.add_argument('--W', type=int, default=450, help="GUI width") parser.add_argument('--H', type=int, default=450, help="GUI height") #musetalk opt parser.add_argument('--avatar_id', type=str, default='model_4_1', help="define which avatar in data/avatars") #parser.add_argument('--bbox_shift', type=int, default=5) parser.add_argument('--batch_size', type=int, default=16, help="infer batch") parser.add_argument('--customvideo_config', type=str, default='', help="custom action json") parser.add_argument('--tts', type=str, default='gpt-sovits', help="tts service type") #xtts gpt-sovits cosyvoice parser.add_argument('--REF_FILE', type=str, default="input/gentle_girl.wav") parser.add_argument('--REF_TEXT', type=str, default="刚进直播间的宝子们,左上角先点个关注,点亮咱们家的粉丝灯牌!我是你们的主播陈婉婉,今天给大家准备了超级重磅的福利") parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000 # parser.add_argument('--CHARACTER', type=str, default='test') # parser.add_argument('--EMOTION', type=str, default='default') parser.add_argument('--model', type=str, default='wav2lip') #musetalk wav2lip ultralight parser.add_argument('--transport', type=str, default='webrtc') #webrtc rtcpush virtualcam parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream parser.add_argument('--max_session', type=int, default=1) #multi session count parser.add_argument('--listenport', type=int, default=8010, help="web listen port") opt = parser.parse_args() #app.config.from_object(opt) #print(app.config) # —— 新增:数据库字段名到 opt 属性名的映射 —— name_map = { 'listen_port': 'listenport', 'avatar_id': 'avatar_id', 'ref_file': 'REF_FILE', 'ref_text': 'REF_TEXT', 'tts_server': 'TTS_SERVER', } for db_key, db_val in config_dict.items(): if db_key not in name_map: continue arg_name = name_map[db_key] # 仅在用户没有显式传参(仍为默认值)时才覆盖 current = getattr(opt, arg_name) default = parser.get_default(arg_name) if current == default: try: setattr(opt, arg_name, type(default)(db_val)) except Exception: setattr(opt, arg_name, db_val) try: with open(opt.config, 'r', encoding='utf-8') as f: cfg = json.load(f) for key, val in cfg.items(): # 如果当前 opt.key 仍然是 parser 定义的默认值,就用配置文件里的 if getattr(opt, key, None) == parser.get_default(key): setattr(opt, key, val) except FileNotFoundError: logger.warning(f"配置文件未找到:{opt.config},将全部使用命令行/默认参数") except Exception as e: logger.warning(f"加载配置文件时出错:{e},将全部使用命令行/默认参数") opt.customopt = [] if opt.customvideo_config!='': with open(opt.customvideo_config,'r') as file: opt.customopt = json.load(file) # if opt.model == 'ernerf': # from nerfreal import NeRFReal,load_model,load_avatar # model = load_model(opt) # avatar = load_avatar(opt) if opt.model == 'wav2lip': from lipreal import LipReal,load_model,load_avatar,warm_up logger.info(opt) model = load_model("./models/wav2lip384.pth") avatar = load_avatar(opt.avatar_id) warm_up(opt.batch_size,model,384) elif opt.model == 'ultralight': from lightreal import LightReal,load_model,load_avatar,warm_up logger.info(opt) model = load_model(opt) avatar = load_avatar(opt.avatar_id) warm_up(opt.batch_size,avatar,160) if opt.transport=='virtualcam': thread_quit = Event() nerfreals[0] = build_nerfreal(0) rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,)) rendthrd.start() ############################################################################# appasync = web.Application(client_max_size=1024**2*100) appasync.on_shutdown.append(on_shutdown) appasync.router.add_post("/offer", offer) appasync.router.add_post("/human", human) appasync.router.add_post("/humanaudio", humanaudio) appasync.router.add_post("/set_audiotype", set_audiotype) appasync.router.add_post("/record", record) appasync.router.add_post("/interrupt_talk", interrupt_talk) appasync.router.add_post("/is_speaking", is_speaking) appasync.router.add_static('/',path='web') # Configure default CORS settings. cors = aiohttp_cors.setup(appasync, defaults={ "*": aiohttp_cors.ResourceOptions( allow_credentials=True, expose_headers="*", allow_headers="*", ) }) # Configure CORS on all routes. for route in list(appasync.router.routes()): cors.add(route) pagename='webrtcapi.html' if opt.transport=='rtmp': pagename='echoapi.html' elif opt.transport=='rtcpush': pagename='rtcpushapi.html' logger.info('start http server; http://:'+str(opt.listenport)+'/'+pagename) logger.info('如果使用webrtc,推荐访问webrtc集成前端: http://:'+str(opt.listenport)+'/dashboard.html') # 服务开启,将数据库中enable_status 对应的 value 字段改为 '1' set_enable_status(db_path) def run_server(runner): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(runner.setup()) site = web.TCPSite(runner, '0.0.0.0', opt.listenport) loop.run_until_complete(site.start()) if opt.transport=='rtcpush': for k in range(opt.max_session): push_url = opt.push_url if k!=0: push_url = opt.push_url+str(k) loop.run_until_complete(run(push_url,k)) loop.run_forever() run_server(web.AppRunner(appasync)) except Exception: import traceback traceback.print_exc() # 打印完整的错误堆栈 input("发生异常,按回车键退出…") # 等待用户按回车再退出