|
|
@ -21,9 +21,9 @@ from flask_sockets import Sockets
|
|
|
|
import base64
|
|
|
|
import base64
|
|
|
|
import time
|
|
|
|
import time
|
|
|
|
import json
|
|
|
|
import json
|
|
|
|
import gevent
|
|
|
|
#import gevent
|
|
|
|
from gevent import pywsgi
|
|
|
|
#from gevent import pywsgi
|
|
|
|
from geventwebsocket.handler import WebSocketHandler
|
|
|
|
#from geventwebsocket.handler import WebSocketHandler
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
import re
|
|
|
|
import re
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
@ -39,6 +39,7 @@ from aiortc.rtcrtpsender import RTCRtpSender
|
|
|
|
from webrtc import HumanPlayer
|
|
|
|
from webrtc import HumanPlayer
|
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
import random
|
|
|
|
|
|
|
|
|
|
|
|
import shutil
|
|
|
|
import shutil
|
|
|
|
import asyncio
|
|
|
|
import asyncio
|
|
|
@ -46,29 +47,11 @@ import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = Flask(__name__)
|
|
|
|
app = Flask(__name__)
|
|
|
|
sockets = Sockets(app)
|
|
|
|
#sockets = Sockets(app)
|
|
|
|
nerfreals = []
|
|
|
|
nerfreals = {}
|
|
|
|
statreals = []
|
|
|
|
opt = None
|
|
|
|
|
|
|
|
model = None
|
|
|
|
|
|
|
|
avatar = None
|
|
|
|
@sockets.route('/humanecho')
|
|
|
|
|
|
|
|
def echo_socket(ws):
|
|
|
|
|
|
|
|
# 获取WebSocket对象
|
|
|
|
|
|
|
|
#ws = request.environ.get('wsgi.websocket')
|
|
|
|
|
|
|
|
# 如果没有获取到,返回错误信息
|
|
|
|
|
|
|
|
if not ws:
|
|
|
|
|
|
|
|
print('未建立连接!')
|
|
|
|
|
|
|
|
return 'Please use WebSocket'
|
|
|
|
|
|
|
|
# 否则,循环接收和发送消息
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
print('建立连接!')
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
|
|
|
message = ws.receive()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not message or len(message)==0:
|
|
|
|
|
|
|
|
return '输入信息为空'
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
nerfreal.put_msg_txt(message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# def llm_response(message):
|
|
|
|
# def llm_response(message):
|
|
|
@ -124,43 +107,41 @@ def llm_response(message,nerfreal):
|
|
|
|
print(f"llm Time to last chunk: {end-start}s")
|
|
|
|
print(f"llm Time to last chunk: {end-start}s")
|
|
|
|
nerfreal.put_msg_txt(result)
|
|
|
|
nerfreal.put_msg_txt(result)
|
|
|
|
|
|
|
|
|
|
|
|
@sockets.route('/humanchat')
|
|
|
|
|
|
|
|
def chat_socket(ws):
|
|
|
|
|
|
|
|
# 获取WebSocket对象
|
|
|
|
|
|
|
|
#ws = request.environ.get('wsgi.websocket')
|
|
|
|
|
|
|
|
# 如果没有获取到,返回错误信息
|
|
|
|
|
|
|
|
if not ws:
|
|
|
|
|
|
|
|
print('未建立连接!')
|
|
|
|
|
|
|
|
return 'Please use WebSocket'
|
|
|
|
|
|
|
|
# 否则,循环接收和发送消息
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
print('建立连接!')
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
|
|
|
message = ws.receive()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(message)==0:
|
|
|
|
|
|
|
|
return '输入信息为空'
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
res=llm_response(message)
|
|
|
|
|
|
|
|
nerfreal.put_msg_txt(res)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#####webrtc###############################
|
|
|
|
#####webrtc###############################
|
|
|
|
pcs = set()
|
|
|
|
pcs = set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def randN(N):
|
|
|
|
|
|
|
|
'''生成长度为 N的随机数 '''
|
|
|
|
|
|
|
|
min = pow(10, N - 1)
|
|
|
|
|
|
|
|
max = pow(10, N)
|
|
|
|
|
|
|
|
return random.randint(min, max - 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_nerfreal(sessionid):
|
|
|
|
|
|
|
|
opt.sessionid=sessionid
|
|
|
|
|
|
|
|
if opt.model == 'wav2lip':
|
|
|
|
|
|
|
|
from lipreal import LipReal
|
|
|
|
|
|
|
|
nerfreal = LipReal(opt,model,avatar)
|
|
|
|
|
|
|
|
elif opt.model == 'musetalk':
|
|
|
|
|
|
|
|
from musereal import MuseReal
|
|
|
|
|
|
|
|
nerfreal = MuseReal(opt,model,avatar)
|
|
|
|
|
|
|
|
elif opt.model == 'ernerf':
|
|
|
|
|
|
|
|
from nerfreal import NeRFReal
|
|
|
|
|
|
|
|
nerfreal = NeRFReal(opt,model,avatar)
|
|
|
|
|
|
|
|
return nerfreal
|
|
|
|
|
|
|
|
|
|
|
|
#@app.route('/offer', methods=['POST'])
|
|
|
|
#@app.route('/offer', methods=['POST'])
|
|
|
|
async def offer(request):
|
|
|
|
async def offer(request):
|
|
|
|
params = await request.json()
|
|
|
|
params = await request.json()
|
|
|
|
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
|
|
|
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
|
|
|
|
|
|
|
|
|
|
|
sessionid = len(nerfreals)
|
|
|
|
if len(nerfreals) >= opt.max_session:
|
|
|
|
for index,value in enumerate(statreals):
|
|
|
|
|
|
|
|
if value == 0:
|
|
|
|
|
|
|
|
sessionid = index
|
|
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
if sessionid>=len(nerfreals):
|
|
|
|
|
|
|
|
print('reach max session')
|
|
|
|
print('reach max session')
|
|
|
|
return -1
|
|
|
|
return -1
|
|
|
|
statreals[sessionid] = 1
|
|
|
|
sessionid = randN(6) #len(nerfreals)
|
|
|
|
|
|
|
|
print('sessionid=',sessionid)
|
|
|
|
|
|
|
|
nerfreals[sessionid] = None
|
|
|
|
|
|
|
|
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
|
|
|
|
|
|
|
|
nerfreals[sessionid] = nerfreal
|
|
|
|
|
|
|
|
|
|
|
|
pc = RTCPeerConnection()
|
|
|
|
pc = RTCPeerConnection()
|
|
|
|
pcs.add(pc)
|
|
|
|
pcs.add(pc)
|
|
|
@ -171,10 +152,10 @@ async def offer(request):
|
|
|
|
if pc.connectionState == "failed":
|
|
|
|
if pc.connectionState == "failed":
|
|
|
|
await pc.close()
|
|
|
|
await pc.close()
|
|
|
|
pcs.discard(pc)
|
|
|
|
pcs.discard(pc)
|
|
|
|
statreals[sessionid] = 0
|
|
|
|
del nerfreals[sessionid]
|
|
|
|
if pc.connectionState == "closed":
|
|
|
|
if pc.connectionState == "closed":
|
|
|
|
pcs.discard(pc)
|
|
|
|
pcs.discard(pc)
|
|
|
|
statreals[sessionid] = 0
|
|
|
|
del nerfreals[sessionid]
|
|
|
|
|
|
|
|
|
|
|
|
player = HumanPlayer(nerfreals[sessionid])
|
|
|
|
player = HumanPlayer(nerfreals[sessionid])
|
|
|
|
audio_sender = pc.addTrack(player.audio)
|
|
|
|
audio_sender = pc.addTrack(player.audio)
|
|
|
@ -205,7 +186,7 @@ async def human(request):
|
|
|
|
|
|
|
|
|
|
|
|
sessionid = params.get('sessionid',0)
|
|
|
|
sessionid = params.get('sessionid',0)
|
|
|
|
if params.get('interrupt'):
|
|
|
|
if params.get('interrupt'):
|
|
|
|
nerfreals[sessionid].pause_talk()
|
|
|
|
nerfreals[sessionid].flush_talk()
|
|
|
|
|
|
|
|
|
|
|
|
if params['type']=='echo':
|
|
|
|
if params['type']=='echo':
|
|
|
|
nerfreals[sessionid].put_msg_txt(params['text'])
|
|
|
|
nerfreals[sessionid].put_msg_txt(params['text'])
|
|
|
@ -298,7 +279,10 @@ async def post(url,data):
|
|
|
|
except aiohttp.ClientError as e:
|
|
|
|
except aiohttp.ClientError as e:
|
|
|
|
print(f'Error: {e}')
|
|
|
|
print(f'Error: {e}')
|
|
|
|
|
|
|
|
|
|
|
|
async def run(push_url):
|
|
|
|
async def run(push_url,sessionid):
|
|
|
|
|
|
|
|
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
|
|
|
|
|
|
|
|
nerfreals[sessionid] = nerfreal
|
|
|
|
|
|
|
|
|
|
|
|
pc = RTCPeerConnection()
|
|
|
|
pc = RTCPeerConnection()
|
|
|
|
pcs.add(pc)
|
|
|
|
pcs.add(pc)
|
|
|
|
|
|
|
|
|
|
|
@ -309,7 +293,7 @@ async def run(push_url):
|
|
|
|
await pc.close()
|
|
|
|
await pc.close()
|
|
|
|
pcs.discard(pc)
|
|
|
|
pcs.discard(pc)
|
|
|
|
|
|
|
|
|
|
|
|
player = HumanPlayer(nerfreals[0])
|
|
|
|
player = HumanPlayer(nerfreals[sessionid])
|
|
|
|
audio_sender = pc.addTrack(player.audio)
|
|
|
|
audio_sender = pc.addTrack(player.audio)
|
|
|
|
video_sender = pc.addTrack(player.video)
|
|
|
|
video_sender = pc.addTrack(player.video)
|
|
|
|
|
|
|
|
|
|
|
@ -467,94 +451,37 @@ if __name__ == '__main__':
|
|
|
|
opt.customopt = json.load(file)
|
|
|
|
opt.customopt = json.load(file)
|
|
|
|
|
|
|
|
|
|
|
|
if opt.model == 'ernerf':
|
|
|
|
if opt.model == 'ernerf':
|
|
|
|
from ernerf.nerf_triplane.provider import NeRFDataset_Test
|
|
|
|
from nerfreal import NeRFReal,load_model,load_avatar
|
|
|
|
from ernerf.nerf_triplane.utils import *
|
|
|
|
model = load_model(opt)
|
|
|
|
from ernerf.nerf_triplane.network import NeRFNetwork
|
|
|
|
avatar = load_avatar(opt)
|
|
|
|
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
|
|
|
|
|
|
|
|
from nerfreal import NeRFReal
|
|
|
|
|
|
|
|
# assert test mode
|
|
|
|
|
|
|
|
opt.test = True
|
|
|
|
|
|
|
|
opt.test_train = False
|
|
|
|
|
|
|
|
#opt.train_camera =True
|
|
|
|
|
|
|
|
# explicit smoothing
|
|
|
|
|
|
|
|
opt.smooth_path = True
|
|
|
|
|
|
|
|
opt.smooth_lips = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert opt.pose != '', 'Must provide a pose source'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# if opt.O:
|
|
|
|
|
|
|
|
opt.fp16 = True
|
|
|
|
|
|
|
|
opt.cuda_ray = True
|
|
|
|
|
|
|
|
opt.exp_eye = True
|
|
|
|
|
|
|
|
opt.smooth_eye = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if opt.torso_imgs=='': #no img,use model output
|
|
|
|
|
|
|
|
opt.torso = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# assert opt.cuda_ray, "Only support CUDA ray mode."
|
|
|
|
|
|
|
|
opt.asr = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if opt.patch_size > 1:
|
|
|
|
|
|
|
|
# assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
|
|
|
|
|
|
|
|
assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."
|
|
|
|
|
|
|
|
seed_everything(opt.seed)
|
|
|
|
|
|
|
|
print(opt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
|
|
model = NeRFNetwork(opt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
criterion = torch.nn.MSELoss(reduction='none')
|
|
|
|
|
|
|
|
metrics = [] # use no metric in GUI for faster initialization...
|
|
|
|
|
|
|
|
print(model)
|
|
|
|
|
|
|
|
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_loader = NeRFDataset_Test(opt, device=device).dataloader()
|
|
|
|
|
|
|
|
model.aud_features = test_loader._data.auds
|
|
|
|
|
|
|
|
model.eye_areas = test_loader._data.eye_area
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f'[INFO] loading ASR model {opt.asr_model}...')
|
|
|
|
|
|
|
|
if 'hubert' in opt.asr_model:
|
|
|
|
|
|
|
|
audio_processor = Wav2Vec2Processor.from_pretrained(opt.asr_model)
|
|
|
|
|
|
|
|
audio_model = HubertModel.from_pretrained(opt.asr_model).to(device)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
audio_processor = AutoProcessor.from_pretrained(opt.asr_model)
|
|
|
|
|
|
|
|
audio_model = AutoModelForCTC.from_pretrained(opt.asr_model).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# we still need test_loader to provide audio features for testing.
|
|
|
|
# we still need test_loader to provide audio features for testing.
|
|
|
|
for k in range(opt.max_session):
|
|
|
|
# for k in range(opt.max_session):
|
|
|
|
opt.sessionid=k
|
|
|
|
# opt.sessionid=k
|
|
|
|
nerfreal = NeRFReal(opt, trainer, test_loader,audio_processor,audio_model)
|
|
|
|
# nerfreal = NeRFReal(opt, trainer, test_loader,audio_processor,audio_model)
|
|
|
|
nerfreals.append(nerfreal)
|
|
|
|
# nerfreals.append(nerfreal)
|
|
|
|
elif opt.model == 'musetalk':
|
|
|
|
elif opt.model == 'musetalk':
|
|
|
|
from musereal import MuseReal
|
|
|
|
from musereal import MuseReal,load_model,load_avatar
|
|
|
|
from musetalk.utils.utils import load_all_model
|
|
|
|
|
|
|
|
print(opt)
|
|
|
|
print(opt)
|
|
|
|
audio_processor,vae, unet, pe = load_all_model()
|
|
|
|
model = load_model()
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
avatar = load_avatar(opt.avatar_id)
|
|
|
|
timesteps = torch.tensor([0], device=device)
|
|
|
|
# for k in range(opt.max_session):
|
|
|
|
pe = pe.half()
|
|
|
|
# opt.sessionid=k
|
|
|
|
vae.vae = vae.vae.half()
|
|
|
|
# nerfreal = MuseReal(opt,audio_processor,vae, unet, pe,timesteps)
|
|
|
|
#vae.vae.share_memory()
|
|
|
|
# nerfreals.append(nerfreal)
|
|
|
|
unet.model = unet.model.half()
|
|
|
|
|
|
|
|
#unet.model.share_memory()
|
|
|
|
|
|
|
|
for k in range(opt.max_session):
|
|
|
|
|
|
|
|
opt.sessionid=k
|
|
|
|
|
|
|
|
nerfreal = MuseReal(opt,audio_processor,vae, unet, pe,timesteps)
|
|
|
|
|
|
|
|
nerfreals.append(nerfreal)
|
|
|
|
|
|
|
|
elif opt.model == 'wav2lip':
|
|
|
|
elif opt.model == 'wav2lip':
|
|
|
|
from lipreal import LipReal,load_model
|
|
|
|
from lipreal import LipReal,load_model,load_avatar
|
|
|
|
print(opt)
|
|
|
|
print(opt)
|
|
|
|
model = load_model("./models/wav2lip.pth")
|
|
|
|
model = load_model("./models/wav2lip.pth")
|
|
|
|
for k in range(opt.max_session):
|
|
|
|
avatar = load_avatar(opt.avatar_id)
|
|
|
|
opt.sessionid=k
|
|
|
|
# for k in range(opt.max_session):
|
|
|
|
nerfreal = LipReal(opt,model)
|
|
|
|
# opt.sessionid=k
|
|
|
|
nerfreals.append(nerfreal)
|
|
|
|
# nerfreal = LipReal(opt,model)
|
|
|
|
|
|
|
|
# nerfreals.append(nerfreal)
|
|
|
|
for _ in range(opt.max_session):
|
|
|
|
|
|
|
|
statreals.append(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if opt.transport=='rtmp':
|
|
|
|
if opt.transport=='rtmp':
|
|
|
|
thread_quit = Event()
|
|
|
|
thread_quit = Event()
|
|
|
|
|
|
|
|
nerfreals[0] = build_nerfreal(0)
|
|
|
|
rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
|
|
|
|
rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
|
|
|
|
rendthrd.start()
|
|
|
|
rendthrd.start()
|
|
|
|
|
|
|
|
|
|
|
@ -594,7 +521,11 @@ if __name__ == '__main__':
|
|
|
|
site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
|
|
|
|
site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
|
|
|
|
loop.run_until_complete(site.start())
|
|
|
|
loop.run_until_complete(site.start())
|
|
|
|
if opt.transport=='rtcpush':
|
|
|
|
if opt.transport=='rtcpush':
|
|
|
|
loop.run_until_complete(run(opt.push_url))
|
|
|
|
for k in range(opt.max_session):
|
|
|
|
|
|
|
|
push_url = opt.push_url
|
|
|
|
|
|
|
|
if k!=0:
|
|
|
|
|
|
|
|
push_url = opt.push_url+str(k)
|
|
|
|
|
|
|
|
loop.run_until_complete(run(push_url,k))
|
|
|
|
loop.run_forever()
|
|
|
|
loop.run_forever()
|
|
|
|
#Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
|
|
|
|
#Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
|
|
|
|
run_server(web.AppRunner(appasync))
|
|
|
|
run_server(web.AppRunner(appasync))
|
|
|
|