Optimize multiple sessions

main
lipku 8 months ago
parent dbe508cb65
commit 00e3c8a23d

205
app.py

@ -21,9 +21,9 @@ from flask_sockets import Sockets
import base64 import base64
import time import time
import json import json
import gevent #import gevent
from gevent import pywsgi #from gevent import pywsgi
from geventwebsocket.handler import WebSocketHandler #from geventwebsocket.handler import WebSocketHandler
import os import os
import re import re
import numpy as np import numpy as np
@ -39,6 +39,7 @@ from aiortc.rtcrtpsender import RTCRtpSender
from webrtc import HumanPlayer from webrtc import HumanPlayer
import argparse import argparse
import random
import shutil import shutil
import asyncio import asyncio
@ -46,29 +47,11 @@ import torch
app = Flask(__name__) app = Flask(__name__)
sockets = Sockets(app) #sockets = Sockets(app)
nerfreals = [] nerfreals = {}
statreals = [] opt = None
model = None
avatar = None
@sockets.route('/humanecho')
def echo_socket(ws):
# 获取WebSocket对象
#ws = request.environ.get('wsgi.websocket')
# 如果没有获取到,返回错误信息
if not ws:
print('未建立连接!')
return 'Please use WebSocket'
# 否则,循环接收和发送消息
else:
print('建立连接!')
while True:
message = ws.receive()
if not message or len(message)==0:
return '输入信息为空'
else:
nerfreal.put_msg_txt(message)
# def llm_response(message): # def llm_response(message):
@ -124,43 +107,41 @@ def llm_response(message,nerfreal):
print(f"llm Time to last chunk: {end-start}s") print(f"llm Time to last chunk: {end-start}s")
nerfreal.put_msg_txt(result) nerfreal.put_msg_txt(result)
@sockets.route('/humanchat')
def chat_socket(ws):
# 获取WebSocket对象
#ws = request.environ.get('wsgi.websocket')
# 如果没有获取到,返回错误信息
if not ws:
print('未建立连接!')
return 'Please use WebSocket'
# 否则,循环接收和发送消息
else:
print('建立连接!')
while True:
message = ws.receive()
if len(message)==0:
return '输入信息为空'
else:
res=llm_response(message)
nerfreal.put_msg_txt(res)
#####webrtc############################### #####webrtc###############################
pcs = set() pcs = set()
def randN(N):
'''生成长度为 N的随机数 '''
min = pow(10, N - 1)
max = pow(10, N)
return random.randint(min, max - 1)
def build_nerfreal(sessionid):
opt.sessionid=sessionid
if opt.model == 'wav2lip':
from lipreal import LipReal
nerfreal = LipReal(opt,model,avatar)
elif opt.model == 'musetalk':
from musereal import MuseReal
nerfreal = MuseReal(opt,model,avatar)
elif opt.model == 'ernerf':
from nerfreal import NeRFReal
nerfreal = NeRFReal(opt,model,avatar)
return nerfreal
#@app.route('/offer', methods=['POST']) #@app.route('/offer', methods=['POST'])
async def offer(request): async def offer(request):
params = await request.json() params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
sessionid = len(nerfreals) if len(nerfreals) >= opt.max_session:
for index,value in enumerate(statreals):
if value == 0:
sessionid = index
break
if sessionid>=len(nerfreals):
print('reach max session') print('reach max session')
return -1 return -1
statreals[sessionid] = 1 sessionid = randN(6) #len(nerfreals)
print('sessionid=',sessionid)
nerfreals[sessionid] = None
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
nerfreals[sessionid] = nerfreal
pc = RTCPeerConnection() pc = RTCPeerConnection()
pcs.add(pc) pcs.add(pc)
@ -171,10 +152,10 @@ async def offer(request):
if pc.connectionState == "failed": if pc.connectionState == "failed":
await pc.close() await pc.close()
pcs.discard(pc) pcs.discard(pc)
statreals[sessionid] = 0 del nerfreals[sessionid]
if pc.connectionState == "closed": if pc.connectionState == "closed":
pcs.discard(pc) pcs.discard(pc)
statreals[sessionid] = 0 del nerfreals[sessionid]
player = HumanPlayer(nerfreals[sessionid]) player = HumanPlayer(nerfreals[sessionid])
audio_sender = pc.addTrack(player.audio) audio_sender = pc.addTrack(player.audio)
@ -205,7 +186,7 @@ async def human(request):
sessionid = params.get('sessionid',0) sessionid = params.get('sessionid',0)
if params.get('interrupt'): if params.get('interrupt'):
nerfreals[sessionid].pause_talk() nerfreals[sessionid].flush_talk()
if params['type']=='echo': if params['type']=='echo':
nerfreals[sessionid].put_msg_txt(params['text']) nerfreals[sessionid].put_msg_txt(params['text'])
@ -298,7 +279,10 @@ async def post(url,data):
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
print(f'Error: {e}') print(f'Error: {e}')
async def run(push_url): async def run(push_url,sessionid):
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
nerfreals[sessionid] = nerfreal
pc = RTCPeerConnection() pc = RTCPeerConnection()
pcs.add(pc) pcs.add(pc)
@ -309,7 +293,7 @@ async def run(push_url):
await pc.close() await pc.close()
pcs.discard(pc) pcs.discard(pc)
player = HumanPlayer(nerfreals[0]) player = HumanPlayer(nerfreals[sessionid])
audio_sender = pc.addTrack(player.audio) audio_sender = pc.addTrack(player.audio)
video_sender = pc.addTrack(player.video) video_sender = pc.addTrack(player.video)
@ -467,94 +451,37 @@ if __name__ == '__main__':
opt.customopt = json.load(file) opt.customopt = json.load(file)
if opt.model == 'ernerf': if opt.model == 'ernerf':
from ernerf.nerf_triplane.provider import NeRFDataset_Test from nerfreal import NeRFReal,load_model,load_avatar
from ernerf.nerf_triplane.utils import * model = load_model(opt)
from ernerf.nerf_triplane.network import NeRFNetwork avatar = load_avatar(opt)
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
from nerfreal import NeRFReal
# assert test mode
opt.test = True
opt.test_train = False
#opt.train_camera =True
# explicit smoothing
opt.smooth_path = True
opt.smooth_lips = True
assert opt.pose != '', 'Must provide a pose source'
# if opt.O:
opt.fp16 = True
opt.cuda_ray = True
opt.exp_eye = True
opt.smooth_eye = True
if opt.torso_imgs=='': #no img,use model output
opt.torso = True
# assert opt.cuda_ray, "Only support CUDA ray mode."
opt.asr = True
if opt.patch_size > 1:
# assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."
seed_everything(opt.seed)
print(opt)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = NeRFNetwork(opt)
criterion = torch.nn.MSELoss(reduction='none')
metrics = [] # use no metric in GUI for faster initialization...
print(model)
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
test_loader = NeRFDataset_Test(opt, device=device).dataloader()
model.aud_features = test_loader._data.auds
model.eye_areas = test_loader._data.eye_area
print(f'[INFO] loading ASR model {opt.asr_model}...')
if 'hubert' in opt.asr_model:
audio_processor = Wav2Vec2Processor.from_pretrained(opt.asr_model)
audio_model = HubertModel.from_pretrained(opt.asr_model).to(device)
else:
audio_processor = AutoProcessor.from_pretrained(opt.asr_model)
audio_model = AutoModelForCTC.from_pretrained(opt.asr_model).to(device)
# we still need test_loader to provide audio features for testing. # we still need test_loader to provide audio features for testing.
for k in range(opt.max_session): # for k in range(opt.max_session):
opt.sessionid=k # opt.sessionid=k
nerfreal = NeRFReal(opt, trainer, test_loader,audio_processor,audio_model) # nerfreal = NeRFReal(opt, trainer, test_loader,audio_processor,audio_model)
nerfreals.append(nerfreal) # nerfreals.append(nerfreal)
elif opt.model == 'musetalk': elif opt.model == 'musetalk':
from musereal import MuseReal from musereal import MuseReal,load_model,load_avatar
from musetalk.utils.utils import load_all_model
print(opt) print(opt)
audio_processor,vae, unet, pe = load_all_model() model = load_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") avatar = load_avatar(opt.avatar_id)
timesteps = torch.tensor([0], device=device) # for k in range(opt.max_session):
pe = pe.half() # opt.sessionid=k
vae.vae = vae.vae.half() # nerfreal = MuseReal(opt,audio_processor,vae, unet, pe,timesteps)
#vae.vae.share_memory() # nerfreals.append(nerfreal)
unet.model = unet.model.half()
#unet.model.share_memory()
for k in range(opt.max_session):
opt.sessionid=k
nerfreal = MuseReal(opt,audio_processor,vae, unet, pe,timesteps)
nerfreals.append(nerfreal)
elif opt.model == 'wav2lip': elif opt.model == 'wav2lip':
from lipreal import LipReal,load_model from lipreal import LipReal,load_model,load_avatar
print(opt) print(opt)
model = load_model("./models/wav2lip.pth") model = load_model("./models/wav2lip.pth")
for k in range(opt.max_session): avatar = load_avatar(opt.avatar_id)
opt.sessionid=k # for k in range(opt.max_session):
nerfreal = LipReal(opt,model) # opt.sessionid=k
nerfreals.append(nerfreal) # nerfreal = LipReal(opt,model)
# nerfreals.append(nerfreal)
for _ in range(opt.max_session):
statreals.append(0)
if opt.transport=='rtmp': if opt.transport=='rtmp':
thread_quit = Event() thread_quit = Event()
nerfreals[0] = build_nerfreal(0)
rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,)) rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
rendthrd.start() rendthrd.start()
@ -594,7 +521,11 @@ if __name__ == '__main__':
site = web.TCPSite(runner, '0.0.0.0', opt.listenport) site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
loop.run_until_complete(site.start()) loop.run_until_complete(site.start())
if opt.transport=='rtcpush': if opt.transport=='rtcpush':
loop.run_until_complete(run(opt.push_url)) for k in range(opt.max_session):
push_url = opt.push_url
if k!=0:
push_url = opt.push_url+str(k)
loop.run_until_complete(run(push_url,k))
loop.run_forever() loop.run_forever()
#Thread(target=run_server, args=(web.AppRunner(appasync),)).start() #Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
run_server(web.AppRunner(appasync)) run_server(web.AppRunner(appasync))

@ -44,7 +44,7 @@ class BaseASR:
#self.warm_up() #self.warm_up()
def pause_talk(self): def flush_talk(self):
self.queue.queue.clear() self.queue.queue.clear()
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm def put_audio_frame(self,audio_chunk): #16khz 20ms pcm

@ -109,9 +109,9 @@ class BaseReal:
return stream return stream
def pause_talk(self): def flush_talk(self):
self.tts.pause_talk() self.tts.flush_talk()
self.asr.pause_talk() self.asr.flush_talk()
def is_speaking(self)->bool: def is_speaking(self)->bool:
return self.speaking return self.speaking

@ -67,6 +67,24 @@ def load_model(path):
model = model.to(device) model = model.to(device)
return model.eval() return model.eval()
def load_avatar(avatar_id):
avatar_path = f"./data/avatars/{avatar_id}"
full_imgs_path = f"{avatar_path}/full_imgs"
face_imgs_path = f"{avatar_path}/face_imgs"
coords_path = f"{avatar_path}/coords.pkl"
with open(coords_path, 'rb') as f:
coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
frame_list_cycle = read_imgs(input_img_list)
#self.imagecache = ImgCache(len(self.coord_list_cycle),self.full_imgs_path,1000)
input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]'))
input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
face_list_cycle = read_imgs(input_face_list)
return frame_list_cycle,face_list_cycle,coord_list_cycle
def read_imgs(img_list): def read_imgs(img_list):
frames = [] frames = []
print('reading images...') print('reading images...')
@ -156,7 +174,7 @@ def inference(quit_event,batch_size,face_list_cycle,audio_feat_queue,audio_out_q
class LipReal(BaseReal): class LipReal(BaseReal):
@torch.no_grad() @torch.no_grad()
def __init__(self, opt, model): def __init__(self, opt, model, avatar):
super().__init__(opt) super().__init__(opt)
#self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. #self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W self.W = opt.W
@ -164,35 +182,21 @@ class LipReal(BaseReal):
self.fps = opt.fps # 20 ms per frame self.fps = opt.fps # 20 ms per frame
#### musetalk
self.avatar_id = opt.avatar_id
self.avatar_path = f"./data/avatars/{self.avatar_id}"
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
self.face_imgs_path = f"{self.avatar_path}/face_imgs"
self.coords_path = f"{self.avatar_path}/coords.pkl"
self.batch_size = opt.batch_size self.batch_size = opt.batch_size
self.idx = 0 self.idx = 0
self.res_frame_queue = Queue(self.batch_size*2) #mp.Queue self.res_frame_queue = Queue(self.batch_size*2) #mp.Queue
#self.__loadmodels() #self.__loadavatar()
self.__loadavatar() self.model = model
self.frame_list_cycle,self.face_list_cycle,self.coord_list_cycle = avatar
self.asr = LipASR(opt,self) self.asr = LipASR(opt,self)
self.asr.warm_up() self.asr.warm_up()
#self.__warm_up() #self.__warm_up()
self.model = model
self.render_event = mp.Event() self.render_event = mp.Event()
def __loadavatar(self): def __del__(self):
with open(self.coords_path, 'rb') as f: print(f'lipreal({self.sessionid}) delete')
self.coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.frame_list_cycle = read_imgs(input_img_list)
#self.imagecache = ImgCache(len(self.coord_list_cycle),self.full_imgs_path,1000)
input_face_list = glob.glob(os.path.join(self.face_imgs_path, '*.[jpJP][pnPN]*[gG]'))
input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.face_list_cycle = read_imgs(input_face_list)
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):

@ -46,6 +46,49 @@ from av import AudioFrame, VideoFrame
from basereal import BaseReal from basereal import BaseReal
from tqdm import tqdm from tqdm import tqdm
def load_model():
# load model weights
audio_processor,vae, unet, pe = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
pe = pe.half()
vae.vae = vae.vae.half()
#vae.vae.share_memory()
unet.model = unet.model.half()
#unet.model.share_memory()
return vae, unet, pe, timesteps, audio_processor
def load_avatar(avatar_id):
#self.video_path = '' #video_path
#self.bbox_shift = opt.bbox_shift
avatar_path = f"./data/avatars/{avatar_id}"
full_imgs_path = f"{avatar_path}/full_imgs"
coords_path = f"{avatar_path}/coords.pkl"
latents_out_path= f"{avatar_path}/latents.pt"
video_out_path = f"{avatar_path}/vid_output/"
mask_out_path =f"{avatar_path}/mask"
mask_coords_path =f"{avatar_path}/mask_coords.pkl"
avatar_info_path = f"{avatar_path}/avator_info.json"
# self.avatar_info = {
# "avatar_id":self.avatar_id,
# "video_path":self.video_path,
# "bbox_shift":self.bbox_shift
# }
input_latent_list_cycle = torch.load(latents_out_path,weights_only=True)
with open(coords_path, 'rb') as f:
coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
frame_list_cycle = read_imgs(input_img_list)
with open(mask_coords_path, 'rb') as f:
mask_coords_list_cycle = pickle.load(f)
input_mask_list = glob.glob(os.path.join(mask_out_path, '*.[jpJP][pnPN]*[gG]'))
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
mask_list_cycle = read_imgs(input_mask_list)
return frame_list_cycle,mask_list_cycle,coord_list_cycle,mask_coords_list_cycle,input_latent_list_cycle
def read_imgs(img_list): def read_imgs(img_list):
frames = [] frames = []
print('reading images...') print('reading images...')
@ -145,7 +188,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a
class MuseReal(BaseReal): class MuseReal(BaseReal):
@torch.no_grad() @torch.no_grad()
def __init__(self, opt, audio_processor:Audio2Feature,vae, unet, pe,timesteps): def __init__(self, opt, model, avatar):
super().__init__(opt) super().__init__(opt)
#self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. #self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W self.W = opt.W
@ -153,63 +196,22 @@ class MuseReal(BaseReal):
self.fps = opt.fps # 20 ms per frame self.fps = opt.fps # 20 ms per frame
#### musetalk
self.avatar_id = opt.avatar_id
self.video_path = '' #video_path
self.bbox_shift = opt.bbox_shift
self.avatar_path = f"./data/avatars/{self.avatar_id}"
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
self.coords_path = f"{self.avatar_path}/coords.pkl"
self.latents_out_path= f"{self.avatar_path}/latents.pt"
self.video_out_path = f"{self.avatar_path}/vid_output/"
self.mask_out_path =f"{self.avatar_path}/mask"
self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl"
self.avatar_info_path = f"{self.avatar_path}/avator_info.json"
self.avatar_info = {
"avatar_id":self.avatar_id,
"video_path":self.video_path,
"bbox_shift":self.bbox_shift
}
self.batch_size = opt.batch_size self.batch_size = opt.batch_size
self.idx = 0 self.idx = 0
self.res_frame_queue = mp.Queue(self.batch_size*2) self.res_frame_queue = mp.Queue(self.batch_size*2)
#self.__loadmodels()
self.audio_processor= audio_processor self.vae, self.unet, self.pe, self.timesteps, self.audio_processor = model
self.__loadavatar() self.frame_list_cycle,self.mask_list_cycle,self.coord_list_cycle,self.mask_coords_list_cycle, self.input_latent_list_cycle = avatar
#self.__loadavatar()
self.asr = MuseASR(opt,self,self.audio_processor) self.asr = MuseASR(opt,self,self.audio_processor)
self.asr.warm_up() self.asr.warm_up()
#self.__warm_up() #self.__warm_up()
self.render_event = mp.Event() self.render_event = mp.Event()
self.vae = vae
self.unet = unet def __del__(self):
self.pe = pe print(f'musereal({self.sessionid}) delete')
self.timesteps = timesteps
# def __loadmodels(self):
# # load model weights
# self.audio_processor= load_audio_model()
# self.audio_processor, self.vae, self.unet, self.pe = load_all_model()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.timesteps = torch.tensor([0], device=device)
# self.pe = self.pe.half()
# self.vae.vae = self.vae.vae.half()
# self.unet.model = self.unet.model.half()
def __loadavatar(self):
self.input_latent_list_cycle = torch.load(self.latents_out_path,weights_only=True)
with open(self.coords_path, 'rb') as f:
self.coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.frame_list_cycle = read_imgs(input_img_list)
with open(self.mask_coords_path, 'rb') as f:
self.mask_coords_list_cycle = pickle.load(f)
input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.mask_list_cycle = read_imgs(input_mask_list)
def __mirror_index(self, index): def __mirror_index(self, index):

@ -33,6 +33,10 @@ from av import AudioFrame, VideoFrame
from basereal import BaseReal from basereal import BaseReal
#from imgcache import ImgCache #from imgcache import ImgCache
from ernerf.nerf_triplane.provider import NeRFDataset_Test
from ernerf.nerf_triplane.utils import *
from ernerf.nerf_triplane.network import NeRFNetwork
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
from tqdm import tqdm from tqdm import tqdm
def read_imgs(img_list): def read_imgs(img_list):
@ -43,15 +47,76 @@ def read_imgs(img_list):
frames.append(frame) frames.append(frame)
return frames return frames
def load_model(opt):
# assert test mode
opt.test = True
opt.test_train = False
#opt.train_camera =True
# explicit smoothing
opt.smooth_path = True
opt.smooth_lips = True
assert opt.pose != '', 'Must provide a pose source'
# if opt.O:
opt.fp16 = True
opt.cuda_ray = True
opt.exp_eye = True
opt.smooth_eye = True
if opt.torso_imgs=='': #no img,use model output
opt.torso = True
# assert opt.cuda_ray, "Only support CUDA ray mode."
opt.asr = True
if opt.patch_size > 1:
# assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."
seed_everything(opt.seed)
print(opt)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = NeRFNetwork(opt)
criterion = torch.nn.MSELoss(reduction='none')
metrics = [] # use no metric in GUI for faster initialization...
print(model)
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
test_loader = NeRFDataset_Test(opt, device=device).dataloader()
model.aud_features = test_loader._data.auds
model.eye_areas = test_loader._data.eye_area
print(f'[INFO] loading ASR model {opt.asr_model}...')
if 'hubert' in opt.asr_model:
audio_processor = Wav2Vec2Processor.from_pretrained(opt.asr_model)
audio_model = HubertModel.from_pretrained(opt.asr_model).to(device)
else:
audio_processor = AutoProcessor.from_pretrained(opt.asr_model)
audio_model = AutoModelForCTC.from_pretrained(opt.asr_model).to(device)
return trainer,test_loader,audio_processor,audio_model
def load_avatar(opt):
fullbody_list_cycle = None
if opt.fullbody:
input_img_list = glob.glob(os.path.join(self.opt.fullbody_img, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
#print('input_img_list:',input_img_list)
fullbody_list_cycle = read_imgs(input_img_list) #[:frame_total_num]
#self.imagecache = ImgCache(frame_total_num,self.opt.fullbody_img,1000)
return fullbody_list_cycle
class NeRFReal(BaseReal): class NeRFReal(BaseReal):
def __init__(self, opt, trainer, data_loader, audio_processor,audio_model, debug=True): def __init__(self, opt, model,avatar, debug=True):
super().__init__(opt) super().__init__(opt)
#self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. #self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W self.W = opt.W
self.H = opt.H self.H = opt.H
self.trainer = trainer #self.trainer = trainer
self.data_loader = data_loader #self.data_loader = data_loader
self.trainer, self.data_loader, audio_processor,audio_model = model
# use dataloader's bg # use dataloader's bg
#bg_img = data_loader._data.bg_img #.view(1, -1, 3) #bg_img = data_loader._data.bg_img #.view(1, -1, 3)
@ -70,14 +135,10 @@ class NeRFReal(BaseReal):
#self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item() #self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
# playing seq from dataloader, or pause. # playing seq from dataloader, or pause.
self.loader = iter(data_loader) self.loader = iter(self.data_loader)
frame_total_num = data_loader._data.end_index frame_total_num = self.data_loader._data.end_index
if opt.fullbody: self.fullbody_list_cycle = avatar
input_img_list = glob.glob(os.path.join(self.opt.fullbody_img, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
#print('input_img_list:',input_img_list)
self.fullbody_list_cycle = read_imgs(input_img_list[:frame_total_num])
#self.imagecache = ImgCache(frame_total_num,self.opt.fullbody_img,1000)
#self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) #self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
#self.need_update = True # camera moved, should reset accumulation #self.need_update = True # camera moved, should reset accumulation
@ -135,6 +196,8 @@ class NeRFReal(BaseReal):
#self.test_step() #self.test_step()
''' '''
def __del__(self):
print(f'nerfreal({self.sessionid}) delete')
def __enter__(self): def __enter__(self):
return self return self

@ -49,7 +49,7 @@ class BaseTTS:
self.msgqueue = Queue() self.msgqueue = Queue()
self.state = State.RUNNING self.state = State.RUNNING
def pause_talk(self): def flush_talk(self):
self.msgqueue.queue.clear() self.msgqueue.queue.clear()
self.state = State.PAUSE self.state = State.PAUSE

@ -30,6 +30,10 @@
</div> </div>
<button id="start" onclick="start()">Start</button> <button id="start" onclick="start()">Start</button>
<button id="stop" style="display: none" onclick="stop()">Stop</button> <button id="stop" style="display: none" onclick="stop()">Stop</button>
<button class="btn btn-primary" id="btn_start_record">Start Recording</button>
<button class="btn btn-primary" id="btn_stop_record" disabled>Stop Recording</button>
<!-- <button class="btn btn-primary" id="btn_download">Download Video</button> -->
<input type="hidden" id="sessionid" value="0">
<form class="form-inline" id="echo-form"> <form class="form-inline" id="echo-form">
<div class="form-group"> <div class="form-group">
<p>input text</p> <p>input text</p>
@ -75,11 +79,13 @@
e.preventDefault(); e.preventDefault();
var message = $('#message').val(); var message = $('#message').val();
console.log('Sending: ' + message); console.log('Sending: ' + message);
console.log('sessionid: ',document.getElementById('sessionid').value);
fetch('/human', { fetch('/human', {
body: JSON.stringify({ body: JSON.stringify({
text: message, text: message,
type: 'chat', type: 'chat',
interrupt: true, interrupt: true,
sessionid:parseInt(document.getElementById('sessionid').value),
}), }),
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
@ -89,6 +95,90 @@
//ws.send(message); //ws.send(message);
$('#message').val(''); $('#message').val('');
}); });
$('#btn_start_record').click(function() {
// 开始录制
console.log('Starting recording...');
fetch('/record', {
body: JSON.stringify({
type: 'start_record',
}),
headers: {
'Content-Type': 'application/json'
},
method: 'POST'
}).then(function(response) {
if (response.ok) {
console.log('Recording started.');
$('#btn_start_record').prop('disabled', true);
$('#btn_stop_record').prop('disabled', false);
// $('#btn_download').prop('disabled', true);
} else {
console.error('Failed to start recording.');
}
}).catch(function(error) {
console.error('Error:', error);
});
});
$('#btn_stop_record').click(function() {
// 结束录制
console.log('Stopping recording...');
fetch('/record', {
body: JSON.stringify({
type: 'end_record',
}),
headers: {
'Content-Type': 'application/json'
},
method: 'POST'
}).then(function(response) {
if (response.ok) {
console.log('Recording stopped.');
$('#btn_start_record').prop('disabled', false);
$('#btn_stop_record').prop('disabled', true);
// $('#btn_download').prop('disabled', false);
} else {
console.error('Failed to stop recording.');
}
}).catch(function(error) {
console.error('Error:', error);
});
});
// $('#btn_download').click(function() {
// // 下载视频文件
// console.log('Downloading video...');
// fetch('/record_lasted.mp4', {
// method: 'GET'
// }).then(function(response) {
// if (response.ok) {
// return response.blob();
// } else {
// throw new Error('Failed to download the video.');
// }
// }).then(function(blob) {
// // 创建一个 Blob 对象
// const url = window.URL.createObjectURL(blob);
// // 创建一个隐藏的可下载链接
// const a = document.createElement('a');
// a.style.display = 'none';
// a.href = url;
// a.download = 'record_lasted.mp4';
// document.body.appendChild(a);
// // 触发下载
// a.click();
// // 清理
// window.URL.revokeObjectURL(url);
// document.body.removeChild(a);
// console.log('Video downloaded successfully.');
// }).catch(function(error) {
// console.error('Error:', error);
// });
// });
}); });
</script> </script>
</html> </html>

Loading…
Cancel
Save