You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

585 lines
22 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# server.py
from flask import Flask, render_template,send_from_directory,request, jsonify
from flask_sockets import Sockets
import base64
import json
#import gevent
#from gevent import pywsgi
#from geventwebsocket.handler import WebSocketHandler
import re
import numpy as np
from threading import Thread,Event
#import multiprocessing
import torch.multiprocessing as mp
from aiohttp import web
import aiohttp
import aiohttp_cors
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCIceServer, RTCConfiguration
from aiortc.rtcrtpsender import RTCRtpSender
from webrtc import HumanPlayer
from basereal import BaseReal
from llm import llm_response
import argparse
import random
import shutil
import asyncio
import torch
from typing import Dict
from logger import logger
import gc
import json
import multiprocessing as mp
import sqlite3
import os
app = Flask(__name__)
#sockets = Sockets(app)
nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal
opt = None
model = None
avatar = None
#####webrtc###############################
pcs = set()
def randN(N)->int:
'''生成长度为 N的随机数 '''
min = pow(10, N - 1)
max = pow(10, N)
return random.randint(min, max - 1)
# 全局数据库连接和配置字典
conn = None
config_dict = {}
def load_db_config(db_path):
global conn, config_dict
# 建立可跨线程复用的连接
conn = sqlite3.connect(db_path, check_same_thread=False)
conn.row_factory = sqlite3.Row
# 性能调优WAL 模式、不阻塞读取、缓存增大
conn.execute("PRAGMA journal_mode = WAL;")
conn.execute("PRAGMA synchronous = NORMAL;")
conn.execute("PRAGMA cache_size = 10000;")
# 一次性加载所有 key/value
cursor = conn.execute("SELECT key, value FROM livetalking_config;")
config_dict = {row["key"]: row["value"] for row in cursor.fetchall()}
def set_enable_status(db_path):
# 连接数据库
conn = sqlite3.connect(db_path, check_same_thread=False)
try:
# 把 enable_status 对应的 value 字段改为 '1'
conn.execute(
"UPDATE live_config SET value = '1' WHERE key = 'livetlking_enable_status';"
)
conn.commit()
print("livetlking_enable_status 对应的 value 字段改为 '1'")
finally:
conn.close()
def build_nerfreal(sessionid:int)->BaseReal:
opt.sessionid=sessionid
if opt.model == 'wav2lip':
from lipreal import LipReal
nerfreal = LipReal(opt,model,avatar)
return nerfreal
#@app.route('/offer', methods=['POST'])
async def offer(request):
logger.info("开始处理offer请求")
# 接收请求参数
try:
params = await request.json()
logger.debug("接收到的请求参数: %s", params)
except Exception as e:
logger.error("解析请求参数失败: %s", str(e))
return web.Response(
content_type="application/json",
text=json.dumps({"code": -2, "msg": "解析参数失败"})
)
# 验证必要参数
if "sdp" not in params or "type" not in params:
logger.warning("请求参数缺少sdp或type字段: %s", params)
return web.Response(
content_type="application/json",
text=json.dumps({"code": -3, "msg": "缺少必要参数"})
)
# 创建RTCSessionDescription
try:
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
logger.debug("成功创建RTCSessionDescription类型: %s", params["type"])
except Exception as e:
logger.error("创建RTCSessionDescription失败: %s", str(e))
return web.Response(
content_type="application/json",
text=json.dumps({"code": -4, "msg": "创建offer失败"})
)
# 生成会话ID并记录
sessionid = randN(6)
nerfreals[sessionid] = None
logger.info('生成新会话sessionid=%d, 当前会话数量=%d', sessionid, len(nerfreals))
# 构建nerfreal
try:
logger.debug("开始构建nerfrealsessionid=%d", sessionid)
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal, sessionid)
nerfreals[sessionid] = nerfreal
logger.info("成功构建nerfrealsessionid=%d", sessionid)
except Exception as e:
logger.error("构建nerfreal失败sessionid=%d: %s", sessionid, str(e))
del nerfreals[sessionid]
return web.Response(
content_type="application/json",
text=json.dumps({"code": -5, "msg": "构建会话失败"})
)
# 配置ICE服务器并创建RTCPeerConnection
try:
ice_server = RTCIceServer(urls='stun:stun.miwifi.com:3478')
logger.debug("使用ICE服务器: %s", ice_server.urls)
pc = RTCPeerConnection(configuration=RTCConfiguration(iceServers=[ice_server]))
pcs.add(pc)
logger.info("创建RTCPeerConnection成功sessionid=%d", sessionid)
except Exception as e:
logger.error("创建RTCPeerConnection失败sessionid=%d: %s", sessionid, str(e))
del nerfreals[sessionid]
return web.Response(
content_type="application/json",
text=json.dumps({"code": -6, "msg": "创建连接失败"})
)
# 连接状态变化处理
@pc.on("connectionstatechange")
async def on_connectionstatechange():
logger.info("sessionid=%d, 连接状态变更为: %s", sessionid, pc.connectionState)
if pc.connectionState == "failed":
logger.warning("sessionid=%d, 连接失败,关闭连接", sessionid)
await pc.close()
pcs.discard(pc)
del nerfreals[sessionid]
if pc.connectionState == "closed":
logger.info("sessionid=%d, 连接已关闭", sessionid)
pcs.discard(pc)
del nerfreals[sessionid]
gc.collect()
logger.debug("sessionid=%d, 已清理资源", sessionid)
# 添加媒体轨道
try:
player = HumanPlayer(nerfreals[sessionid])
audio_sender = pc.addTrack(player.audio)
video_sender = pc.addTrack(player.video)
logger.debug("sessionid=%d, 已添加音频和视频轨道", sessionid)
except Exception as e:
logger.error("sessionid=%d, 添加媒体轨道失败: %s", sessionid, str(e))
return web.Response(
content_type="application/json",
text=json.dumps({"code": -7, "msg": "添加媒体轨道失败"})
)
# 配置编解码器偏好
try:
capabilities = RTCRtpSender.getCapabilities("video")
preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs))
preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs))
preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs))
logger.debug("sessionid=%d, 编解码器偏好: %s", sessionid, [p.name for p in preferences])
transceiver = pc.getTransceivers()[1]
transceiver.setCodecPreferences(preferences)
logger.debug("sessionid=%d, 已设置编解码器偏好", sessionid)
except Exception as e:
logger.error("sessionid=%d, 配置编解码器失败: %s", sessionid, str(e))
# 处理offer并创建answer
try:
await pc.setRemoteDescription(offer)
logger.debug("sessionid=%d, 已设置远程描述", sessionid)
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
logger.debug("sessionid=%d, 已创建并设置本地answer", sessionid)
except Exception as e:
logger.error("sessionid=%d, 处理offer创建answer失败: %s", sessionid, str(e))
return web.Response(
content_type="application/json",
text=json.dumps({"code": -8, "msg": "处理offer失败"})
)
# 返回结果
response_data = {
"sdp": pc.localDescription.sdp,
"type": pc.localDescription.type,
"sessionid": sessionid
}
logger.info("sessionid=%d, offer请求处理完成返回响应", sessionid)
return web.Response(
content_type="application/json",
text=json.dumps(response_data)
)
async def human(request):
try:
params = await request.json()
sessionid = params.get('sessionid',0)
if params.get('interrupt'):
nerfreals[sessionid].flush_talk()
if params['type']=='echo':
nerfreals[sessionid].speaking = True
nerfreals[sessionid].put_msg_txt(params['text'])
elif params['type']=='chat':
asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'],nerfreals[sessionid])
#nerfreals[sessionid].put_msg_txt(res)
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "msg":"ok"}
),
)
except Exception as e:
logger.exception('exception:')
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": -1, "msg": str(e)}
),
)
async def interrupt_talk(request):
try:
params = await request.json()
sessionid = params.get('sessionid',0)
nerfreals[sessionid].flush_talk()
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "msg":"ok"}
),
)
except Exception as e:
logger.exception('exception:')
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": -1, "msg": str(e)}
),
)
async def humanaudio(request):
try:
form= await request.post()
sessionid = int(form.get('sessionid',0))
fileobj = form["file"]
filename=fileobj.filename
filebytes=fileobj.file.read()
nerfreals[sessionid].put_audio_file(filebytes)
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "msg":"ok"}
),
)
except Exception as e:
logger.exception('exception:')
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": -1, "msg": str(e)}
),
)
async def set_audiotype(request):
try:
params = await request.json()
sessionid = params.get('sessionid',0)
nerfreals[sessionid].set_custom_state(params['audiotype'],params['reinit'])
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "msg":"ok"}
),
)
except Exception as e:
logger.exception('exception:')
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": -1, "msg": str(e)}
),
)
async def record(request):
try:
params = await request.json()
sessionid = params.get('sessionid',0)
if params['type']=='start_record':
# nerfreals[sessionid].put_msg_txt(params['text'])
nerfreals[sessionid].start_recording()
elif params['type']=='end_record':
nerfreals[sessionid].stop_recording()
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "msg":"ok"}
),
)
except Exception as e:
logger.exception('exception:')
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": -1, "msg": str(e)}
),
)
async def is_speaking(request):
params = await request.json()
sessionid = params.get('sessionid',0)
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "data": nerfreals[sessionid].is_speaking()}
),
)
async def on_shutdown(app):
# close peer connections
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
pcs.clear()
# 关闭数据库连接
global conn
if conn:
conn.close()
print("[INFO] 已关闭数据库连接")
async def post(url,data):
try:
async with aiohttp.ClientSession() as session:
async with session.post(url,data=data) as response:
return await response.text()
except aiohttp.ClientError as e:
logger.info(f'Error: {e}')
async def run(push_url,sessionid):
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
nerfreals[sessionid] = nerfreal
pc = RTCPeerConnection()
pcs.add(pc)
@pc.on("connectionstatechange")
async def on_connectionstatechange():
logger.info("Connection state is %s" % pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
player = HumanPlayer(nerfreals[sessionid])
audio_sender = pc.addTrack(player.audio)
video_sender = pc.addTrack(player.video)
await pc.setLocalDescription(await pc.createOffer())
answer = await post(push_url,pc.localDescription.sdp)
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
##########################################
if __name__ == '__main__':
try:
db_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', 'live_chat.db')
load_db_config(db_path)
mp.set_start_method('spawn')
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config/config.json', help="配置文件路径")
# audio FPS
parser.add_argument('--fps', type=int, default=50, help="audio fps,must be 50")
# sliding window left-middle-right length (unit: 20ms)
parser.add_argument('-l', type=int, default=10)
parser.add_argument('-m', type=int, default=8)
parser.add_argument('-r', type=int, default=10)
parser.add_argument('--W', type=int, default=450, help="GUI width")
parser.add_argument('--H', type=int, default=450, help="GUI height")
#musetalk opt
parser.add_argument('--avatar_id', type=str, default='model_4_1', help="define which avatar in data/avatars")
#parser.add_argument('--bbox_shift', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=16, help="infer batch")
parser.add_argument('--customvideo_config', type=str, default='', help="custom action json")
parser.add_argument('--tts', type=str, default='gpt-sovits', help="tts service type") #xtts gpt-sovits cosyvoice
parser.add_argument('--REF_FILE', type=str, default="input/gentle_girl.wav")
parser.add_argument('--REF_TEXT', type=str, default="刚进直播间的宝子们,左上角先点个关注,点亮咱们家的粉丝灯牌!我是你们的主播陈婉婉,今天给大家准备了超级重磅的福利")
parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000
# parser.add_argument('--CHARACTER', type=str, default='test')
# parser.add_argument('--EMOTION', type=str, default='default')
parser.add_argument('--model', type=str, default='wav2lip') #musetalk wav2lip ultralight
parser.add_argument('--transport', type=str, default='webrtc') #webrtc rtcpush virtualcam
parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
parser.add_argument('--max_session', type=int, default=1) #multi session count
parser.add_argument('--listenport', type=int, default=8010, help="web listen port")
opt = parser.parse_args()
#app.config.from_object(opt)
#print(app.config)
# —— 新增:数据库字段名到 opt 属性名的映射 ——
name_map = {
'listen_port': 'listenport',
'avatar_id': 'avatar_id',
'ref_file': 'REF_FILE',
'ref_text': 'REF_TEXT',
'tts_server': 'TTS_SERVER',
}
for db_key, db_val in config_dict.items():
if db_key not in name_map:
continue
arg_name = name_map[db_key]
# 仅在用户没有显式传参(仍为默认值)时才覆盖
current = getattr(opt, arg_name)
default = parser.get_default(arg_name)
if current == default:
try:
setattr(opt, arg_name, type(default)(db_val))
except Exception:
setattr(opt, arg_name, db_val)
try:
with open(opt.config, 'r', encoding='utf-8') as f:
cfg = json.load(f)
for key, val in cfg.items():
# 如果当前 opt.key 仍然是 parser 定义的默认值,就用配置文件里的
if getattr(opt, key, None) == parser.get_default(key):
setattr(opt, key, val)
except FileNotFoundError:
logger.warning(f"配置文件未找到:{opt.config},将全部使用命令行/默认参数")
except Exception as e:
logger.warning(f"加载配置文件时出错:{e},将全部使用命令行/默认参数")
opt.customopt = []
if opt.customvideo_config!='':
with open(opt.customvideo_config,'r') as file:
opt.customopt = json.load(file)
# if opt.model == 'ernerf':
# from nerfreal import NeRFReal,load_model,load_avatar
# model = load_model(opt)
# avatar = load_avatar(opt)
if opt.model == 'wav2lip':
from lipreal import LipReal,load_model,load_avatar,warm_up
logger.info(opt)
model = load_model("./models/wav2lip384.pth")
avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,model,384)
elif opt.model == 'ultralight':
from lightreal import LightReal,load_model,load_avatar,warm_up
logger.info(opt)
model = load_model(opt)
avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,avatar,160)
if opt.transport=='virtualcam':
thread_quit = Event()
nerfreals[0] = build_nerfreal(0)
rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
rendthrd.start()
#############################################################################
appasync = web.Application(client_max_size=1024**2*100)
appasync.on_shutdown.append(on_shutdown)
appasync.router.add_post("/offer", offer)
appasync.router.add_post("/human", human)
appasync.router.add_post("/humanaudio", humanaudio)
appasync.router.add_post("/set_audiotype", set_audiotype)
appasync.router.add_post("/record", record)
appasync.router.add_post("/interrupt_talk", interrupt_talk)
appasync.router.add_post("/is_speaking", is_speaking)
appasync.router.add_static('/',path='web')
# Configure default CORS settings.
cors = aiohttp_cors.setup(appasync, defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True,
expose_headers="*",
allow_headers="*",
)
})
# Configure CORS on all routes.
for route in list(appasync.router.routes()):
cors.add(route)
pagename='webrtcapi.html'
if opt.transport=='rtmp':
pagename='echoapi.html'
elif opt.transport=='rtcpush':
pagename='rtcpushapi.html'
logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename)
logger.info('如果使用webrtc推荐访问webrtc集成前端: http://<serverip>:'+str(opt.listenport)+'/dashboard.html')
# 服务开启将数据库中enable_status 对应的 value 字段改为 '1'
set_enable_status(db_path)
def run_server(runner):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(runner.setup())
site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
loop.run_until_complete(site.start())
if opt.transport=='rtcpush':
for k in range(opt.max_session):
push_url = opt.push_url
if k!=0:
push_url = opt.push_url+str(k)
loop.run_until_complete(run(push_url,k))
loop.run_forever()
run_server(web.AppRunner(appasync))
except Exception:
import traceback
traceback.print_exc() # 打印完整的错误堆栈
input("发生异常,按回车键退出…") # 等待用户按回车再退出