diff --git a/app.py b/app.py index bea7992..6d8afc0 100644 --- a/app.py +++ b/app.py @@ -16,17 +16,17 @@ ############################################################################### # server.py -from flask import Flask, render_template,send_from_directory,request, jsonify +from flask import Flask, render_template, send_from_directory, request, jsonify from flask_sockets import Sockets import base64 import json -#import gevent -#from gevent import pywsgi -#from geventwebsocket.handler import WebSocketHandler +# import gevent +# from gevent import pywsgi +# from geventwebsocket.handler import WebSocketHandler import re import numpy as np -from threading import Thread,Event -#import multiprocessing +from threading import Thread, Event +# import multiprocessing import torch.multiprocessing as mp from aiohttp import web @@ -52,8 +52,8 @@ import sqlite3 import os app = Flask(__name__) -#sockets = Sockets(app) -nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal +# sockets = Sockets(app) +nerfreals: Dict[int, BaseReal] = {} # sessionid:BaseReal opt = None model = None avatar = None @@ -61,7 +61,8 @@ avatar = None #####webrtc############################### pcs = set() -def randN(N)->int: + +def randN(N) -> int: '''生成长度为 N的随机数 ''' min = pow(10, N - 1) max = pow(10, N) @@ -101,172 +102,116 @@ def set_enable_status(db_path): finally: conn.close() -def build_nerfreal(sessionid:int)->BaseReal: - opt.sessionid=sessionid - if opt.model == 'wav2lip': - from lipreal import LipReal - nerfreal = LipReal(opt,model,avatar) - return nerfreal - -#@app.route('/offer', methods=['POST']) -async def offer(request): - logger.info("开始处理offer请求") - # 接收请求参数 +# 用于更新数据库中的 sessionid +def update_sessionid_in_db(sessionid): try: - params = await request.json() - logger.debug("接收到的请求参数: %s", params) + # 打开数据库连接 + conn = sqlite3.connect(db_path, check_same_thread=False) + cursor = conn.cursor() + + # 更新 livetalking_sessionid 字段 + cursor.execute("UPDATE live_config SET value = ? WHERE key = 'livetalking_sessionid';", (str(sessionid),)) + conn.commit() + logger.info(f"Successfully updated livetalking_sessionid to {sessionid}") except Exception as e: - logger.error("解析请求参数失败: %s", str(e)) - return web.Response( - content_type="application/json", - text=json.dumps({"code": -2, "msg": "解析参数失败"}) - ) + logger.exception("Error updating livetalking_sessionid in database:") + finally: + if conn: + conn.close() - # 验证必要参数 - if "sdp" not in params or "type" not in params: - logger.warning("请求参数缺少sdp或type字段: %s", params) - return web.Response( - content_type="application/json", - text=json.dumps({"code": -3, "msg": "缺少必要参数"}) - ) - # 创建RTCSessionDescription - try: - offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) - logger.debug("成功创建RTCSessionDescription,类型: %s", params["type"]) - except Exception as e: - logger.error("创建RTCSessionDescription失败: %s", str(e)) - return web.Response( - content_type="application/json", - text=json.dumps({"code": -4, "msg": "创建offer失败"}) - ) +def build_nerfreal(sessionid: int) -> BaseReal: + opt.sessionid = sessionid + if opt.model == 'wav2lip': + from lipreal import LipReal + nerfreal = LipReal(opt, model, avatar) + return nerfreal - # 生成会话ID并记录 - sessionid = randN(6) - nerfreals[sessionid] = None - logger.info('生成新会话,sessionid=%d, 当前会话数量=%d', sessionid, len(nerfreals)) - # 构建nerfreal - try: - logger.debug("开始构建nerfreal,sessionid=%d", sessionid) - nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal, sessionid) - nerfreals[sessionid] = nerfreal - logger.info("成功构建nerfreal,sessionid=%d", sessionid) - except Exception as e: - logger.error("构建nerfreal失败,sessionid=%d: %s", sessionid, str(e)) - del nerfreals[sessionid] - return web.Response( - content_type="application/json", - text=json.dumps({"code": -5, "msg": "构建会话失败"}) - ) +# @app.route('/offer', methods=['POST']) +async def offer(request): + params = await request.json() + offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) + + # if len(nerfreals) >= opt.max_session: + # logger.info('reach max session') + # return web.Response( + # content_type="application/json", + # text=json.dumps( + # {"code": -1, "msg": "reach max session"} + # ), + # ) + sessionid = randN(6) # len(nerfreals) + nerfreals[sessionid] = None + logger.info('sessionid=%d, session num=%d', sessionid, len(nerfreals)) + update_sessionid_in_db(sessionid) - # 配置ICE服务器并创建RTCPeerConnection - try: - ice_server = RTCIceServer(urls='stun:stun.miwifi.com:3478') - logger.debug("使用ICE服务器: %s", ice_server.urls) + nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal, sessionid) + nerfreals[sessionid] = nerfreal - pc = RTCPeerConnection(configuration=RTCConfiguration(iceServers=[ice_server])) - pcs.add(pc) - logger.info("创建RTCPeerConnection成功,sessionid=%d", sessionid) - except Exception as e: - logger.error("创建RTCPeerConnection失败,sessionid=%d: %s", sessionid, str(e)) - del nerfreals[sessionid] - return web.Response( - content_type="application/json", - text=json.dumps({"code": -6, "msg": "创建连接失败"}) - ) + # ice_server = RTCIceServer(urls='stun:stun.l.google.com:19302') + ice_server = RTCIceServer(urls='stun:stun.miwifi.com:3478') + pc = RTCPeerConnection(configuration=RTCConfiguration(iceServers=[ice_server])) + pcs.add(pc) - # 连接状态变化处理 @pc.on("connectionstatechange") async def on_connectionstatechange(): - logger.info("sessionid=%d, 连接状态变更为: %s", sessionid, pc.connectionState) + logger.info("Connection state is %s" % pc.connectionState) if pc.connectionState == "failed": - logger.warning("sessionid=%d, 连接失败,关闭连接", sessionid) await pc.close() pcs.discard(pc) del nerfreals[sessionid] if pc.connectionState == "closed": - logger.info("sessionid=%d, 连接已关闭", sessionid) pcs.discard(pc) del nerfreals[sessionid] gc.collect() - logger.debug("sessionid=%d, 已清理资源", sessionid) - # 添加媒体轨道 - try: - player = HumanPlayer(nerfreals[sessionid]) - audio_sender = pc.addTrack(player.audio) - video_sender = pc.addTrack(player.video) - logger.debug("sessionid=%d, 已添加音频和视频轨道", sessionid) - except Exception as e: - logger.error("sessionid=%d, 添加媒体轨道失败: %s", sessionid, str(e)) - return web.Response( - content_type="application/json", - text=json.dumps({"code": -7, "msg": "添加媒体轨道失败"}) - ) + player = HumanPlayer(nerfreals[sessionid]) + audio_sender = pc.addTrack(player.audio) + video_sender = pc.addTrack(player.video) + capabilities = RTCRtpSender.getCapabilities("video") + preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs)) + preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs)) + preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs)) + transceiver = pc.getTransceivers()[1] + transceiver.setCodecPreferences(preferences) - # 配置编解码器偏好 - try: - capabilities = RTCRtpSender.getCapabilities("video") - preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs)) - preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs)) - preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs)) - logger.debug("sessionid=%d, 编解码器偏好: %s", sessionid, [p.name for p in preferences]) - - transceiver = pc.getTransceivers()[1] - transceiver.setCodecPreferences(preferences) - logger.debug("sessionid=%d, 已设置编解码器偏好", sessionid) - except Exception as e: - logger.error("sessionid=%d, 配置编解码器失败: %s", sessionid, str(e)) + await pc.setRemoteDescription(offer) - # 处理offer并创建answer - try: - await pc.setRemoteDescription(offer) - logger.debug("sessionid=%d, 已设置远程描述", sessionid) + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) - answer = await pc.createAnswer() - await pc.setLocalDescription(answer) - logger.debug("sessionid=%d, 已创建并设置本地answer", sessionid) - except Exception as e: - logger.error("sessionid=%d, 处理offer创建answer失败: %s", sessionid, str(e)) - return web.Response( - content_type="application/json", - text=json.dumps({"code": -8, "msg": "处理offer失败"}) - ) + # return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}) - # 返回结果 - response_data = { - "sdp": pc.localDescription.sdp, - "type": pc.localDescription.type, - "sessionid": sessionid - } - logger.info("sessionid=%d, offer请求处理完成,返回响应", sessionid) return web.Response( content_type="application/json", - text=json.dumps(response_data) + text=json.dumps( + {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "sessionid": sessionid} + ), ) + async def human(request): try: params = await request.json() - sessionid = params.get('sessionid',0) + sessionid = params.get('sessionid', 0) if params.get('interrupt'): nerfreals[sessionid].flush_talk() - if params['type']=='echo': + if params['type'] == 'echo': nerfreals[sessionid].speaking = True nerfreals[sessionid].put_msg_txt(params['text']) - elif params['type']=='chat': - asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'],nerfreals[sessionid]) - #nerfreals[sessionid].put_msg_txt(res) + elif params['type'] == 'chat': + asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'], nerfreals[sessionid]) + # nerfreals[sessionid].put_msg_txt(res) return web.Response( content_type="application/json", text=json.dumps( - {"code": 0, "msg":"ok"} + {"code": 0, "msg": "ok"} ), ) except Exception as e: @@ -278,17 +223,18 @@ async def human(request): ), ) + async def interrupt_talk(request): try: params = await request.json() - sessionid = params.get('sessionid',0) + sessionid = params.get('sessionid', 0) nerfreals[sessionid].flush_talk() - + return web.Response( content_type="application/json", text=json.dumps( - {"code": 0, "msg":"ok"} + {"code": 0, "msg": "ok"} ), ) except Exception as e: @@ -300,19 +246,20 @@ async def interrupt_talk(request): ), ) + async def humanaudio(request): try: - form= await request.post() - sessionid = int(form.get('sessionid',0)) + form = await request.post() + sessionid = int(form.get('sessionid', 0)) fileobj = form["file"] - filename=fileobj.filename - filebytes=fileobj.file.read() + filename = fileobj.filename + filebytes = fileobj.file.read() nerfreals[sessionid].put_audio_file(filebytes) return web.Response( content_type="application/json", text=json.dumps( - {"code": 0, "msg":"ok"} + {"code": 0, "msg": "ok"} ), ) except Exception as e: @@ -324,17 +271,18 @@ async def humanaudio(request): ), ) + async def set_audiotype(request): try: params = await request.json() - sessionid = params.get('sessionid',0) - nerfreals[sessionid].set_custom_state(params['audiotype'],params['reinit']) + sessionid = params.get('sessionid', 0) + nerfreals[sessionid].set_custom_state(params['audiotype'], params['reinit']) return web.Response( content_type="application/json", text=json.dumps( - {"code": 0, "msg":"ok"} + {"code": 0, "msg": "ok"} ), ) except Exception as e: @@ -346,20 +294,21 @@ async def set_audiotype(request): ), ) + async def record(request): try: params = await request.json() - sessionid = params.get('sessionid',0) - if params['type']=='start_record': + sessionid = params.get('sessionid', 0) + if params['type'] == 'start_record': # nerfreals[sessionid].put_msg_txt(params['text']) nerfreals[sessionid].start_recording() - elif params['type']=='end_record': + elif params['type'] == 'end_record': nerfreals[sessionid].stop_recording() return web.Response( content_type="application/json", text=json.dumps( - {"code": 0, "msg":"ok"} + {"code": 0, "msg": "ok"} ), ) except Exception as e: @@ -371,10 +320,11 @@ async def record(request): ), ) + async def is_speaking(request): params = await request.json() - sessionid = params.get('sessionid',0) + sessionid = params.get('sessionid', 0) return web.Response( content_type="application/json", text=json.dumps( @@ -395,16 +345,17 @@ async def on_shutdown(app): print("[INFO] 已关闭数据库连接") -async def post(url,data): +async def post(url, data): try: async with aiohttp.ClientSession() as session: - async with session.post(url,data=data) as response: + async with session.post(url, data=data) as response: return await response.text() except aiohttp.ClientError as e: logger.info(f'Error: {e}') -async def run(push_url,sessionid): - nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid) + +async def run(push_url, sessionid): + nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal, sessionid) nerfreals[sessionid] = nerfreal pc = RTCPeerConnection() @@ -422,10 +373,12 @@ async def run(push_url,sessionid): video_sender = pc.addTrack(player.video) await pc.setLocalDescription(await pc.createOffer()) - answer = await post(push_url,pc.localDescription.sdp) - await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer')) + answer = await post(push_url, pc.localDescription.sdp) + await pc.setRemoteDescription(RTCSessionDescription(sdp=answer, type='answer')) + + ########################################## - + if __name__ == '__main__': try: db_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', 'live_chat.db') @@ -434,7 +387,7 @@ if __name__ == '__main__': mp.set_start_method('spawn') parser = argparse.ArgumentParser() - parser.add_argument('--config', type=str, default='config/config.json', help="配置文件路径") + parser.add_argument('--config', type=str, default='../config.json', help="配置文件路径") # audio FPS parser.add_argument('--fps', type=int, default=50, help="audio fps,must be 50") @@ -446,9 +399,9 @@ if __name__ == '__main__': parser.add_argument('--W', type=int, default=450, help="GUI width") parser.add_argument('--H', type=int, default=450, help="GUI height") - #musetalk opt + # musetalk opt parser.add_argument('--avatar_id', type=str, default='model_4_1', help="define which avatar in data/avatars") - #parser.add_argument('--bbox_shift', type=int, default=5) + # parser.add_argument('--bbox_shift', type=int, default=5) parser.add_argument('--batch_size', type=int, default=16, help="infer batch") parser.add_argument('--customvideo_config', type=str, default='', help="custom action json") @@ -460,17 +413,17 @@ if __name__ == '__main__': # parser.add_argument('--CHARACTER', type=str, default='test') # parser.add_argument('--EMOTION', type=str, default='default') - parser.add_argument('--model', type=str, default='wav2lip') #musetalk wav2lip ultralight + parser.add_argument('--model', type=str, default='wav2lip') # musetalk wav2lip ultralight parser.add_argument('--transport', type=str, default='webrtc') #webrtc rtcpush virtualcam parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream - parser.add_argument('--max_session', type=int, default=1) #multi session count + parser.add_argument('--max_session', type=int, default=1) # multi session count parser.add_argument('--listenport', type=int, default=8010, help="web listen port") opt = parser.parse_args() - #app.config.from_object(opt) - #print(app.config) + # app.config.from_object(opt) + # print(app.config) # —— 新增:数据库字段名到 opt 属性名的映射 —— name_map = { 'listen_port': 'listenport', @@ -505,35 +458,36 @@ if __name__ == '__main__': logger.warning(f"加载配置文件时出错:{e},将全部使用命令行/默认参数") opt.customopt = [] - if opt.customvideo_config!='': - with open(opt.customvideo_config,'r') as file: + if opt.customvideo_config != '': + with open(opt.customvideo_config, 'r') as file: opt.customopt = json.load(file) - # if opt.model == 'ernerf': + # if opt.model == 'ernerf': # from nerfreal import NeRFReal,load_model,load_avatar # model = load_model(opt) - # avatar = load_avatar(opt) + # avatar = load_avatar(opt) if opt.model == 'wav2lip': from lipreal import LipReal,load_model,load_avatar,warm_up logger.info(opt) model = load_model("./models/wav2lip384.pth") avatar = load_avatar(opt.avatar_id) - warm_up(opt.batch_size,model,384) + warm_up(opt.batch_size, model, 384) elif opt.model == 'ultralight': - from lightreal import LightReal,load_model,load_avatar,warm_up + from lightreal import LightReal, load_model, load_avatar, warm_up + logger.info(opt) model = load_model(opt) avatar = load_avatar(opt.avatar_id) - warm_up(opt.batch_size,avatar,160) + warm_up(opt.batch_size, avatar, 160) - if opt.transport=='virtualcam': + if opt.transport == 'virtualcam': thread_quit = Event() nerfreals[0] = build_nerfreal(0) - rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,)) + rendthrd = Thread(target=nerfreals[0].render, args=(thread_quit,)) rendthrd.start() ############################################################################# - appasync = web.Application(client_max_size=1024**2*100) + appasync = web.Application(client_max_size=1024 ** 2 * 100) appasync.on_shutdown.append(on_shutdown) appasync.router.add_post("/offer", offer) appasync.router.add_post("/human", human) @@ -542,16 +496,16 @@ if __name__ == '__main__': appasync.router.add_post("/record", record) appasync.router.add_post("/interrupt_talk", interrupt_talk) appasync.router.add_post("/is_speaking", is_speaking) - appasync.router.add_static('/',path='web') + appasync.router.add_static('/', path='web') # Configure default CORS settings. cors = aiohttp_cors.setup(appasync, defaults={ - "*": aiohttp_cors.ResourceOptions( - allow_credentials=True, - expose_headers="*", - allow_headers="*", - ) - }) + "*": aiohttp_cors.ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_headers="*", + ) + }) # Configure CORS on all routes. for route in list(appasync.router.routes()): cors.add(route) @@ -565,21 +519,26 @@ if __name__ == '__main__': logger.info('如果使用webrtc,推荐访问webrtc集成前端: http://:'+str(opt.listenport)+'/dashboard.html') # 服务开启,将数据库中enable_status 对应的 value 字段改为 '1' set_enable_status(db_path) + + def run_server(runner): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(runner.setup()) site = web.TCPSite(runner, '0.0.0.0', opt.listenport) loop.run_until_complete(site.start()) - if opt.transport=='rtcpush': + if opt.transport == 'rtcpush': for k in range(opt.max_session): push_url = opt.push_url - if k!=0: - push_url = opt.push_url+str(k) - loop.run_until_complete(run(push_url,k)) + if k != 0: + push_url = opt.push_url + str(k) + loop.run_until_complete(run(push_url, k)) loop.run_forever() + + run_server(web.AppRunner(appasync)) except Exception: import traceback - traceback.print_exc() # 打印完整的错误堆栈 + + traceback.print_exc() # 打印完整的错误堆栈 input("发生异常,按回车键退出…") # 等待用户按回车再退出 \ No newline at end of file