You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

501 lines
18 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# server.py
from flask import Flask, render_template,send_from_directory,request, jsonify
from flask_sockets import Sockets
import base64
import json
#import gevent
#from gevent import pywsgi
#from geventwebsocket.handler import WebSocketHandler
import re
import numpy as np
from threading import Thread,Event
#import multiprocessing
import torch.multiprocessing as mp
from aiohttp import web
import aiohttp
import aiohttp_cors
from aiortc import RTCPeerConnection, RTCSessionDescription,RTCIceServer,RTCConfiguration
from aiortc.rtcrtpsender import RTCRtpSender
from webrtc import HumanPlayer
from basereal import BaseReal
from llm import llm_response
import argparse
import random
import shutil
import asyncio
import torch
from typing import Dict
from logger import logger
import gc
import json
import multiprocessing as mp
import sqlite3
import os
app = Flask(__name__)
#sockets = Sockets(app)
nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal
opt = None
model = None
avatar = None
#####webrtc###############################
pcs = set()
def randN(N)->int:
'''生成长度为 N的随机数 '''
min = pow(10, N - 1)
max = pow(10, N)
return random.randint(min, max - 1)
# 全局数据库连接和配置字典
conn = None
config_dict = {}
def load_db_config(db_path):
global conn, config_dict
# 建立可跨线程复用的连接
conn = sqlite3.connect(db_path, check_same_thread=False)
conn.row_factory = sqlite3.Row
# 性能调优WAL 模式、不阻塞读取、缓存增大
conn.execute("PRAGMA journal_mode = WAL;")
conn.execute("PRAGMA synchronous = NORMAL;")
conn.execute("PRAGMA cache_size = 10000;")
# 一次性加载所有 key/value
cursor = conn.execute("SELECT key, value FROM livetalking_config;")
config_dict = {row["key"]: row["value"] for row in cursor.fetchall()}
def set_enable_status(db_path):
# 连接数据库
conn = sqlite3.connect(db_path, check_same_thread=False)
try:
# 把 enable_status 对应的 value 字段改为 '1'
conn.execute(
"UPDATE live_config SET value = '1' WHERE key = 'enable_status';"
)
conn.commit()
print("enable_status 对应的 value 字段改为 '1'")
finally:
conn.close()
def build_nerfreal(sessionid:int)->BaseReal:
opt.sessionid=sessionid
if opt.model == 'wav2lip':
from lipreal import LipReal
nerfreal = LipReal(opt,model,avatar)
return nerfreal
#@app.route('/offer', methods=['POST'])
async def offer(request):
params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
# if len(nerfreals) >= opt.max_session:
# logger.info('reach max session')
# return web.Response(
# content_type="application/json",
# text=json.dumps(
# {"code": -1, "msg": "reach max session"}
# ),
# )
sessionid = randN(6) #len(nerfreals)
nerfreals[sessionid] = None
logger.info('sessionid=%d, session num=%d',sessionid,len(nerfreals))
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
nerfreals[sessionid] = nerfreal
#ice_server = RTCIceServer(urls='stun:stun.l.google.com:19302')
ice_server = RTCIceServer(urls='stun:stun.miwifi.com:3478')
pc = RTCPeerConnection(configuration=RTCConfiguration(iceServers=[ice_server]))
pcs.add(pc)
@pc.on("connectionstatechange")
async def on_connectionstatechange():
logger.info("Connection state is %s" % pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
del nerfreals[sessionid]
if pc.connectionState == "closed":
pcs.discard(pc)
del nerfreals[sessionid]
gc.collect()
player = HumanPlayer(nerfreals[sessionid])
audio_sender = pc.addTrack(player.audio)
video_sender = pc.addTrack(player.video)
capabilities = RTCRtpSender.getCapabilities("video")
preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs))
preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs))
preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs))
transceiver = pc.getTransceivers()[1]
transceiver.setCodecPreferences(preferences)
await pc.setRemoteDescription(offer)
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
#return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type})
return web.Response(
content_type="application/json",
text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "sessionid":sessionid}
),
)
async def human(request):
try:
params = await request.json()
sessionid = params.get('sessionid',0)
if params.get('interrupt'):
nerfreals[sessionid].flush_talk()
if params['type']=='echo':
nerfreals[sessionid].speaking = True
nerfreals[sessionid].put_msg_txt(params['text'])
elif params['type']=='chat':
asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'],nerfreals[sessionid])
#nerfreals[sessionid].put_msg_txt(res)
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "msg":"ok"}
),
)
except Exception as e:
logger.exception('exception:')
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": -1, "msg": str(e)}
),
)
async def interrupt_talk(request):
try:
params = await request.json()
sessionid = params.get('sessionid',0)
nerfreals[sessionid].flush_talk()
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "msg":"ok"}
),
)
except Exception as e:
logger.exception('exception:')
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": -1, "msg": str(e)}
),
)
async def humanaudio(request):
try:
form= await request.post()
sessionid = int(form.get('sessionid',0))
fileobj = form["file"]
filename=fileobj.filename
filebytes=fileobj.file.read()
nerfreals[sessionid].put_audio_file(filebytes)
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "msg":"ok"}
),
)
except Exception as e:
logger.exception('exception:')
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": -1, "msg": str(e)}
),
)
async def set_audiotype(request):
try:
params = await request.json()
sessionid = params.get('sessionid',0)
nerfreals[sessionid].set_custom_state(params['audiotype'],params['reinit'])
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "msg":"ok"}
),
)
except Exception as e:
logger.exception('exception:')
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": -1, "msg": str(e)}
),
)
async def record(request):
try:
params = await request.json()
sessionid = params.get('sessionid',0)
if params['type']=='start_record':
# nerfreals[sessionid].put_msg_txt(params['text'])
nerfreals[sessionid].start_recording()
elif params['type']=='end_record':
nerfreals[sessionid].stop_recording()
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "msg":"ok"}
),
)
except Exception as e:
logger.exception('exception:')
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": -1, "msg": str(e)}
),
)
async def is_speaking(request):
params = await request.json()
sessionid = params.get('sessionid',0)
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "data": nerfreals[sessionid].is_speaking()}
),
)
async def on_shutdown(app):
# close peer connections
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
pcs.clear()
async def post(url,data):
try:
async with aiohttp.ClientSession() as session:
async with session.post(url,data=data) as response:
return await response.text()
except aiohttp.ClientError as e:
logger.info(f'Error: {e}')
async def run(push_url,sessionid):
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
nerfreals[sessionid] = nerfreal
pc = RTCPeerConnection()
pcs.add(pc)
@pc.on("connectionstatechange")
async def on_connectionstatechange():
logger.info("Connection state is %s" % pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
player = HumanPlayer(nerfreals[sessionid])
audio_sender = pc.addTrack(player.audio)
video_sender = pc.addTrack(player.video)
await pc.setLocalDescription(await pc.createOffer())
answer = await post(push_url,pc.localDescription.sdp)
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
##########################################
if __name__ == '__main__':
try:
db_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'live_chat.db')
load_db_config(db_path)
mp.set_start_method('spawn')
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config/config.json', help="配置文件路径")
# audio FPS
parser.add_argument('--fps', type=int, default=50, help="audio fps,must be 50")
# sliding window left-middle-right length (unit: 20ms)
parser.add_argument('-l', type=int, default=10)
parser.add_argument('-m', type=int, default=8)
parser.add_argument('-r', type=int, default=10)
parser.add_argument('--W', type=int, default=450, help="GUI width")
parser.add_argument('--H', type=int, default=450, help="GUI height")
#musetalk opt
parser.add_argument('--avatar_id', type=str, default='model_4_1', help="define which avatar in data/avatars")
#parser.add_argument('--bbox_shift', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=16, help="infer batch")
parser.add_argument('--customvideo_config', type=str, default='', help="custom action json")
parser.add_argument('--tts', type=str, default='gpt-sovits', help="tts service type") #xtts gpt-sovits cosyvoice
parser.add_argument('--REF_FILE', type=str, default="input/gentle_girl.wav")
parser.add_argument('--REF_TEXT', type=str, default="刚进直播间的宝子们,左上角先点个关注,点亮咱们家的粉丝灯牌!我是你们的主播陈婉婉,今天给大家准备了超级重磅的福利")
parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000
# parser.add_argument('--CHARACTER', type=str, default='test')
# parser.add_argument('--EMOTION', type=str, default='default')
parser.add_argument('--model', type=str, default='wav2lip') #musetalk wav2lip ultralight
parser.add_argument('--transport', type=str, default='webrtc') #webrtc rtcpush virtualcam
parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
parser.add_argument('--max_session', type=int, default=1) #multi session count
parser.add_argument('--listenport', type=int, default=8010, help="web listen port")
opt = parser.parse_args()
#app.config.from_object(opt)
#print(app.config)
# —— 新增:数据库字段名到 opt 属性名的映射 ——
name_map = {
'listen_port': 'listenport',
'avatar_id': 'avatar_id',
'ref_file': 'REF_FILE',
'ref_text': 'REF_TEXT',
'tts_server': 'TTS_SERVER',
}
for db_key, db_val in config_dict.items():
if db_key not in name_map:
continue
arg_name = name_map[db_key]
# 仅在用户没有显式传参(仍为默认值)时才覆盖
current = getattr(opt, arg_name)
default = parser.get_default(arg_name)
if current == default:
try:
setattr(opt, arg_name, type(default)(db_val))
except Exception:
setattr(opt, arg_name, db_val)
try:
with open(opt.config, 'r', encoding='utf-8') as f:
cfg = json.load(f)
for key, val in cfg.items():
# 如果当前 opt.key 仍然是 parser 定义的默认值,就用配置文件里的
if getattr(opt, key, None) == parser.get_default(key):
setattr(opt, key, val)
except FileNotFoundError:
logger.warning(f"配置文件未找到:{opt.config},将全部使用命令行/默认参数")
except Exception as e:
logger.warning(f"加载配置文件时出错:{e},将全部使用命令行/默认参数")
opt.customopt = []
if opt.customvideo_config!='':
with open(opt.customvideo_config,'r') as file:
opt.customopt = json.load(file)
# if opt.model == 'ernerf':
# from nerfreal import NeRFReal,load_model,load_avatar
# model = load_model(opt)
# avatar = load_avatar(opt)
if opt.model == 'wav2lip':
from lipreal import LipReal,load_model,load_avatar,warm_up
logger.info(opt)
model = load_model("./models/wav2lip384.pth")
avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,model,384)
elif opt.model == 'ultralight':
from lightreal import LightReal,load_model,load_avatar,warm_up
logger.info(opt)
model = load_model(opt)
avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,avatar,160)
if opt.transport=='virtualcam':
thread_quit = Event()
nerfreals[0] = build_nerfreal(0)
rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
rendthrd.start()
#############################################################################
appasync = web.Application(client_max_size=1024**2*100)
appasync.on_shutdown.append(on_shutdown)
appasync.router.add_post("/offer", offer)
appasync.router.add_post("/human", human)
appasync.router.add_post("/humanaudio", humanaudio)
appasync.router.add_post("/set_audiotype", set_audiotype)
appasync.router.add_post("/record", record)
appasync.router.add_post("/interrupt_talk", interrupt_talk)
appasync.router.add_post("/is_speaking", is_speaking)
appasync.router.add_static('/',path='web')
# Configure default CORS settings.
cors = aiohttp_cors.setup(appasync, defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True,
expose_headers="*",
allow_headers="*",
)
})
# Configure CORS on all routes.
for route in list(appasync.router.routes()):
cors.add(route)
pagename='webrtcapi.html'
if opt.transport=='rtmp':
pagename='echoapi.html'
elif opt.transport=='rtcpush':
pagename='rtcpushapi.html'
logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename)
logger.info('如果使用webrtc推荐访问webrtc集成前端: http://<serverip>:'+str(opt.listenport)+'/dashboard.html')
# 服务开启将数据库中enable_status 对应的 value 字段改为 '1'
set_enable_status(db_path)
def run_server(runner):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(runner.setup())
site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
loop.run_until_complete(site.start())
if opt.transport=='rtcpush':
for k in range(opt.max_session):
push_url = opt.push_url
if k!=0:
push_url = opt.push_url+str(k)
loop.run_until_complete(run(push_url,k))
loop.run_forever()
run_server(web.AppRunner(appasync))
except Exception:
import traceback
traceback.print_exc() # 打印完整的错误堆栈
input("发生异常,按回车键退出…") # 等待用户按回车再退出