support multi session

main
unknown 1 year ago
parent 2883b2243e
commit 0c63e9a11b

@ -27,7 +27,8 @@ import asyncio
app = Flask(__name__) app = Flask(__name__)
sockets = Sockets(app) sockets = Sockets(app)
global nerfreal nerfreals = []
statreals = []
@sockets.route('/humanecho') @sockets.route('/humanecho')
@ -87,6 +88,16 @@ async def offer(request):
params = await request.json() params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
sessionid = len(nerfreals)
for index,value in enumerate(statreals):
if value == 0:
sessionid = index
break
if sessionid>=len(nerfreals):
print('reach max session')
return -1
statreals[sessionid] = 1
pc = RTCPeerConnection() pc = RTCPeerConnection()
pcs.add(pc) pcs.add(pc)
@ -96,8 +107,12 @@ async def offer(request):
if pc.connectionState == "failed": if pc.connectionState == "failed":
await pc.close() await pc.close()
pcs.discard(pc) pcs.discard(pc)
statreals[sessionid] = 0
if pc.connectionState == "closed":
pcs.discard(pc)
statreals[sessionid] = 0
player = HumanPlayer(nerfreal) player = HumanPlayer(nerfreals[sessionid])
audio_sender = pc.addTrack(player.audio) audio_sender = pc.addTrack(player.audio)
video_sender = pc.addTrack(player.video) video_sender = pc.addTrack(player.video)
@ -111,21 +126,22 @@ async def offer(request):
return web.Response( return web.Response(
content_type="application/json", content_type="application/json",
text=json.dumps( text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "sessionid":sessionid}
), ),
) )
async def human(request): async def human(request):
params = await request.json() params = await request.json()
sessionid = params.get('sessionid',0)
if params.get('interrupt'): if params.get('interrupt'):
nerfreal.pause_talk() nerfreals[sessionid].pause_talk()
if params['type']=='echo': if params['type']=='echo':
nerfreal.put_msg_txt(params['text']) nerfreals[sessionid].put_msg_txt(params['text'])
elif params['type']=='chat': elif params['type']=='chat':
res=await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text'])) res=await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text']))
nerfreal.put_msg_txt(res) nerfreals[sessionid].put_msg_txt(res)
return web.Response( return web.Response(
content_type="application/json", content_type="application/json",
@ -159,7 +175,7 @@ async def run(push_url):
await pc.close() await pc.close()
pcs.discard(pc) pcs.discard(pc)
player = HumanPlayer(nerfreal) player = HumanPlayer(nerfreals[0])
audio_sender = pc.addTrack(player.audio) audio_sender = pc.addTrack(player.audio)
video_sender = pc.addTrack(player.video) video_sender = pc.addTrack(player.video)
@ -303,6 +319,7 @@ if __name__ == '__main__':
parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush
parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
parser.add_argument('--max_session', type=int, default=1) #multi session count
parser.add_argument('--listenport', type=int, default=8010) parser.add_argument('--listenport', type=int, default=8010)
opt = parser.parse_args() opt = parser.parse_args()
@ -355,20 +372,29 @@ if __name__ == '__main__':
model.eye_areas = test_loader._data.eye_area model.eye_areas = test_loader._data.eye_area
# we still need test_loader to provide audio features for testing. # we still need test_loader to provide audio features for testing.
nerfreal = NeRFReal(opt, trainer, test_loader) for _ in range(opt.max_session):
nerfreal = NeRFReal(opt, trainer, test_loader)
nerfreals.append(nerfreal)
elif opt.model == 'musetalk': elif opt.model == 'musetalk':
from musereal import MuseReal from musereal import MuseReal
print(opt) print(opt)
nerfreal = MuseReal(opt) for _ in range(opt.max_session):
nerfreal = MuseReal(opt)
nerfreals.append(nerfreal)
elif opt.model == 'wav2lip': elif opt.model == 'wav2lip':
from lipreal import LipReal from lipreal import LipReal
print(opt) print(opt)
nerfreal = LipReal(opt) for _ in range(opt.max_session):
nerfreal = LipReal(opt)
nerfreals.append(nerfreal)
for _ in range(opt.max_session):
statreals.append(0)
#txt_to_audio('我是中国人,我来自北京') #txt_to_audio('我是中国人,我来自北京')
if opt.transport=='rtmp': if opt.transport=='rtmp':
thread_quit = Event() thread_quit = Event()
rendthrd = Thread(target=nerfreal.render,args=(thread_quit,)) rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
rendthrd.start() rendthrd.start()
############################################################################# #############################################################################

@ -35,6 +35,7 @@ function negotiate() {
}).then((response) => { }).then((response) => {
return response.json(); return response.json();
}).then((answer) => { }).then((answer) => {
document.getElementById('sessionid').value = answer.sessionid
return pc.setRemoteDescription(answer); return pc.setRemoteDescription(answer);
}).catch((e) => { }).catch((e) => {
alert(e); alert(e);

@ -30,6 +30,7 @@
</div> </div>
<button id="start" onclick="start()">Start</button> <button id="start" onclick="start()">Start</button>
<button id="stop" style="display: none" onclick="stop()">Stop</button> <button id="stop" style="display: none" onclick="stop()">Stop</button>
<input type="hidden" id="sessionid" value="1234">
<form class="form-inline" id="echo-form"> <form class="form-inline" id="echo-form">
<div class="form-group"> <div class="form-group">
<p>input text</p> <p>input text</p>
@ -75,11 +76,13 @@
e.preventDefault(); e.preventDefault();
var message = $('#message').val(); var message = $('#message').val();
console.log('Sending: ' + message); console.log('Sending: ' + message);
console.log('sessionid: ',document.getElementById('sessionid').value);
fetch('/human', { fetch('/human', {
body: JSON.stringify({ body: JSON.stringify({
text: message, text: message,
type: 'echo', type: 'echo',
interrupt: true, interrupt: true,
sessionid:parseInt(document.getElementById('sessionid').value),
}), }),
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'

Loading…
Cancel
Save