添加更新sessionid功能

main
fanpt 1 month ago
parent f6113c0bbb
commit 4c38f17568

269
app.py

@ -16,17 +16,17 @@
###############################################################################
# server.py
from flask import Flask, render_template,send_from_directory,request, jsonify
from flask import Flask, render_template, send_from_directory, request, jsonify
from flask_sockets import Sockets
import base64
import json
#import gevent
#from gevent import pywsgi
#from geventwebsocket.handler import WebSocketHandler
# import gevent
# from gevent import pywsgi
# from geventwebsocket.handler import WebSocketHandler
import re
import numpy as np
from threading import Thread,Event
#import multiprocessing
from threading import Thread, Event
# import multiprocessing
import torch.multiprocessing as mp
from aiohttp import web
@ -52,8 +52,8 @@ import sqlite3
import os
app = Flask(__name__)
#sockets = Sockets(app)
nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal
# sockets = Sockets(app)
nerfreals: Dict[int, BaseReal] = {} # sessionid:BaseReal
opt = None
model = None
avatar = None
@ -61,7 +61,8 @@ avatar = None
#####webrtc###############################
pcs = set()
def randN(N)->int:
def randN(N) -> int:
'''生成长度为 N的随机数 '''
min = pow(10, N - 1)
max = pow(10, N)
@ -101,172 +102,116 @@ def set_enable_status(db_path):
finally:
conn.close()
def build_nerfreal(sessionid:int)->BaseReal:
opt.sessionid=sessionid
# 用于更新数据库中的 sessionid
def update_sessionid_in_db(sessionid):
try:
# 打开数据库连接
conn = sqlite3.connect(db_path, check_same_thread=False)
cursor = conn.cursor()
# 更新 livetalking_sessionid 字段
cursor.execute("UPDATE live_config SET value = ? WHERE key = 'livetalking_sessionid';", (str(sessionid),))
conn.commit()
logger.info(f"Successfully updated livetalking_sessionid to {sessionid}")
except Exception as e:
logger.exception("Error updating livetalking_sessionid in database:")
finally:
if conn:
conn.close()
def build_nerfreal(sessionid: int) -> BaseReal:
opt.sessionid = sessionid
if opt.model == 'wav2lip':
from lipreal import LipReal
nerfreal = LipReal(opt,model,avatar)
nerfreal = LipReal(opt, model, avatar)
return nerfreal
#@app.route('/offer', methods=['POST'])
async def offer(request):
logger.info("开始处理offer请求")
# 接收请求参数
try:
# @app.route('/offer', methods=['POST'])
async def offer(request):
params = await request.json()
logger.debug("接收到的请求参数: %s", params)
except Exception as e:
logger.error("解析请求参数失败: %s", str(e))
return web.Response(
content_type="application/json",
text=json.dumps({"code": -2, "msg": "解析参数失败"})
)
# 验证必要参数
if "sdp" not in params or "type" not in params:
logger.warning("请求参数缺少sdp或type字段: %s", params)
return web.Response(
content_type="application/json",
text=json.dumps({"code": -3, "msg": "缺少必要参数"})
)
# 创建RTCSessionDescription
try:
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
logger.debug("成功创建RTCSessionDescription类型: %s", params["type"])
except Exception as e:
logger.error("创建RTCSessionDescription失败: %s", str(e))
return web.Response(
content_type="application/json",
text=json.dumps({"code": -4, "msg": "创建offer失败"})
)
# 生成会话ID并记录
sessionid = randN(6)
# if len(nerfreals) >= opt.max_session:
# logger.info('reach max session')
# return web.Response(
# content_type="application/json",
# text=json.dumps(
# {"code": -1, "msg": "reach max session"}
# ),
# )
sessionid = randN(6) # len(nerfreals)
nerfreals[sessionid] = None
logger.info('生成新会话sessionid=%d, 当前会话数量=%d', sessionid, len(nerfreals))
logger.info('sessionid=%d, session num=%d', sessionid, len(nerfreals))
update_sessionid_in_db(sessionid)
# 构建nerfreal
try:
logger.debug("开始构建nerfrealsessionid=%d", sessionid)
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal, sessionid)
nerfreals[sessionid] = nerfreal
logger.info("成功构建nerfrealsessionid=%d", sessionid)
except Exception as e:
logger.error("构建nerfreal失败sessionid=%d: %s", sessionid, str(e))
del nerfreals[sessionid]
return web.Response(
content_type="application/json",
text=json.dumps({"code": -5, "msg": "构建会话失败"})
)
# 配置ICE服务器并创建RTCPeerConnection
try:
# ice_server = RTCIceServer(urls='stun:stun.l.google.com:19302')
ice_server = RTCIceServer(urls='stun:stun.miwifi.com:3478')
logger.debug("使用ICE服务器: %s", ice_server.urls)
pc = RTCPeerConnection(configuration=RTCConfiguration(iceServers=[ice_server]))
pcs.add(pc)
logger.info("创建RTCPeerConnection成功sessionid=%d", sessionid)
except Exception as e:
logger.error("创建RTCPeerConnection失败sessionid=%d: %s", sessionid, str(e))
del nerfreals[sessionid]
return web.Response(
content_type="application/json",
text=json.dumps({"code": -6, "msg": "创建连接失败"})
)
# 连接状态变化处理
@pc.on("connectionstatechange")
async def on_connectionstatechange():
logger.info("sessionid=%d, 连接状态变更为: %s", sessionid, pc.connectionState)
logger.info("Connection state is %s" % pc.connectionState)
if pc.connectionState == "failed":
logger.warning("sessionid=%d, 连接失败,关闭连接", sessionid)
await pc.close()
pcs.discard(pc)
del nerfreals[sessionid]
if pc.connectionState == "closed":
logger.info("sessionid=%d, 连接已关闭", sessionid)
pcs.discard(pc)
del nerfreals[sessionid]
gc.collect()
logger.debug("sessionid=%d, 已清理资源", sessionid)
# 添加媒体轨道
try:
player = HumanPlayer(nerfreals[sessionid])
audio_sender = pc.addTrack(player.audio)
video_sender = pc.addTrack(player.video)
logger.debug("sessionid=%d, 已添加音频和视频轨道", sessionid)
except Exception as e:
logger.error("sessionid=%d, 添加媒体轨道失败: %s", sessionid, str(e))
return web.Response(
content_type="application/json",
text=json.dumps({"code": -7, "msg": "添加媒体轨道失败"})
)
# 配置编解码器偏好
try:
capabilities = RTCRtpSender.getCapabilities("video")
preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs))
preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs))
preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs))
logger.debug("sessionid=%d, 编解码器偏好: %s", sessionid, [p.name for p in preferences])
transceiver = pc.getTransceivers()[1]
transceiver.setCodecPreferences(preferences)
logger.debug("sessionid=%d, 已设置编解码器偏好", sessionid)
except Exception as e:
logger.error("sessionid=%d, 配置编解码器失败: %s", sessionid, str(e))
# 处理offer并创建answer
try:
await pc.setRemoteDescription(offer)
logger.debug("sessionid=%d, 已设置远程描述", sessionid)
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
logger.debug("sessionid=%d, 已创建并设置本地answer", sessionid)
except Exception as e:
logger.error("sessionid=%d, 处理offer创建answer失败: %s", sessionid, str(e))
return web.Response(
content_type="application/json",
text=json.dumps({"code": -8, "msg": "处理offer失败"})
)
# 返回结果
response_data = {
"sdp": pc.localDescription.sdp,
"type": pc.localDescription.type,
"sessionid": sessionid
}
logger.info("sessionid=%d, offer请求处理完成返回响应", sessionid)
# return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type})
return web.Response(
content_type="application/json",
text=json.dumps(response_data)
text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "sessionid": sessionid}
),
)
async def human(request):
try:
params = await request.json()
sessionid = params.get('sessionid',0)
sessionid = params.get('sessionid', 0)
if params.get('interrupt'):
nerfreals[sessionid].flush_talk()
if params['type']=='echo':
if params['type'] == 'echo':
nerfreals[sessionid].speaking = True
nerfreals[sessionid].put_msg_txt(params['text'])
elif params['type']=='chat':
asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'],nerfreals[sessionid])
#nerfreals[sessionid].put_msg_txt(res)
elif params['type'] == 'chat':
asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'], nerfreals[sessionid])
# nerfreals[sessionid].put_msg_txt(res)
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "msg":"ok"}
{"code": 0, "msg": "ok"}
),
)
except Exception as e:
@ -278,17 +223,18 @@ async def human(request):
),
)
async def interrupt_talk(request):
try:
params = await request.json()
sessionid = params.get('sessionid',0)
sessionid = params.get('sessionid', 0)
nerfreals[sessionid].flush_talk()
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "msg":"ok"}
{"code": 0, "msg": "ok"}
),
)
except Exception as e:
@ -300,19 +246,20 @@ async def interrupt_talk(request):
),
)
async def humanaudio(request):
try:
form= await request.post()
sessionid = int(form.get('sessionid',0))
form = await request.post()
sessionid = int(form.get('sessionid', 0))
fileobj = form["file"]
filename=fileobj.filename
filebytes=fileobj.file.read()
filename = fileobj.filename
filebytes = fileobj.file.read()
nerfreals[sessionid].put_audio_file(filebytes)
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "msg":"ok"}
{"code": 0, "msg": "ok"}
),
)
except Exception as e:
@ -324,17 +271,18 @@ async def humanaudio(request):
),
)
async def set_audiotype(request):
try:
params = await request.json()
sessionid = params.get('sessionid',0)
nerfreals[sessionid].set_custom_state(params['audiotype'],params['reinit'])
sessionid = params.get('sessionid', 0)
nerfreals[sessionid].set_custom_state(params['audiotype'], params['reinit'])
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "msg":"ok"}
{"code": 0, "msg": "ok"}
),
)
except Exception as e:
@ -346,20 +294,21 @@ async def set_audiotype(request):
),
)
async def record(request):
try:
params = await request.json()
sessionid = params.get('sessionid',0)
if params['type']=='start_record':
sessionid = params.get('sessionid', 0)
if params['type'] == 'start_record':
# nerfreals[sessionid].put_msg_txt(params['text'])
nerfreals[sessionid].start_recording()
elif params['type']=='end_record':
elif params['type'] == 'end_record':
nerfreals[sessionid].stop_recording()
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "msg":"ok"}
{"code": 0, "msg": "ok"}
),
)
except Exception as e:
@ -371,10 +320,11 @@ async def record(request):
),
)
async def is_speaking(request):
params = await request.json()
sessionid = params.get('sessionid',0)
sessionid = params.get('sessionid', 0)
return web.Response(
content_type="application/json",
text=json.dumps(
@ -395,16 +345,17 @@ async def on_shutdown(app):
print("[INFO] 已关闭数据库连接")
async def post(url,data):
async def post(url, data):
try:
async with aiohttp.ClientSession() as session:
async with session.post(url,data=data) as response:
async with session.post(url, data=data) as response:
return await response.text()
except aiohttp.ClientError as e:
logger.info(f'Error: {e}')
async def run(push_url,sessionid):
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
async def run(push_url, sessionid):
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal, sessionid)
nerfreals[sessionid] = nerfreal
pc = RTCPeerConnection()
@ -422,8 +373,10 @@ async def run(push_url,sessionid):
video_sender = pc.addTrack(player.video)
await pc.setLocalDescription(await pc.createOffer())
answer = await post(push_url,pc.localDescription.sdp)
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
answer = await post(push_url, pc.localDescription.sdp)
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer, type='answer'))
##########################################
if __name__ == '__main__':
@ -434,7 +387,7 @@ if __name__ == '__main__':
mp.set_start_method('spawn')
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config/config.json', help="配置文件路径")
parser.add_argument('--config', type=str, default='../config.json', help="配置文件路径")
# audio FPS
parser.add_argument('--fps', type=int, default=50, help="audio fps,must be 50")
@ -446,9 +399,9 @@ if __name__ == '__main__':
parser.add_argument('--W', type=int, default=450, help="GUI width")
parser.add_argument('--H', type=int, default=450, help="GUI height")
#musetalk opt
# musetalk opt
parser.add_argument('--avatar_id', type=str, default='model_4_1', help="define which avatar in data/avatars")
#parser.add_argument('--bbox_shift', type=int, default=5)
# parser.add_argument('--bbox_shift', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=16, help="infer batch")
parser.add_argument('--customvideo_config', type=str, default='', help="custom action json")
@ -460,17 +413,17 @@ if __name__ == '__main__':
# parser.add_argument('--CHARACTER', type=str, default='test')
# parser.add_argument('--EMOTION', type=str, default='default')
parser.add_argument('--model', type=str, default='wav2lip') #musetalk wav2lip ultralight
parser.add_argument('--model', type=str, default='wav2lip') # musetalk wav2lip ultralight
parser.add_argument('--transport', type=str, default='webrtc') #webrtc rtcpush virtualcam
parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
parser.add_argument('--max_session', type=int, default=1) #multi session count
parser.add_argument('--max_session', type=int, default=1) # multi session count
parser.add_argument('--listenport', type=int, default=8010, help="web listen port")
opt = parser.parse_args()
#app.config.from_object(opt)
#print(app.config)
# app.config.from_object(opt)
# print(app.config)
# —— 新增:数据库字段名到 opt 属性名的映射 ——
name_map = {
'listen_port': 'listenport',
@ -505,8 +458,8 @@ if __name__ == '__main__':
logger.warning(f"加载配置文件时出错:{e},将全部使用命令行/默认参数")
opt.customopt = []
if opt.customvideo_config!='':
with open(opt.customvideo_config,'r') as file:
if opt.customvideo_config != '':
with open(opt.customvideo_config, 'r') as file:
opt.customopt = json.load(file)
# if opt.model == 'ernerf':
@ -518,22 +471,23 @@ if __name__ == '__main__':
logger.info(opt)
model = load_model("./models/wav2lip384.pth")
avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,model,384)
warm_up(opt.batch_size, model, 384)
elif opt.model == 'ultralight':
from lightreal import LightReal,load_model,load_avatar,warm_up
from lightreal import LightReal, load_model, load_avatar, warm_up
logger.info(opt)
model = load_model(opt)
avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,avatar,160)
warm_up(opt.batch_size, avatar, 160)
if opt.transport=='virtualcam':
if opt.transport == 'virtualcam':
thread_quit = Event()
nerfreals[0] = build_nerfreal(0)
rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
rendthrd = Thread(target=nerfreals[0].render, args=(thread_quit,))
rendthrd.start()
#############################################################################
appasync = web.Application(client_max_size=1024**2*100)
appasync = web.Application(client_max_size=1024 ** 2 * 100)
appasync.on_shutdown.append(on_shutdown)
appasync.router.add_post("/offer", offer)
appasync.router.add_post("/human", human)
@ -542,7 +496,7 @@ if __name__ == '__main__':
appasync.router.add_post("/record", record)
appasync.router.add_post("/interrupt_talk", interrupt_talk)
appasync.router.add_post("/is_speaking", is_speaking)
appasync.router.add_static('/',path='web')
appasync.router.add_static('/', path='web')
# Configure default CORS settings.
cors = aiohttp_cors.setup(appasync, defaults={
@ -565,21 +519,26 @@ if __name__ == '__main__':
logger.info('如果使用webrtc推荐访问webrtc集成前端: http://<serverip>:'+str(opt.listenport)+'/dashboard.html')
# 服务开启将数据库中enable_status 对应的 value 字段改为 '1'
set_enable_status(db_path)
def run_server(runner):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(runner.setup())
site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
loop.run_until_complete(site.start())
if opt.transport=='rtcpush':
if opt.transport == 'rtcpush':
for k in range(opt.max_session):
push_url = opt.push_url
if k!=0:
push_url = opt.push_url+str(k)
loop.run_until_complete(run(push_url,k))
if k != 0:
push_url = opt.push_url + str(k)
loop.run_until_complete(run(push_url, k))
loop.run_forever()
run_server(web.AppRunner(appasync))
except Exception:
import traceback
traceback.print_exc() # 打印完整的错误堆栈
input("发生异常,按回车键退出…") # 等待用户按回车再退出
Loading…
Cancel
Save