############################################################################### # Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking # email: lipku@foxmail.com # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ############################################################################### # server.py from flask import Flask, render_template, send_from_directory, request, jsonify from flask_sockets import Sockets import base64 import json # import gevent # from gevent import pywsgi # from geventwebsocket.handler import WebSocketHandler import re import numpy as np from threading import Thread, Event # import multiprocessing import torch.multiprocessing as mp from aiohttp import web import aiohttp import aiohttp_cors from aiortc import RTCPeerConnection, RTCSessionDescription, RTCIceServer, RTCConfiguration from aiortc.rtcrtpsender import RTCRtpSender from webrtc import HumanPlayer from basereal import BaseReal from llm import llm_response import argparse import random import shutil import asyncio import torch from typing import Dict from logger import logger import gc import json import multiprocessing as mp import sqlite3 import os import socket app = Flask(__name__) # sockets = Sockets(app) nerfreals: Dict[int, BaseReal] = {} # sessionid:BaseReal opt = None model = None avatar = None #####webrtc############################### pcs = set() def randN(N) -> int: '''生成长度为 N的随机数 ''' min = pow(10, N - 1) max = pow(10, N) return random.randint(min, max - 1) # 全局数据库连接和配置字典 conn = None config_dict = {} def load_db_config(db_path): global conn, config_dict # 建立可跨线程复用的连接 conn = sqlite3.connect(db_path, check_same_thread=False) conn.row_factory = sqlite3.Row # 性能调优:WAL 模式、不阻塞读取、缓存增大 conn.execute("PRAGMA journal_mode = WAL;") conn.execute("PRAGMA synchronous = NORMAL;") conn.execute("PRAGMA cache_size = 10000;") # 一次性加载所有 key/value cursor = conn.execute("SELECT key, value FROM livetalking_config;") config_dict = {row["key"]: row["value"] for row in cursor.fetchall()} def set_status(db_path, status_code: int): """ 更新数据库中 livetlking_enable_status 的状态值 0 未启动(前端写入) 1 启动中 2 启动成功 3 启动失败 """ conn = sqlite3.connect(db_path, check_same_thread=False) try: conn.execute( "UPDATE live_config SET value = ? WHERE key = 'livetlking_enable_status';", (str(status_code),) ) conn.commit() print(f"livetlking_enable_status 已更新为 {status_code}") finally: conn.close() # 用于更新数据库中的 sessionid def update_sessionid_in_db(sessionid): try: # 打开数据库连接 conn = sqlite3.connect(db_path, check_same_thread=False) cursor = conn.cursor() # 更新 livetalking_sessionid 字段 cursor.execute("UPDATE live_config SET value = ? WHERE key = 'livetalking_sessionid';", (str(sessionid),)) conn.commit() logger.info(f"Successfully updated livetalking_sessionid to {sessionid}") except Exception as e: logger.exception("Error updating livetalking_sessionid in database:") finally: if conn: conn.close() # 检查端口是否可用 def check_port(port: int): """检查端口是否被占用""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: s.bind(('0.0.0.0', port)) # 尝试绑定端口 return True # 端口可用 except socket.error: return False # 端口已占用 def build_nerfreal(sessionid: int) -> BaseReal: opt.sessionid = sessionid if opt.model == 'wav2lip': from lipreal import LipReal nerfreal = LipReal(opt, model, avatar) return nerfreal # @app.route('/offer', methods=['POST']) async def offer(request): params = await request.json() offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) # if len(nerfreals) >= opt.max_session: # logger.info('reach max session') # return web.Response( # content_type="application/json", # text=json.dumps( # {"code": -1, "msg": "reach max session"} # ), # ) sessionid = randN(6) # len(nerfreals) nerfreals[sessionid] = None logger.info('sessionid=%d, session num=%d', sessionid, len(nerfreals)) update_sessionid_in_db(sessionid) nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal, sessionid) nerfreals[sessionid] = nerfreal # ice_server = RTCIceServer(urls='stun:stun.l.google.com:19302') ice_server = RTCIceServer(urls='stun:stun.miwifi.com:3478') pc = RTCPeerConnection(configuration=RTCConfiguration(iceServers=[ice_server])) pcs.add(pc) @pc.on("connectionstatechange") async def on_connectionstatechange(): logger.info("Connection state is %s" % pc.connectionState) if pc.connectionState == "failed": await pc.close() pcs.discard(pc) del nerfreals[sessionid] if pc.connectionState == "closed": pcs.discard(pc) del nerfreals[sessionid] gc.collect() player = HumanPlayer(nerfreals[sessionid]) audio_sender = pc.addTrack(player.audio) video_sender = pc.addTrack(player.video) capabilities = RTCRtpSender.getCapabilities("video") preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs)) preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs)) preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs)) transceiver = pc.getTransceivers()[1] transceiver.setCodecPreferences(preferences) await pc.setRemoteDescription(offer) answer = await pc.createAnswer() await pc.setLocalDescription(answer) # return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}) return web.Response( content_type="application/json", text=json.dumps( {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "sessionid": sessionid} ), ) async def human(request): try: params = await request.json() sessionid = params.get('sessionid', 0) if params.get('interrupt'): nerfreals[sessionid].flush_talk() if params['type'] == 'echo': nerfreals[sessionid].liv_speaking = True nerfreals[sessionid].put_msg_txt(params['text']) elif params['type'] == 'chat': asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'], nerfreals[sessionid]) # nerfreals[sessionid].put_msg_txt(res) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg": "ok"} ), ) except Exception as e: logger.exception('exception:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def interrupt_talk(request): try: params = await request.json() sessionid = params.get('sessionid', 0) nerfreals[sessionid].flush_talk() return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg": "ok"} ), ) except Exception as e: logger.exception('exception:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def humanaudio(request): try: form = await request.post() sessionid = int(form.get('sessionid', 0)) fileobj = form["file"] filename = fileobj.filename filebytes = fileobj.file.read() nerfreals[sessionid].put_audio_file(filebytes) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg": "ok"} ), ) except Exception as e: logger.exception('exception:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def set_audiotype(request): try: params = await request.json() sessionid = params.get('sessionid', 0) nerfreals[sessionid].set_custom_state(params['audiotype'], params['reinit']) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg": "ok"} ), ) except Exception as e: logger.exception('exception:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def record(request): try: params = await request.json() sessionid = params.get('sessionid', 0) if params['type'] == 'start_record': # nerfreals[sessionid].put_msg_txt(params['text']) nerfreals[sessionid].start_recording() elif params['type'] == 'end_record': nerfreals[sessionid].stop_recording() return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg": "ok"} ), ) except Exception as e: logger.exception('exception:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def is_speaking(request): params = await request.json() sessionid = params.get('sessionid', 0) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "data": nerfreals[sessionid].is_speaking()} ), ) async def on_shutdown(app): # close peer connections coros = [pc.close() for pc in pcs] await asyncio.gather(*coros) pcs.clear() # 关闭数据库连接 global conn if conn: conn.close() print("[INFO] 已关闭数据库连接") async def post(url, data): try: async with aiohttp.ClientSession() as session: async with session.post(url, data=data) as response: return await response.text() except aiohttp.ClientError as e: logger.info(f'Error: {e}') async def run(push_url, sessionid): nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal, sessionid) nerfreals[sessionid] = nerfreal pc = RTCPeerConnection() pcs.add(pc) @pc.on("connectionstatechange") async def on_connectionstatechange(): logger.info("Connection state is %s" % pc.connectionState) if pc.connectionState == "failed": await pc.close() pcs.discard(pc) player = HumanPlayer(nerfreals[sessionid]) audio_sender = pc.addTrack(player.audio) video_sender = pc.addTrack(player.video) await pc.setLocalDescription(await pc.createOffer()) answer = await post(push_url, pc.localDescription.sdp) await pc.setRemoteDescription(RTCSessionDescription(sdp=answer, type='answer')) ########################################## if __name__ == '__main__': try: db_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', 'live_chat.db') # 启动流程一开始:写入 1(启动中) set_status(db_path, 1) load_db_config(db_path) mp.set_start_method('spawn') parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='../config.json', help="配置文件路径") # audio FPS parser.add_argument('--fps', type=int, default=50, help="audio fps,must be 50") # sliding window left-middle-right length (unit: 20ms) parser.add_argument('-l', type=int, default=10) parser.add_argument('-m', type=int, default=8) parser.add_argument('-r', type=int, default=10) parser.add_argument('--W', type=int, default=450, help="GUI width") parser.add_argument('--H', type=int, default=450, help="GUI height") # musetalk opt parser.add_argument('--avatar_id', type=str, default='model_4_1', help="define which avatar in data/avatars") # parser.add_argument('--bbox_shift', type=int, default=5) parser.add_argument('--batch_size', type=int, default=16, help="infer batch") parser.add_argument('--customvideo_config', type=str, default='', help="custom action json") parser.add_argument('--tts', type=str, default='gpt-sovits', help="tts service type") # xtts gpt-sovits cosyvoice parser.add_argument('--REF_FILE', type=str, default="input/gentle_girl.wav") parser.add_argument('--REF_TEXT', type=str, default="刚进直播间的宝子们,左上角先点个关注,点亮咱们家的粉丝灯牌!我是你们的主播陈婉婉,今天给大家准备了超级重磅的福利") parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000 # parser.add_argument('--CHARACTER', type=str, default='test') # parser.add_argument('--EMOTION', type=str, default='default') parser.add_argument('--model', type=str, default='wav2lip') # musetalk wav2lip ultralight parser.add_argument('--transport', type=str, default='webrtc') # webrtc rtcpush virtualcam parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') # rtmp://localhost/live/livestream parser.add_argument('--max_session', type=int, default=1) # multi session count parser.add_argument('--listenport', type=int, default=8010, help="web listen port") opt = parser.parse_args() # app.config.from_object(opt) # print(app.config) # —— 新增:数据库字段名到 opt 属性名的映射 —— name_map = { 'listen_port': 'listenport', 'avatar_id': 'avatar_id', 'ref_file': 'REF_FILE', 'ref_text': 'REF_TEXT', 'tts_server': 'TTS_SERVER', } for db_key, db_val in config_dict.items(): if db_key not in name_map: continue arg_name = name_map[db_key] # 仅在用户没有显式传参(仍为默认值)时才覆盖 current = getattr(opt, arg_name) default = parser.get_default(arg_name) if current == default: try: setattr(opt, arg_name, type(default)(db_val)) except Exception: setattr(opt, arg_name, db_val) try: with open(opt.config, 'r', encoding='utf-8') as f: cfg = json.load(f) for key, val in cfg.items(): # 如果当前 opt.key 仍然是 parser 定义的默认值,就用配置文件里的 if getattr(opt, key, None) == parser.get_default(key): setattr(opt, key, val) except FileNotFoundError: logger.warning(f"配置文件未找到:{opt.config},将全部使用命令行/默认参数") except Exception as e: logger.warning(f"加载配置文件时出错:{e},将全部使用命令行/默认参数") opt.customopt = [] if opt.customvideo_config != '': with open(opt.customvideo_config, 'r') as file: opt.customopt = json.load(file) # if opt.model == 'ernerf': # from nerfreal import NeRFReal,load_model,load_avatar # model = load_model(opt) # avatar = load_avatar(opt) if opt.model == 'wav2lip': from lipreal import LipReal, load_model, load_avatar, warm_up logger.info(opt) model = load_model("./models/wav2lip384.pth") avatar = load_avatar(opt.avatar_id) warm_up(opt.batch_size, model, 384) elif opt.model == 'ultralight': from lightreal import LightReal, load_model, load_avatar, warm_up logger.info(opt) model = load_model(opt) avatar = load_avatar(opt.avatar_id) warm_up(opt.batch_size, avatar, 160) if opt.transport == 'virtualcam': thread_quit = Event() nerfreals[0] = build_nerfreal(0) rendthrd = Thread(target=nerfreals[0].render, args=(thread_quit,)) rendthrd.start() ############################################################################# appasync = web.Application(client_max_size=1024 ** 2 * 100) appasync.on_shutdown.append(on_shutdown) appasync.router.add_post("/offer", offer) appasync.router.add_post("/human", human) appasync.router.add_post("/humanaudio", humanaudio) appasync.router.add_post("/set_audiotype", set_audiotype) appasync.router.add_post("/record", record) appasync.router.add_post("/interrupt_talk", interrupt_talk) appasync.router.add_post("/is_speaking", is_speaking) appasync.router.add_static('/', path='web') # Configure default CORS settings. cors = aiohttp_cors.setup(appasync, defaults={ "*": aiohttp_cors.ResourceOptions( allow_credentials=True, expose_headers="*", allow_headers="*", ) }) # Configure CORS on all routes. for route in list(appasync.router.routes()): cors.add(route) # 端口配置 listen_port = opt.listenport # 从配置中获取监听端口 # 检查端口是否被占用 if not check_port(listen_port): logger.error(f"端口 {listen_port} 已被占用,无法启动服务。") set_status(db_path, 3) # 设置为启动失败 exit(1) # 退出程序 pagename = 'webrtcapi.html' if opt.transport == 'rtmp': pagename = 'echoapi.html' elif opt.transport == 'rtcpush': pagename = 'rtcpushapi.html' logger.info('start http server; http://:' + str(opt.listenport) + '/' + pagename) logger.info( '如果使用webrtc,推荐访问webrtc集成前端: http://:' + str(opt.listenport) + '/dashboard.html') # 服务已开启, 服务状态改为 '2' set_status(db_path, 2) def run_server(runner): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(runner.setup()) site = web.TCPSite(runner, '0.0.0.0', opt.listenport) loop.run_until_complete(site.start()) if opt.transport == 'rtcpush': for k in range(opt.max_session): push_url = opt.push_url if k != 0: push_url = opt.push_url + str(k) loop.run_until_complete(run(push_url, k)) loop.run_forever() run_server(web.AppRunner(appasync)) except Exception: logger.exception("启动失败") # 服务启动失败, 服务状态改为 '3' set_status(db_path, 3) input("发生异常,按回车键退出…") # 等待用户按回车再退出