|
|
|
@ -61,6 +61,7 @@ avatar = None
|
|
|
|
|
#####webrtc###############################
|
|
|
|
|
pcs = set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def randN(N) -> int:
|
|
|
|
|
'''生成长度为 N的随机数 '''
|
|
|
|
|
min = pow(10, N - 1)
|
|
|
|
@ -101,6 +102,25 @@ def set_enable_status(db_path):
|
|
|
|
|
finally:
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 用于更新数据库中的 sessionid
|
|
|
|
|
def update_sessionid_in_db(sessionid):
|
|
|
|
|
try:
|
|
|
|
|
# 打开数据库连接
|
|
|
|
|
conn = sqlite3.connect(db_path, check_same_thread=False)
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
|
|
|
|
# 更新 livetalking_sessionid 字段
|
|
|
|
|
cursor.execute("UPDATE live_config SET value = ? WHERE key = 'livetalking_sessionid';", (str(sessionid),))
|
|
|
|
|
conn.commit()
|
|
|
|
|
logger.info(f"Successfully updated livetalking_sessionid to {sessionid}")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception("Error updating livetalking_sessionid in database:")
|
|
|
|
|
finally:
|
|
|
|
|
if conn:
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_nerfreal(sessionid: int) -> BaseReal:
|
|
|
|
|
opt.sessionid = sessionid
|
|
|
|
|
if opt.model == 'wav2lip':
|
|
|
|
@ -108,145 +128,70 @@ def build_nerfreal(sessionid:int)->BaseReal:
|
|
|
|
|
nerfreal = LipReal(opt, model, avatar)
|
|
|
|
|
return nerfreal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# @app.route('/offer', methods=['POST'])
|
|
|
|
|
async def offer(request):
|
|
|
|
|
logger.info("开始处理offer请求")
|
|
|
|
|
|
|
|
|
|
# 接收请求参数
|
|
|
|
|
try:
|
|
|
|
|
params = await request.json()
|
|
|
|
|
logger.debug("接收到的请求参数: %s", params)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("解析请求参数失败: %s", str(e))
|
|
|
|
|
return web.Response(
|
|
|
|
|
content_type="application/json",
|
|
|
|
|
text=json.dumps({"code": -2, "msg": "解析参数失败"})
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 验证必要参数
|
|
|
|
|
if "sdp" not in params or "type" not in params:
|
|
|
|
|
logger.warning("请求参数缺少sdp或type字段: %s", params)
|
|
|
|
|
return web.Response(
|
|
|
|
|
content_type="application/json",
|
|
|
|
|
text=json.dumps({"code": -3, "msg": "缺少必要参数"})
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 创建RTCSessionDescription
|
|
|
|
|
try:
|
|
|
|
|
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
|
|
|
|
logger.debug("成功创建RTCSessionDescription,类型: %s", params["type"])
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("创建RTCSessionDescription失败: %s", str(e))
|
|
|
|
|
return web.Response(
|
|
|
|
|
content_type="application/json",
|
|
|
|
|
text=json.dumps({"code": -4, "msg": "创建offer失败"})
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 生成会话ID并记录
|
|
|
|
|
sessionid = randN(6)
|
|
|
|
|
# if len(nerfreals) >= opt.max_session:
|
|
|
|
|
# logger.info('reach max session')
|
|
|
|
|
# return web.Response(
|
|
|
|
|
# content_type="application/json",
|
|
|
|
|
# text=json.dumps(
|
|
|
|
|
# {"code": -1, "msg": "reach max session"}
|
|
|
|
|
# ),
|
|
|
|
|
# )
|
|
|
|
|
sessionid = randN(6) # len(nerfreals)
|
|
|
|
|
nerfreals[sessionid] = None
|
|
|
|
|
logger.info('生成新会话,sessionid=%d, 当前会话数量=%d', sessionid, len(nerfreals))
|
|
|
|
|
logger.info('sessionid=%d, session num=%d', sessionid, len(nerfreals))
|
|
|
|
|
update_sessionid_in_db(sessionid)
|
|
|
|
|
|
|
|
|
|
# 构建nerfreal
|
|
|
|
|
try:
|
|
|
|
|
logger.debug("开始构建nerfreal,sessionid=%d", sessionid)
|
|
|
|
|
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal, sessionid)
|
|
|
|
|
nerfreals[sessionid] = nerfreal
|
|
|
|
|
logger.info("成功构建nerfreal,sessionid=%d", sessionid)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("构建nerfreal失败,sessionid=%d: %s", sessionid, str(e))
|
|
|
|
|
del nerfreals[sessionid]
|
|
|
|
|
return web.Response(
|
|
|
|
|
content_type="application/json",
|
|
|
|
|
text=json.dumps({"code": -5, "msg": "构建会话失败"})
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 配置ICE服务器并创建RTCPeerConnection
|
|
|
|
|
try:
|
|
|
|
|
# ice_server = RTCIceServer(urls='stun:stun.l.google.com:19302')
|
|
|
|
|
ice_server = RTCIceServer(urls='stun:stun.miwifi.com:3478')
|
|
|
|
|
logger.debug("使用ICE服务器: %s", ice_server.urls)
|
|
|
|
|
|
|
|
|
|
pc = RTCPeerConnection(configuration=RTCConfiguration(iceServers=[ice_server]))
|
|
|
|
|
pcs.add(pc)
|
|
|
|
|
logger.info("创建RTCPeerConnection成功,sessionid=%d", sessionid)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("创建RTCPeerConnection失败,sessionid=%d: %s", sessionid, str(e))
|
|
|
|
|
del nerfreals[sessionid]
|
|
|
|
|
return web.Response(
|
|
|
|
|
content_type="application/json",
|
|
|
|
|
text=json.dumps({"code": -6, "msg": "创建连接失败"})
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 连接状态变化处理
|
|
|
|
|
@pc.on("connectionstatechange")
|
|
|
|
|
async def on_connectionstatechange():
|
|
|
|
|
logger.info("sessionid=%d, 连接状态变更为: %s", sessionid, pc.connectionState)
|
|
|
|
|
logger.info("Connection state is %s" % pc.connectionState)
|
|
|
|
|
if pc.connectionState == "failed":
|
|
|
|
|
logger.warning("sessionid=%d, 连接失败,关闭连接", sessionid)
|
|
|
|
|
await pc.close()
|
|
|
|
|
pcs.discard(pc)
|
|
|
|
|
del nerfreals[sessionid]
|
|
|
|
|
if pc.connectionState == "closed":
|
|
|
|
|
logger.info("sessionid=%d, 连接已关闭", sessionid)
|
|
|
|
|
pcs.discard(pc)
|
|
|
|
|
del nerfreals[sessionid]
|
|
|
|
|
gc.collect()
|
|
|
|
|
logger.debug("sessionid=%d, 已清理资源", sessionid)
|
|
|
|
|
|
|
|
|
|
# 添加媒体轨道
|
|
|
|
|
try:
|
|
|
|
|
player = HumanPlayer(nerfreals[sessionid])
|
|
|
|
|
audio_sender = pc.addTrack(player.audio)
|
|
|
|
|
video_sender = pc.addTrack(player.video)
|
|
|
|
|
logger.debug("sessionid=%d, 已添加音频和视频轨道", sessionid)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("sessionid=%d, 添加媒体轨道失败: %s", sessionid, str(e))
|
|
|
|
|
return web.Response(
|
|
|
|
|
content_type="application/json",
|
|
|
|
|
text=json.dumps({"code": -7, "msg": "添加媒体轨道失败"})
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 配置编解码器偏好
|
|
|
|
|
try:
|
|
|
|
|
capabilities = RTCRtpSender.getCapabilities("video")
|
|
|
|
|
preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs))
|
|
|
|
|
preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs))
|
|
|
|
|
preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs))
|
|
|
|
|
logger.debug("sessionid=%d, 编解码器偏好: %s", sessionid, [p.name for p in preferences])
|
|
|
|
|
|
|
|
|
|
transceiver = pc.getTransceivers()[1]
|
|
|
|
|
transceiver.setCodecPreferences(preferences)
|
|
|
|
|
logger.debug("sessionid=%d, 已设置编解码器偏好", sessionid)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("sessionid=%d, 配置编解码器失败: %s", sessionid, str(e))
|
|
|
|
|
|
|
|
|
|
# 处理offer并创建answer
|
|
|
|
|
try:
|
|
|
|
|
await pc.setRemoteDescription(offer)
|
|
|
|
|
logger.debug("sessionid=%d, 已设置远程描述", sessionid)
|
|
|
|
|
|
|
|
|
|
answer = await pc.createAnswer()
|
|
|
|
|
await pc.setLocalDescription(answer)
|
|
|
|
|
logger.debug("sessionid=%d, 已创建并设置本地answer", sessionid)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("sessionid=%d, 处理offer创建answer失败: %s", sessionid, str(e))
|
|
|
|
|
return web.Response(
|
|
|
|
|
content_type="application/json",
|
|
|
|
|
text=json.dumps({"code": -8, "msg": "处理offer失败"})
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 返回结果
|
|
|
|
|
response_data = {
|
|
|
|
|
"sdp": pc.localDescription.sdp,
|
|
|
|
|
"type": pc.localDescription.type,
|
|
|
|
|
"sessionid": sessionid
|
|
|
|
|
}
|
|
|
|
|
logger.info("sessionid=%d, offer请求处理完成,返回响应", sessionid)
|
|
|
|
|
# return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type})
|
|
|
|
|
|
|
|
|
|
return web.Response(
|
|
|
|
|
content_type="application/json",
|
|
|
|
|
text=json.dumps(response_data)
|
|
|
|
|
text=json.dumps(
|
|
|
|
|
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "sessionid": sessionid}
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def human(request):
|
|
|
|
|
try:
|
|
|
|
|
params = await request.json()
|
|
|
|
@ -278,6 +223,7 @@ async def human(request):
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def interrupt_talk(request):
|
|
|
|
|
try:
|
|
|
|
|
params = await request.json()
|
|
|
|
@ -300,6 +246,7 @@ async def interrupt_talk(request):
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def humanaudio(request):
|
|
|
|
|
try:
|
|
|
|
|
form = await request.post()
|
|
|
|
@ -324,6 +271,7 @@ async def humanaudio(request):
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def set_audiotype(request):
|
|
|
|
|
try:
|
|
|
|
|
params = await request.json()
|
|
|
|
@ -346,6 +294,7 @@ async def set_audiotype(request):
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def record(request):
|
|
|
|
|
try:
|
|
|
|
|
params = await request.json()
|
|
|
|
@ -371,6 +320,7 @@ async def record(request):
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def is_speaking(request):
|
|
|
|
|
params = await request.json()
|
|
|
|
|
|
|
|
|
@ -403,6 +353,7 @@ async def post(url,data):
|
|
|
|
|
except aiohttp.ClientError as e:
|
|
|
|
|
logger.info(f'Error: {e}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def run(push_url, sessionid):
|
|
|
|
|
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal, sessionid)
|
|
|
|
|
nerfreals[sessionid] = nerfreal
|
|
|
|
@ -424,6 +375,8 @@ async def run(push_url,sessionid):
|
|
|
|
|
await pc.setLocalDescription(await pc.createOffer())
|
|
|
|
|
answer = await post(push_url, pc.localDescription.sdp)
|
|
|
|
|
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer, type='answer'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##########################################
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
@ -434,7 +387,7 @@ if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
mp.set_start_method('spawn')
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
parser.add_argument('--config', type=str, default='config/config.json', help="配置文件路径")
|
|
|
|
|
parser.add_argument('--config', type=str, default='../config.json', help="配置文件路径")
|
|
|
|
|
|
|
|
|
|
# audio FPS
|
|
|
|
|
parser.add_argument('--fps', type=int, default=50, help="audio fps,must be 50")
|
|
|
|
@ -521,6 +474,7 @@ if __name__ == '__main__':
|
|
|
|
|
warm_up(opt.batch_size, model, 384)
|
|
|
|
|
elif opt.model == 'ultralight':
|
|
|
|
|
from lightreal import LightReal, load_model, load_avatar, warm_up
|
|
|
|
|
|
|
|
|
|
logger.info(opt)
|
|
|
|
|
model = load_model(opt)
|
|
|
|
|
avatar = load_avatar(opt.avatar_id)
|
|
|
|
@ -565,6 +519,8 @@ if __name__ == '__main__':
|
|
|
|
|
logger.info('如果使用webrtc,推荐访问webrtc集成前端: http://<serverip>:'+str(opt.listenport)+'/dashboard.html')
|
|
|
|
|
# 服务开启,将数据库中enable_status 对应的 value 字段改为 '1'
|
|
|
|
|
set_enable_status(db_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_server(runner):
|
|
|
|
|
loop = asyncio.new_event_loop()
|
|
|
|
|
asyncio.set_event_loop(loop)
|
|
|
|
@ -578,8 +534,11 @@ if __name__ == '__main__':
|
|
|
|
|
push_url = opt.push_url + str(k)
|
|
|
|
|
loop.run_until_complete(run(push_url, k))
|
|
|
|
|
loop.run_forever()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_server(web.AppRunner(appasync))
|
|
|
|
|
except Exception:
|
|
|
|
|
import traceback
|
|
|
|
|
|
|
|
|
|
traceback.print_exc() # 打印完整的错误堆栈
|
|
|
|
|
input("发生异常,按回车键退出…") # 等待用户按回车再退出
|