diff --git a/app.py b/app.py index 05a696b..3cd81d3 100644 --- a/app.py +++ b/app.py @@ -21,9 +21,9 @@ from flask_sockets import Sockets import base64 import time import json -import gevent -from gevent import pywsgi -from geventwebsocket.handler import WebSocketHandler +#import gevent +#from gevent import pywsgi +#from geventwebsocket.handler import WebSocketHandler import os import re import numpy as np @@ -39,6 +39,7 @@ from aiortc.rtcrtpsender import RTCRtpSender from webrtc import HumanPlayer import argparse +import random import shutil import asyncio @@ -46,29 +47,11 @@ import torch app = Flask(__name__) -sockets = Sockets(app) -nerfreals = [] -statreals = [] - - -@sockets.route('/humanecho') -def echo_socket(ws): - # 获取WebSocket对象 - #ws = request.environ.get('wsgi.websocket') - # 如果没有获取到,返回错误信息 - if not ws: - print('未建立连接!') - return 'Please use WebSocket' - # 否则,循环接收和发送消息 - else: - print('建立连接!') - while True: - message = ws.receive() - - if not message or len(message)==0: - return '输入信息为空' - else: - nerfreal.put_msg_txt(message) +#sockets = Sockets(app) +nerfreals = {} +opt = None +model = None +avatar = None # def llm_response(message): @@ -124,43 +107,41 @@ def llm_response(message,nerfreal): print(f"llm Time to last chunk: {end-start}s") nerfreal.put_msg_txt(result) -@sockets.route('/humanchat') -def chat_socket(ws): - # 获取WebSocket对象 - #ws = request.environ.get('wsgi.websocket') - # 如果没有获取到,返回错误信息 - if not ws: - print('未建立连接!') - return 'Please use WebSocket' - # 否则,循环接收和发送消息 - else: - print('建立连接!') - while True: - message = ws.receive() - - if len(message)==0: - return '输入信息为空' - else: - res=llm_response(message) - nerfreal.put_msg_txt(res) - #####webrtc############################### pcs = set() +def randN(N): + '''生成长度为 N的随机数 ''' + min = pow(10, N - 1) + max = pow(10, N) + return random.randint(min, max - 1) + +def build_nerfreal(sessionid): + opt.sessionid=sessionid + if opt.model == 'wav2lip': + from lipreal import LipReal + nerfreal = LipReal(opt,model,avatar) + elif opt.model == 'musetalk': + from musereal import MuseReal + nerfreal = MuseReal(opt,model,avatar) + elif opt.model == 'ernerf': + from nerfreal import NeRFReal + nerfreal = NeRFReal(opt,model,avatar) + return nerfreal + #@app.route('/offer', methods=['POST']) async def offer(request): params = await request.json() offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) - sessionid = len(nerfreals) - for index,value in enumerate(statreals): - if value == 0: - sessionid = index - break - if sessionid>=len(nerfreals): + if len(nerfreals) >= opt.max_session: print('reach max session') return -1 - statreals[sessionid] = 1 + sessionid = randN(6) #len(nerfreals) + print('sessionid=',sessionid) + nerfreals[sessionid] = None + nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid) + nerfreals[sessionid] = nerfreal pc = RTCPeerConnection() pcs.add(pc) @@ -171,10 +152,10 @@ async def offer(request): if pc.connectionState == "failed": await pc.close() pcs.discard(pc) - statreals[sessionid] = 0 + del nerfreals[sessionid] if pc.connectionState == "closed": pcs.discard(pc) - statreals[sessionid] = 0 + del nerfreals[sessionid] player = HumanPlayer(nerfreals[sessionid]) audio_sender = pc.addTrack(player.audio) @@ -205,7 +186,7 @@ async def human(request): sessionid = params.get('sessionid',0) if params.get('interrupt'): - nerfreals[sessionid].pause_talk() + nerfreals[sessionid].flush_talk() if params['type']=='echo': nerfreals[sessionid].put_msg_txt(params['text']) @@ -298,7 +279,10 @@ async def post(url,data): except aiohttp.ClientError as e: print(f'Error: {e}') -async def run(push_url): +async def run(push_url,sessionid): + nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid) + nerfreals[sessionid] = nerfreal + pc = RTCPeerConnection() pcs.add(pc) @@ -309,7 +293,7 @@ async def run(push_url): await pc.close() pcs.discard(pc) - player = HumanPlayer(nerfreals[0]) + player = HumanPlayer(nerfreals[sessionid]) audio_sender = pc.addTrack(player.audio) video_sender = pc.addTrack(player.video) @@ -466,95 +450,38 @@ if __name__ == '__main__': with open(opt.customvideo_config,'r') as file: opt.customopt = json.load(file) - if opt.model == 'ernerf': - from ernerf.nerf_triplane.provider import NeRFDataset_Test - from ernerf.nerf_triplane.utils import * - from ernerf.nerf_triplane.network import NeRFNetwork - from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel - from nerfreal import NeRFReal - # assert test mode - opt.test = True - opt.test_train = False - #opt.train_camera =True - # explicit smoothing - opt.smooth_path = True - opt.smooth_lips = True - - assert opt.pose != '', 'Must provide a pose source' - - # if opt.O: - opt.fp16 = True - opt.cuda_ray = True - opt.exp_eye = True - opt.smooth_eye = True - - if opt.torso_imgs=='': #no img,use model output - opt.torso = True - - # assert opt.cuda_ray, "Only support CUDA ray mode." - opt.asr = True - - if opt.patch_size > 1: - # assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss." - assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays." - seed_everything(opt.seed) - print(opt) - - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - model = NeRFNetwork(opt) - - criterion = torch.nn.MSELoss(reduction='none') - metrics = [] # use no metric in GUI for faster initialization... - print(model) - trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt) - - test_loader = NeRFDataset_Test(opt, device=device).dataloader() - model.aud_features = test_loader._data.auds - model.eye_areas = test_loader._data.eye_area - - print(f'[INFO] loading ASR model {opt.asr_model}...') - if 'hubert' in opt.asr_model: - audio_processor = Wav2Vec2Processor.from_pretrained(opt.asr_model) - audio_model = HubertModel.from_pretrained(opt.asr_model).to(device) - else: - audio_processor = AutoProcessor.from_pretrained(opt.asr_model) - audio_model = AutoModelForCTC.from_pretrained(opt.asr_model).to(device) - + if opt.model == 'ernerf': + from nerfreal import NeRFReal,load_model,load_avatar + model = load_model(opt) + avatar = load_avatar(opt) + # we still need test_loader to provide audio features for testing. - for k in range(opt.max_session): - opt.sessionid=k - nerfreal = NeRFReal(opt, trainer, test_loader,audio_processor,audio_model) - nerfreals.append(nerfreal) + # for k in range(opt.max_session): + # opt.sessionid=k + # nerfreal = NeRFReal(opt, trainer, test_loader,audio_processor,audio_model) + # nerfreals.append(nerfreal) elif opt.model == 'musetalk': - from musereal import MuseReal - from musetalk.utils.utils import load_all_model + from musereal import MuseReal,load_model,load_avatar print(opt) - audio_processor,vae, unet, pe = load_all_model() - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - timesteps = torch.tensor([0], device=device) - pe = pe.half() - vae.vae = vae.vae.half() - #vae.vae.share_memory() - unet.model = unet.model.half() - #unet.model.share_memory() - for k in range(opt.max_session): - opt.sessionid=k - nerfreal = MuseReal(opt,audio_processor,vae, unet, pe,timesteps) - nerfreals.append(nerfreal) + model = load_model() + avatar = load_avatar(opt.avatar_id) + # for k in range(opt.max_session): + # opt.sessionid=k + # nerfreal = MuseReal(opt,audio_processor,vae, unet, pe,timesteps) + # nerfreals.append(nerfreal) elif opt.model == 'wav2lip': - from lipreal import LipReal,load_model + from lipreal import LipReal,load_model,load_avatar print(opt) model = load_model("./models/wav2lip.pth") - for k in range(opt.max_session): - opt.sessionid=k - nerfreal = LipReal(opt,model) - nerfreals.append(nerfreal) - - for _ in range(opt.max_session): - statreals.append(0) + avatar = load_avatar(opt.avatar_id) + # for k in range(opt.max_session): + # opt.sessionid=k + # nerfreal = LipReal(opt,model) + # nerfreals.append(nerfreal) if opt.transport=='rtmp': thread_quit = Event() + nerfreals[0] = build_nerfreal(0) rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,)) rendthrd.start() @@ -594,7 +521,11 @@ if __name__ == '__main__': site = web.TCPSite(runner, '0.0.0.0', opt.listenport) loop.run_until_complete(site.start()) if opt.transport=='rtcpush': - loop.run_until_complete(run(opt.push_url)) + for k in range(opt.max_session): + push_url = opt.push_url + if k!=0: + push_url = opt.push_url+str(k) + loop.run_until_complete(run(push_url,k)) loop.run_forever() #Thread(target=run_server, args=(web.AppRunner(appasync),)).start() run_server(web.AppRunner(appasync)) diff --git a/baseasr.py b/baseasr.py index 827bf90..d105cee 100644 --- a/baseasr.py +++ b/baseasr.py @@ -44,7 +44,7 @@ class BaseASR: #self.warm_up() - def pause_talk(self): + def flush_talk(self): self.queue.queue.clear() def put_audio_frame(self,audio_chunk): #16khz 20ms pcm diff --git a/basereal.py b/basereal.py index edc07c7..1a23657 100644 --- a/basereal.py +++ b/basereal.py @@ -109,9 +109,9 @@ class BaseReal: return stream - def pause_talk(self): - self.tts.pause_talk() - self.asr.pause_talk() + def flush_talk(self): + self.tts.flush_talk() + self.asr.flush_talk() def is_speaking(self)->bool: return self.speaking diff --git a/lipreal.py b/lipreal.py index 05beade..fea8700 100644 --- a/lipreal.py +++ b/lipreal.py @@ -67,6 +67,24 @@ def load_model(path): model = model.to(device) return model.eval() +def load_avatar(avatar_id): + avatar_path = f"./data/avatars/{avatar_id}" + full_imgs_path = f"{avatar_path}/full_imgs" + face_imgs_path = f"{avatar_path}/face_imgs" + coords_path = f"{avatar_path}/coords.pkl" + + with open(coords_path, 'rb') as f: + coord_list_cycle = pickle.load(f) + input_img_list = glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]')) + input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + frame_list_cycle = read_imgs(input_img_list) + #self.imagecache = ImgCache(len(self.coord_list_cycle),self.full_imgs_path,1000) + input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]')) + input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + face_list_cycle = read_imgs(input_face_list) + + return frame_list_cycle,face_list_cycle,coord_list_cycle + def read_imgs(img_list): frames = [] print('reading images...') @@ -156,45 +174,31 @@ def inference(quit_event,batch_size,face_list_cycle,audio_feat_queue,audio_out_q class LipReal(BaseReal): @torch.no_grad() - def __init__(self, opt, model): + def __init__(self, opt, model, avatar): super().__init__(opt) #self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. self.W = opt.W self.H = opt.H self.fps = opt.fps # 20 ms per frame - - #### musetalk - self.avatar_id = opt.avatar_id - self.avatar_path = f"./data/avatars/{self.avatar_id}" - self.full_imgs_path = f"{self.avatar_path}/full_imgs" - self.face_imgs_path = f"{self.avatar_path}/face_imgs" - self.coords_path = f"{self.avatar_path}/coords.pkl" + self.batch_size = opt.batch_size self.idx = 0 self.res_frame_queue = Queue(self.batch_size*2) #mp.Queue - #self.__loadmodels() - self.__loadavatar() + #self.__loadavatar() + self.model = model + self.frame_list_cycle,self.face_list_cycle,self.coord_list_cycle = avatar self.asr = LipASR(opt,self) self.asr.warm_up() #self.__warm_up() - self.model = model self.render_event = mp.Event() + + def __del__(self): + print(f'lipreal({self.sessionid}) delete') - def __loadavatar(self): - with open(self.coords_path, 'rb') as f: - self.coord_list_cycle = pickle.load(f) - input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')) - input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) - self.frame_list_cycle = read_imgs(input_img_list) - #self.imagecache = ImgCache(len(self.coord_list_cycle),self.full_imgs_path,1000) - input_face_list = glob.glob(os.path.join(self.face_imgs_path, '*.[jpJP][pnPN]*[gG]')) - input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) - self.face_list_cycle = read_imgs(input_face_list) - - + def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): while not quit_event.is_set(): diff --git a/musereal.py b/musereal.py index 6891705..a05e5fc 100644 --- a/musereal.py +++ b/musereal.py @@ -46,6 +46,49 @@ from av import AudioFrame, VideoFrame from basereal import BaseReal from tqdm import tqdm + +def load_model(): + # load model weights + audio_processor,vae, unet, pe = load_all_model() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + timesteps = torch.tensor([0], device=device) + pe = pe.half() + vae.vae = vae.vae.half() + #vae.vae.share_memory() + unet.model = unet.model.half() + #unet.model.share_memory() + return vae, unet, pe, timesteps, audio_processor + +def load_avatar(avatar_id): + #self.video_path = '' #video_path + #self.bbox_shift = opt.bbox_shift + avatar_path = f"./data/avatars/{avatar_id}" + full_imgs_path = f"{avatar_path}/full_imgs" + coords_path = f"{avatar_path}/coords.pkl" + latents_out_path= f"{avatar_path}/latents.pt" + video_out_path = f"{avatar_path}/vid_output/" + mask_out_path =f"{avatar_path}/mask" + mask_coords_path =f"{avatar_path}/mask_coords.pkl" + avatar_info_path = f"{avatar_path}/avator_info.json" + # self.avatar_info = { + # "avatar_id":self.avatar_id, + # "video_path":self.video_path, + # "bbox_shift":self.bbox_shift + # } + + input_latent_list_cycle = torch.load(latents_out_path,weights_only=True) + with open(coords_path, 'rb') as f: + coord_list_cycle = pickle.load(f) + input_img_list = glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]')) + input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + frame_list_cycle = read_imgs(input_img_list) + with open(mask_coords_path, 'rb') as f: + mask_coords_list_cycle = pickle.load(f) + input_mask_list = glob.glob(os.path.join(mask_out_path, '*.[jpJP][pnPN]*[gG]')) + input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + mask_list_cycle = read_imgs(input_mask_list) + return frame_list_cycle,mask_list_cycle,coord_list_cycle,mask_coords_list_cycle,input_latent_list_cycle + def read_imgs(img_list): frames = [] print('reading images...') @@ -145,7 +188,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a class MuseReal(BaseReal): @torch.no_grad() - def __init__(self, opt, audio_processor:Audio2Feature,vae, unet, pe,timesteps): + def __init__(self, opt, model, avatar): super().__init__(opt) #self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. self.W = opt.W @@ -153,63 +196,22 @@ class MuseReal(BaseReal): self.fps = opt.fps # 20 ms per frame - #### musetalk - self.avatar_id = opt.avatar_id - self.video_path = '' #video_path - self.bbox_shift = opt.bbox_shift - self.avatar_path = f"./data/avatars/{self.avatar_id}" - self.full_imgs_path = f"{self.avatar_path}/full_imgs" - self.coords_path = f"{self.avatar_path}/coords.pkl" - self.latents_out_path= f"{self.avatar_path}/latents.pt" - self.video_out_path = f"{self.avatar_path}/vid_output/" - self.mask_out_path =f"{self.avatar_path}/mask" - self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl" - self.avatar_info_path = f"{self.avatar_path}/avator_info.json" - self.avatar_info = { - "avatar_id":self.avatar_id, - "video_path":self.video_path, - "bbox_shift":self.bbox_shift - } self.batch_size = opt.batch_size self.idx = 0 self.res_frame_queue = mp.Queue(self.batch_size*2) - #self.__loadmodels() - self.audio_processor= audio_processor - self.__loadavatar() + + self.vae, self.unet, self.pe, self.timesteps, self.audio_processor = model + self.frame_list_cycle,self.mask_list_cycle,self.coord_list_cycle,self.mask_coords_list_cycle, self.input_latent_list_cycle = avatar + #self.__loadavatar() self.asr = MuseASR(opt,self,self.audio_processor) self.asr.warm_up() #self.__warm_up() self.render_event = mp.Event() - self.vae = vae - self.unet = unet - self.pe = pe - self.timesteps = timesteps - - - # def __loadmodels(self): - # # load model weights - # self.audio_processor= load_audio_model() - # self.audio_processor, self.vae, self.unet, self.pe = load_all_model() - # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # self.timesteps = torch.tensor([0], device=device) - # self.pe = self.pe.half() - # self.vae.vae = self.vae.vae.half() - # self.unet.model = self.unet.model.half() - def __loadavatar(self): - self.input_latent_list_cycle = torch.load(self.latents_out_path,weights_only=True) - with open(self.coords_path, 'rb') as f: - self.coord_list_cycle = pickle.load(f) - input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')) - input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) - self.frame_list_cycle = read_imgs(input_img_list) - with open(self.mask_coords_path, 'rb') as f: - self.mask_coords_list_cycle = pickle.load(f) - input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]')) - input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) - self.mask_list_cycle = read_imgs(input_mask_list) + def __del__(self): + print(f'musereal({self.sessionid}) delete') def __mirror_index(self, index): diff --git a/nerfreal.py b/nerfreal.py index 8c34851..f28d759 100644 --- a/nerfreal.py +++ b/nerfreal.py @@ -33,6 +33,10 @@ from av import AudioFrame, VideoFrame from basereal import BaseReal #from imgcache import ImgCache +from ernerf.nerf_triplane.provider import NeRFDataset_Test +from ernerf.nerf_triplane.utils import * +from ernerf.nerf_triplane.network import NeRFNetwork +from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel from tqdm import tqdm def read_imgs(img_list): @@ -43,15 +47,76 @@ def read_imgs(img_list): frames.append(frame) return frames +def load_model(opt): + # assert test mode + opt.test = True + opt.test_train = False + #opt.train_camera =True + # explicit smoothing + opt.smooth_path = True + opt.smooth_lips = True + + assert opt.pose != '', 'Must provide a pose source' + + # if opt.O: + opt.fp16 = True + opt.cuda_ray = True + opt.exp_eye = True + opt.smooth_eye = True + + if opt.torso_imgs=='': #no img,use model output + opt.torso = True + + # assert opt.cuda_ray, "Only support CUDA ray mode." + opt.asr = True + + if opt.patch_size > 1: + # assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss." + assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays." + seed_everything(opt.seed) + print(opt) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = NeRFNetwork(opt) + + criterion = torch.nn.MSELoss(reduction='none') + metrics = [] # use no metric in GUI for faster initialization... + print(model) + trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt) + + test_loader = NeRFDataset_Test(opt, device=device).dataloader() + model.aud_features = test_loader._data.auds + model.eye_areas = test_loader._data.eye_area + + print(f'[INFO] loading ASR model {opt.asr_model}...') + if 'hubert' in opt.asr_model: + audio_processor = Wav2Vec2Processor.from_pretrained(opt.asr_model) + audio_model = HubertModel.from_pretrained(opt.asr_model).to(device) + else: + audio_processor = AutoProcessor.from_pretrained(opt.asr_model) + audio_model = AutoModelForCTC.from_pretrained(opt.asr_model).to(device) + return trainer,test_loader,audio_processor,audio_model + +def load_avatar(opt): + fullbody_list_cycle = None + if opt.fullbody: + input_img_list = glob.glob(os.path.join(self.opt.fullbody_img, '*.[jpJP][pnPN]*[gG]')) + input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + #print('input_img_list:',input_img_list) + fullbody_list_cycle = read_imgs(input_img_list) #[:frame_total_num] + #self.imagecache = ImgCache(frame_total_num,self.opt.fullbody_img,1000) + return fullbody_list_cycle + class NeRFReal(BaseReal): - def __init__(self, opt, trainer, data_loader, audio_processor,audio_model, debug=True): + def __init__(self, opt, model,avatar, debug=True): super().__init__(opt) #self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. self.W = opt.W self.H = opt.H - self.trainer = trainer - self.data_loader = data_loader + #self.trainer = trainer + #self.data_loader = data_loader + self.trainer, self.data_loader, audio_processor,audio_model = model # use dataloader's bg #bg_img = data_loader._data.bg_img #.view(1, -1, 3) @@ -70,14 +135,10 @@ class NeRFReal(BaseReal): #self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item() # playing seq from dataloader, or pause. - self.loader = iter(data_loader) - frame_total_num = data_loader._data.end_index - if opt.fullbody: - input_img_list = glob.glob(os.path.join(self.opt.fullbody_img, '*.[jpJP][pnPN]*[gG]')) - input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) - #print('input_img_list:',input_img_list) - self.fullbody_list_cycle = read_imgs(input_img_list[:frame_total_num]) - #self.imagecache = ImgCache(frame_total_num,self.opt.fullbody_img,1000) + self.loader = iter(self.data_loader) + frame_total_num = self.data_loader._data.end_index + self.fullbody_list_cycle = avatar + #self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) #self.need_update = True # camera moved, should reset accumulation @@ -134,7 +195,9 @@ class NeRFReal(BaseReal): self.fifo_audio = open(audio_path, 'wb') #self.test_step() ''' - + + def __del__(self): + print(f'nerfreal({self.sessionid}) delete') def __enter__(self): return self diff --git a/ttsreal.py b/ttsreal.py index 636b306..a885394 100644 --- a/ttsreal.py +++ b/ttsreal.py @@ -49,7 +49,7 @@ class BaseTTS: self.msgqueue = Queue() self.state = State.RUNNING - def pause_talk(self): + def flush_talk(self): self.msgqueue.queue.clear() self.state = State.PAUSE diff --git a/web/webrtcchat.html b/web/webrtcchat.html index d389828..16d9bc8 100644 --- a/web/webrtcchat.html +++ b/web/webrtcchat.html @@ -30,6 +30,10 @@ + + + +

input text

@@ -75,11 +79,13 @@ e.preventDefault(); var message = $('#message').val(); console.log('Sending: ' + message); + console.log('sessionid: ',document.getElementById('sessionid').value); fetch('/human', { body: JSON.stringify({ text: message, type: 'chat', interrupt: true, + sessionid:parseInt(document.getElementById('sessionid').value), }), headers: { 'Content-Type': 'application/json' @@ -89,6 +95,90 @@ //ws.send(message); $('#message').val(''); }); + + $('#btn_start_record').click(function() { + // 开始录制 + console.log('Starting recording...'); + fetch('/record', { + body: JSON.stringify({ + type: 'start_record', + }), + headers: { + 'Content-Type': 'application/json' + }, + method: 'POST' + }).then(function(response) { + if (response.ok) { + console.log('Recording started.'); + $('#btn_start_record').prop('disabled', true); + $('#btn_stop_record').prop('disabled', false); + // $('#btn_download').prop('disabled', true); + } else { + console.error('Failed to start recording.'); + } + }).catch(function(error) { + console.error('Error:', error); + }); + }); + + $('#btn_stop_record').click(function() { + // 结束录制 + console.log('Stopping recording...'); + fetch('/record', { + body: JSON.stringify({ + type: 'end_record', + }), + headers: { + 'Content-Type': 'application/json' + }, + method: 'POST' + }).then(function(response) { + if (response.ok) { + console.log('Recording stopped.'); + $('#btn_start_record').prop('disabled', false); + $('#btn_stop_record').prop('disabled', true); + // $('#btn_download').prop('disabled', false); + } else { + console.error('Failed to stop recording.'); + } + }).catch(function(error) { + console.error('Error:', error); + }); + }); + + // $('#btn_download').click(function() { + // // 下载视频文件 + // console.log('Downloading video...'); + // fetch('/record_lasted.mp4', { + // method: 'GET' + // }).then(function(response) { + // if (response.ok) { + // return response.blob(); + // } else { + // throw new Error('Failed to download the video.'); + // } + // }).then(function(blob) { + // // 创建一个 Blob 对象 + // const url = window.URL.createObjectURL(blob); + // // 创建一个隐藏的可下载链接 + // const a = document.createElement('a'); + // a.style.display = 'none'; + // a.href = url; + // a.download = 'record_lasted.mp4'; + // document.body.appendChild(a); + // // 触发下载 + // a.click(); + // // 清理 + // window.URL.revokeObjectURL(url); + // document.body.removeChild(a); + // console.log('Video downloaded successfully.'); + // }).catch(function(error) { + // console.error('Error:', error); + // }); + // }); + }); + +