From 0c63e9a11be04229c5f2d697c21bc4bc88feb944 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 17 Jul 2024 08:21:31 +0800 Subject: [PATCH] support multi session --- app.py | 48 +++++++++++++++++++++++++++++++++++----------- web/client.js | 1 + web/webrtcapi.html | 3 +++ 3 files changed, 41 insertions(+), 11 deletions(-) diff --git a/app.py b/app.py index ec4b54c..4f3c255 100644 --- a/app.py +++ b/app.py @@ -27,7 +27,8 @@ import asyncio app = Flask(__name__) sockets = Sockets(app) -global nerfreal +nerfreals = [] +statreals = [] @sockets.route('/humanecho') @@ -87,6 +88,16 @@ async def offer(request): params = await request.json() offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) + sessionid = len(nerfreals) + for index,value in enumerate(statreals): + if value == 0: + sessionid = index + break + if sessionid>=len(nerfreals): + print('reach max session') + return -1 + statreals[sessionid] = 1 + pc = RTCPeerConnection() pcs.add(pc) @@ -96,8 +107,12 @@ async def offer(request): if pc.connectionState == "failed": await pc.close() pcs.discard(pc) + statreals[sessionid] = 0 + if pc.connectionState == "closed": + pcs.discard(pc) + statreals[sessionid] = 0 - player = HumanPlayer(nerfreal) + player = HumanPlayer(nerfreals[sessionid]) audio_sender = pc.addTrack(player.audio) video_sender = pc.addTrack(player.video) @@ -111,21 +126,22 @@ async def offer(request): return web.Response( content_type="application/json", text=json.dumps( - {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} + {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "sessionid":sessionid} ), ) async def human(request): params = await request.json() + sessionid = params.get('sessionid',0) if params.get('interrupt'): - nerfreal.pause_talk() + nerfreals[sessionid].pause_talk() if params['type']=='echo': - nerfreal.put_msg_txt(params['text']) + nerfreals[sessionid].put_msg_txt(params['text']) elif params['type']=='chat': res=await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text'])) - nerfreal.put_msg_txt(res) + nerfreals[sessionid].put_msg_txt(res) return web.Response( content_type="application/json", @@ -159,7 +175,7 @@ async def run(push_url): await pc.close() pcs.discard(pc) - player = HumanPlayer(nerfreal) + player = HumanPlayer(nerfreals[0]) audio_sender = pc.addTrack(player.audio) video_sender = pc.addTrack(player.video) @@ -303,6 +319,7 @@ if __name__ == '__main__': parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream + parser.add_argument('--max_session', type=int, default=1) #multi session count parser.add_argument('--listenport', type=int, default=8010) opt = parser.parse_args() @@ -355,20 +372,29 @@ if __name__ == '__main__': model.eye_areas = test_loader._data.eye_area # we still need test_loader to provide audio features for testing. - nerfreal = NeRFReal(opt, trainer, test_loader) + for _ in range(opt.max_session): + nerfreal = NeRFReal(opt, trainer, test_loader) + nerfreals.append(nerfreal) elif opt.model == 'musetalk': from musereal import MuseReal print(opt) - nerfreal = MuseReal(opt) + for _ in range(opt.max_session): + nerfreal = MuseReal(opt) + nerfreals.append(nerfreal) elif opt.model == 'wav2lip': from lipreal import LipReal print(opt) - nerfreal = LipReal(opt) + for _ in range(opt.max_session): + nerfreal = LipReal(opt) + nerfreals.append(nerfreal) + + for _ in range(opt.max_session): + statreals.append(0) #txt_to_audio('我是中国人,我来自北京') if opt.transport=='rtmp': thread_quit = Event() - rendthrd = Thread(target=nerfreal.render,args=(thread_quit,)) + rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,)) rendthrd.start() ############################################################################# diff --git a/web/client.js b/web/client.js index c48b064..b23ecc1 100644 --- a/web/client.js +++ b/web/client.js @@ -35,6 +35,7 @@ function negotiate() { }).then((response) => { return response.json(); }).then((answer) => { + document.getElementById('sessionid').value = answer.sessionid return pc.setRemoteDescription(answer); }).catch((e) => { alert(e); diff --git a/web/webrtcapi.html b/web/webrtcapi.html index eff0287..7f874a9 100644 --- a/web/webrtcapi.html +++ b/web/webrtcapi.html @@ -30,6 +30,7 @@ +

input text

@@ -75,11 +76,13 @@ e.preventDefault(); var message = $('#message').val(); console.log('Sending: ' + message); + console.log('sessionid: ',document.getElementById('sessionid').value); fetch('/human', { body: JSON.stringify({ text: message, type: 'echo', interrupt: true, + sessionid:parseInt(document.getElementById('sessionid').value), }), headers: { 'Content-Type': 'application/json'