diff --git a/app.py b/app.py index 2adac66..1d44282 100644 --- a/app.py +++ b/app.py @@ -19,12 +19,10 @@ from flask import Flask, render_template,send_from_directory,request, jsonify from flask_sockets import Sockets import base64 -import time import json #import gevent #from gevent import pywsgi #from geventwebsocket.handler import WebSocketHandler -import os import re import numpy as np from threading import Thread,Event @@ -37,86 +35,36 @@ import aiohttp_cors from aiortc import RTCPeerConnection, RTCSessionDescription from aiortc.rtcrtpsender import RTCRtpSender from webrtc import HumanPlayer +from basereal import BaseReal +from llm import llm_response import argparse import random - import shutil import asyncio import torch +from typing import Dict +from logger import logger app = Flask(__name__) #sockets = Sockets(app) -nerfreals = {} +nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal opt = None model = None avatar = None - - -# def llm_response(message): -# from llm.LLM import LLM -# # llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None) -# # llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key') -# llm = LLM().init_model('VllmGPT', model_path= 'THUDM/chatglm3-6b') -# response = llm.chat(message) -# print(response) -# return response - -def llm_response(message,nerfreal): - start = time.perf_counter() - from openai import OpenAI - client = OpenAI( - # 如果您没有配置环境变量,请在此处用您的API Key进行替换 - api_key=os.getenv("DASHSCOPE_API_KEY"), - # 填写DashScope SDK的base_url - base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", - ) - end = time.perf_counter() - print(f"llm Time init: {end-start}s") - completion = client.chat.completions.create( - model="qwen-plus", - messages=[{'role': 'system', 'content': 'You are a helpful assistant.'}, - {'role': 'user', 'content': message}], - stream=True, - # 通过以下设置,在流式输出的最后一行展示token使用信息 - stream_options={"include_usage": True} - ) - result="" - first = True - for chunk in completion: - if len(chunk.choices)>0: - #print(chunk.choices[0].delta.content) - if first: - end = time.perf_counter() - print(f"llm Time to first chunk: {end-start}s") - first = False - msg = chunk.choices[0].delta.content - lastpos=0 - #msglist = re.split('[,.!;:,。!?]',msg) - for i, char in enumerate(msg): - if char in ",.!;:,。!?:;" : - result = result+msg[lastpos:i+1] - lastpos = i+1 - if len(result)>10: - print(result) - nerfreal.put_msg_txt(result) - result="" - result = result+msg[lastpos:] - end = time.perf_counter() - print(f"llm Time to last chunk: {end-start}s") - nerfreal.put_msg_txt(result) + #####webrtc############################### pcs = set() -def randN(N): +def randN(N)->int: '''生成长度为 N的随机数 ''' min = pow(10, N - 1) max = pow(10, N) return random.randint(min, max - 1) -def build_nerfreal(sessionid): +def build_nerfreal(sessionid:int)->BaseReal: opt.sessionid=sessionid if opt.model == 'wav2lip': from lipreal import LipReal @@ -138,10 +86,10 @@ async def offer(request): offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) if len(nerfreals) >= opt.max_session: - print('reach max session') + logger.info('reach max session') return -1 sessionid = randN(6) #len(nerfreals) - print('sessionid=',sessionid) + logger.info('sessionid=%d',sessionid) nerfreals[sessionid] = None nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid) nerfreals[sessionid] = nerfreal @@ -151,7 +99,7 @@ async def offer(request): @pc.on("connectionstatechange") async def on_connectionstatechange(): - print("Connection state is %s" % pc.connectionState) + logger.info("Connection state is %s" % pc.connectionState) if pc.connectionState == "failed": await pc.close() pcs.discard(pc) @@ -280,7 +228,7 @@ async def post(url,data): async with session.post(url,data=data) as response: return await response.text() except aiohttp.ClientError as e: - print(f'Error: {e}') + logger.info(f'Error: {e}') async def run(push_url,sessionid): nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid) @@ -291,7 +239,7 @@ async def run(push_url,sessionid): @pc.on("connectionstatechange") async def on_connectionstatechange(): - print("Connection state is %s" % pc.connectionState) + logger.info("Connection state is %s" % pc.connectionState) if pc.connectionState == "failed": await pc.close() pcs.discard(pc) @@ -465,7 +413,7 @@ if __name__ == '__main__': # nerfreals.append(nerfreal) elif opt.model == 'musetalk': from musereal import MuseReal,load_model,load_avatar,warm_up - print(opt) + logger.info(opt) model = load_model() avatar = load_avatar(opt.avatar_id) warm_up(opt.batch_size,model) @@ -475,7 +423,7 @@ if __name__ == '__main__': # nerfreals.append(nerfreal) elif opt.model == 'wav2lip': from lipreal import LipReal,load_model,load_avatar,warm_up - print(opt) + logger.info(opt) model = load_model("./models/wav2lip.pth") avatar = load_avatar(opt.avatar_id) warm_up(opt.batch_size,model,256) @@ -485,7 +433,7 @@ if __name__ == '__main__': # nerfreals.append(nerfreal) elif opt.model == 'ultralight': from lightreal import LightReal,load_model,load_avatar,warm_up - print(opt) + logger.info(opt) model = load_model(opt) avatar = load_avatar(opt.avatar_id) warm_up(opt.batch_size,avatar,160) @@ -524,7 +472,7 @@ if __name__ == '__main__': pagename='echoapi.html' elif opt.transport=='rtcpush': pagename='rtcpushapi.html' - print('start http server; http://:'+str(opt.listenport)+'/'+pagename) + logger.info('start http server; http://:'+str(opt.listenport)+'/'+pagename) def run_server(runner): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) diff --git a/baseasr.py b/baseasr.py index 8b353fd..9b11d42 100644 --- a/baseasr.py +++ b/baseasr.py @@ -22,9 +22,11 @@ import queue from queue import Queue import torch.multiprocessing as mp +from basereal import BaseReal + class BaseASR: - def __init__(self, opt, parent=None): + def __init__(self, opt, parent:BaseReal|None = None): self.opt = opt self.parent = parent diff --git a/basereal.py b/basereal.py index 939217b..0f66ad6 100644 --- a/basereal.py +++ b/basereal.py @@ -36,11 +36,12 @@ import av from fractions import Fraction from ttsreal import EdgeTTS,VoitsTTS,XTTS,CosyVoiceTTS,FishTTS +from logger import logger from tqdm import tqdm def read_imgs(img_list): frames = [] - print('reading images...') + logger.info('reading images...') for img_path in tqdm(img_list): frame = cv2.imread(img_path) frames.append(frame) @@ -98,15 +99,15 @@ class BaseReal: def __create_bytes_stream(self,byte_stream): #byte_stream=BytesIO(buffer) stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 - print(f'[INFO]put audio stream {sample_rate}: {stream.shape}') + logger.info(f'[INFO]put audio stream {sample_rate}: {stream.shape}') stream = stream.astype(np.float32) if stream.ndim > 1: - print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') + logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') stream = stream[:, 0] if sample_rate != self.sample_rate and stream.shape[0]>0: - print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') + logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) return stream @@ -120,7 +121,7 @@ class BaseReal: def __loadcustom(self): for item in self.opt.customopt: - print(item) + logger.info(item) input_img_list = glob.glob(os.path.join(item['imgpath'], '*.[jpJP][pnPN]*[gG]')) input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) self.custom_img_cycle[item['audiotype']] = read_imgs(input_img_list) @@ -137,7 +138,7 @@ class BaseReal: self.custom_index[key]=0 def notify(self,eventpoint): - print("notify:",eventpoint) + logger.info("notify:%s",eventpoint) def start_recording(self): """开始录制视频""" diff --git a/lightreal.py b/lightreal.py index b79d5a3..51f544e 100644 --- a/lightreal.py +++ b/lightreal.py @@ -54,11 +54,11 @@ from transformers import Wav2Vec2Processor, HubertModel from torch.utils.data import DataLoader from ultralight.unet import Model from ultralight.audio2feature import Audio2Feature - +from logger import logger device = 'cuda' if torch.cuda.is_available() else 'cpu' -print('Using {} for inference.'.format(device)) +logger.info('Using {} for inference.'.format(device)) def load_model(opt): @@ -89,7 +89,7 @@ def load_avatar(avatar_id): @torch.no_grad() def warm_up(batch_size,avatar,modelres): - print('warmup model...') + logger.info('warmup model...') model,_,_,_ = avatar img_batch = torch.ones(batch_size, 6, modelres, modelres).to(device) mel_batch = torch.ones(batch_size, 32, 32, 32).to(device) @@ -97,7 +97,7 @@ def warm_up(batch_size,avatar,modelres): def read_imgs(img_list): frames = [] - print('reading images...') + logger.info('reading images...') for img_path in tqdm(img_list): frame = cv2.imread(img_path) frames.append(frame) @@ -124,7 +124,7 @@ def get_audio_features(features, index): def read_lms(lms_list): land_marks = [] - print('reading lms...') + logger.info('reading lms...') for lms_path in tqdm(lms_list): file_landmarks = [] # Store landmarks for this file with open(lms_path, "r") as f: @@ -152,7 +152,7 @@ def inference(quit_event, batch_size, face_list_cycle, audio_feat_queue, audio_o index = 0 count = 0 counttime = 0 - print('start inference') + logger.info('start inference') while not quit_event.is_set(): starttime=time.perf_counter() @@ -206,7 +206,7 @@ def inference(quit_event, batch_size, face_list_cycle, audio_feat_queue, audio_o counttime += (time.perf_counter() - t) count += batch_size if count >= 100: - print(f"------actual avg infer fps:{count / counttime:.4f}") + logger.info(f"------actual avg infer fps:{count / counttime:.4f}") count = 0 counttime = 0 for i,res_frame in enumerate(pred): @@ -221,7 +221,7 @@ def inference(quit_event, batch_size, face_list_cycle, audio_feat_queue, audio_o #print('total batch time:', time.perf_counter() - starttime) - print('lightreal inference processor stop') + logger.info('lightreal inference processor stop') class LightReal(BaseReal): @@ -248,7 +248,7 @@ class LightReal(BaseReal): self.render_event = mp.Event() def __del__(self): - print(f'lightreal({self.sessionid}) delete') + logger.info(f'lightreal({self.sessionid}) delete') def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): @@ -302,7 +302,7 @@ class LightReal(BaseReal): asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) self.record_audio_data(frame) #self.notify(eventpoint) - print('lightreal process_frames thread stop') + logger.info('lightreal process_frames thread stop') def render(self,quit_event,loop=None,audio_track=None,video_track=None): #if self.opt.asr: @@ -331,13 +331,13 @@ class LightReal(BaseReal): # print('sleep qsize=',video_track._queue.qsize()) # time.sleep(0.04*video_track._queue.qsize()*0.8) if video_track._queue.qsize()>=5: - print('sleep qsize=',video_track._queue.qsize()) + logger.debug('sleep qsize=%d',video_track._queue.qsize()) time.sleep(0.04*video_track._queue.qsize()*0.8) # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # if delay > 0: # time.sleep(delay) #self.render_event.clear() #end infer process render - print('lightreal thread stop') + logger.info('lightreal thread stop') diff --git a/lipreal.py b/lipreal.py index b569182..97a5626 100644 --- a/lipreal.py +++ b/lipreal.py @@ -42,9 +42,10 @@ from basereal import BaseReal #from imgcache import ImgCache from tqdm import tqdm +from logger import logger device = 'cuda' if torch.cuda.is_available() else 'cpu' -print('Using {} for inference.'.format(device)) +logger.info('Using {} for inference.'.format(device)) def _load(checkpoint_path): if device == 'cuda': @@ -56,7 +57,7 @@ def _load(checkpoint_path): def load_model(path): model = Wav2Lip() - print("Load checkpoint from: {}".format(path)) + logger.info("Load checkpoint from: {}".format(path)) checkpoint = _load(path) s = checkpoint["state_dict"] new_s = {} @@ -88,14 +89,14 @@ def load_avatar(avatar_id): @torch.no_grad() def warm_up(batch_size,model,modelres): # 预热函数 - print('warmup model...') + logger.info('warmup model...') img_batch = torch.ones(batch_size, 6, modelres, modelres).to(device) mel_batch = torch.ones(batch_size, 1, 80, 16).to(device) model(mel_batch, img_batch) def read_imgs(img_list): frames = [] - print('reading images...') + logger.info('reading images...') for img_path in tqdm(img_list): frame = cv2.imread(img_path) frames.append(frame) @@ -122,7 +123,7 @@ def inference(quit_event,batch_size,face_list_cycle,audio_feat_queue,audio_out_q index = 0 count=0 counttime=0 - print('start inference') + logger.info('start inference') while not quit_event.is_set(): starttime=time.perf_counter() mel_batch = [] @@ -170,7 +171,7 @@ def inference(quit_event,batch_size,face_list_cycle,audio_feat_queue,audio_out_q count += batch_size #_totalframe += 1 if count>=100: - print(f"------actual avg infer fps:{count/counttime:.4f}") + logger.info(f"------actual avg infer fps:{count/counttime:.4f}") count=0 counttime=0 for i,res_frame in enumerate(pred): @@ -178,7 +179,7 @@ def inference(quit_event,batch_size,face_list_cycle,audio_feat_queue,audio_out_q res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2])) index = index + 1 #print('total batch time:',time.perf_counter()-starttime) - print('lipreal inference processor stop') + logger.info('lipreal inference processor stop') class LipReal(BaseReal): @torch.no_grad() @@ -203,7 +204,7 @@ class LipReal(BaseReal): self.render_event = mp.Event() def __del__(self): - print(f'lipreal({self.sessionid}) delete') + logger.info(f'lipreal({self.sessionid}) delete') def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): @@ -256,7 +257,7 @@ class LipReal(BaseReal): asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) self.record_audio_data(frame) #self.notify(eventpoint) - print('lipreal process_frames thread stop') + logger.info('lipreal process_frames thread stop') def render(self,quit_event,loop=None,audio_track=None,video_track=None): #if self.opt.asr: @@ -286,12 +287,12 @@ class LipReal(BaseReal): # print('sleep qsize=',video_track._queue.qsize()) # time.sleep(0.04*video_track._queue.qsize()*0.8) if video_track._queue.qsize()>=5: - print('sleep qsize=',video_track._queue.qsize()) + logger.debug('sleep qsize=%d',video_track._queue.qsize()) time.sleep(0.04*video_track._queue.qsize()*0.8) # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # if delay > 0: # time.sleep(delay) #self.render_event.clear() #end infer process render - print('lipreal thread stop') + logger.info('lipreal thread stop') \ No newline at end of file diff --git a/llm.py b/llm.py new file mode 100644 index 0000000..857084d --- /dev/null +++ b/llm.py @@ -0,0 +1,48 @@ +import time +import os +from basereal import BaseReal +from logger import logger + +def llm_response(message,nerfreal:BaseReal): + start = time.perf_counter() + from openai import OpenAI + client = OpenAI( + # 如果您没有配置环境变量,请在此处用您的API Key进行替换 + api_key=os.getenv("DASHSCOPE_API_KEY"), + # 填写DashScope SDK的base_url + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", + ) + end = time.perf_counter() + logger.info(f"llm Time init: {end-start}s") + completion = client.chat.completions.create( + model="qwen-plus", + messages=[{'role': 'system', 'content': 'You are a helpful assistant.'}, + {'role': 'user', 'content': message}], + stream=True, + # 通过以下设置,在流式输出的最后一行展示token使用信息 + stream_options={"include_usage": True} + ) + result="" + first = True + for chunk in completion: + if len(chunk.choices)>0: + #print(chunk.choices[0].delta.content) + if first: + end = time.perf_counter() + logger.info(f"llm Time to first chunk: {end-start}s") + first = False + msg = chunk.choices[0].delta.content + lastpos=0 + #msglist = re.split('[,.!;:,。!?]',msg) + for i, char in enumerate(msg): + if char in ",.!;:,。!?:;" : + result = result+msg[lastpos:i+1] + lastpos = i+1 + if len(result)>10: + logger.info(result) + nerfreal.put_msg_txt(result) + result="" + result = result+msg[lastpos:] + end = time.perf_counter() + logger.info(f"llm Time to last chunk: {end-start}s") + nerfreal.put_msg_txt(result) \ No newline at end of file diff --git a/logger.py b/logger.py new file mode 100644 index 0000000..9887daa --- /dev/null +++ b/logger.py @@ -0,0 +1,16 @@ +import logging + +# 配置日志器 +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +fhandler = logging.FileHandler('livetalking.log') # 可以改为StreamHandler输出到控制台或多个Handler组合使用等。 +fhandler.setFormatter(formatter) +fhandler.setLevel(logging.INFO) +logger.addHandler(fhandler) + +handler = logging.StreamHandler() +handler.setLevel(logging.DEBUG) +sformatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') +handler.setFormatter(sformatter) +logger.addHandler(handler) \ No newline at end of file diff --git a/musereal.py b/musereal.py index ed88ec5..bfbb155 100644 --- a/musereal.py +++ b/musereal.py @@ -46,6 +46,7 @@ from av import AudioFrame, VideoFrame from basereal import BaseReal from tqdm import tqdm +from logger import logger def load_model(): # load model weights @@ -92,7 +93,7 @@ def load_avatar(avatar_id): @torch.no_grad() def warm_up(batch_size,model): # 预热函数 - print('warmup model...') + logger.info('warmup model...') vae, unet, pe, timesteps, audio_processor = model #batch_size = 16 #timesteps = torch.tensor([0], device=unet.device) @@ -110,7 +111,7 @@ def warm_up(batch_size,model): def read_imgs(img_list): frames = [] - print('reading images...') + logger.info('reading images...') for img_path in tqdm(img_list): frame = cv2.imread(img_path) frames.append(frame) @@ -140,7 +141,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a index = 0 count=0 counttime=0 - print('start inference') + logger.info('start inference') while render_event.is_set(): starttime=time.perf_counter() try: @@ -195,7 +196,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a count += batch_size #_totalframe += 1 if count>=100: - print(f"------actual avg infer fps:{count/counttime:.4f}") + logger.info(f"------actual avg infer fps:{count/counttime:.4f}") count=0 counttime=0 for i,res_frame in enumerate(recon): @@ -203,7 +204,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2])) index = index + 1 #print('total batch time:',time.perf_counter()-starttime) - print('musereal inference processor stop') + logger.info('musereal inference processor stop') class MuseReal(BaseReal): @torch.no_grad() @@ -229,7 +230,7 @@ class MuseReal(BaseReal): self.render_event = mp.Event() def __del__(self): - print(f'musereal({self.sessionid}) delete') + logger.info(f'musereal({self.sessionid}) delete') def __mirror_index(self, index): @@ -251,7 +252,7 @@ class MuseReal(BaseReal): latent = self.input_latent_list_cycle[idx] latent_batch.append(latent) latent_batch = torch.cat(latent_batch, dim=0) - print('infer=======') + logger.info('infer=======') # for i, (whisper_batch,latent_batch) in enumerate(gen): audio_feature_batch = torch.from_numpy(whisper_batch) audio_feature_batch = audio_feature_batch.to(device=self.unet.device, @@ -317,7 +318,7 @@ class MuseReal(BaseReal): self.record_audio_data(frame) #self.notify(eventpoint) #self.recordq_audio.put(new_frame) - print('musereal process_frames thread stop') + logger.info('musereal process_frames thread stop') def render(self,quit_event,loop=None,audio_track=None,video_track=None): #if self.opt.asr: @@ -349,7 +350,7 @@ class MuseReal(BaseReal): # count=0 # totaltime=0 if video_track._queue.qsize()>=1.5*self.opt.batch_size: - print('sleep qsize=',video_track._queue.qsize()) + logger.debug('sleep qsize=%d',video_track._queue.qsize()) time.sleep(0.04*video_track._queue.qsize()*0.8) # if video_track._queue.qsize()>=5: # print('sleep qsize=',video_track._queue.qsize()) @@ -359,5 +360,5 @@ class MuseReal(BaseReal): # if delay > 0: # time.sleep(delay) self.render_event.clear() #end infer process render - print('musereal thread stop') + logger.info('musereal thread stop') diff --git a/nerfreal.py b/nerfreal.py index f75da52..6ae0c63 100644 --- a/nerfreal.py +++ b/nerfreal.py @@ -38,10 +38,11 @@ from ernerf.nerf_triplane.utils import * from ernerf.nerf_triplane.network import NeRFNetwork from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel +from logger import logger from tqdm import tqdm def read_imgs(img_list): frames = [] - print('reading images...') + logger.info('reading images...') for img_path in tqdm(img_list): frame = cv2.imread(img_path) frames.append(frame) @@ -74,21 +75,21 @@ def load_model(opt): # assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss." assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays." seed_everything(opt.seed) - print(opt) + logger.info(opt) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = NeRFNetwork(opt) criterion = torch.nn.MSELoss(reduction='none') metrics = [] # use no metric in GUI for faster initialization... - print(model) + logger.info(model) trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt) test_loader = NeRFDataset_Test(opt, device=device).dataloader() model.aud_features = test_loader._data.auds model.eye_areas = test_loader._data.eye_area - print(f'[INFO] loading ASR model {opt.asr_model}...') + logger.info(f'[INFO] loading ASR model {opt.asr_model}...') if 'hubert' in opt.asr_model: audio_processor = Wav2Vec2Processor.from_pretrained(opt.asr_model) audio_model = HubertModel.from_pretrained(opt.asr_model).to(device) @@ -197,7 +198,7 @@ class NeRFReal(BaseReal): ''' def __del__(self): - print(f'nerfreal({self.sessionid}) delete') + logger.info(f'nerfreal({self.sessionid}) delete') def __enter__(self): return self @@ -365,7 +366,7 @@ class NeRFReal(BaseReal): count += 1 _totalframe += 1 if count==100: - print(f"------actual avg infer fps:{count/totaltime:.4f}") + logger.info(f"------actual avg infer fps:{count/totaltime:.4f}") count=0 totaltime=0 if self.opt.transport=='rtmp': @@ -376,6 +377,6 @@ class NeRFReal(BaseReal): if video_track._queue.qsize()>=5: #print('sleep qsize=',video_track._queue.qsize()) time.sleep(0.04*video_track._queue.qsize()*0.8) - print('nerfreal thread stop') + logger.info('nerfreal thread stop') \ No newline at end of file diff --git a/ttsreal.py b/ttsreal.py index 7ddf12e..e0b35a3 100644 --- a/ttsreal.py +++ b/ttsreal.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. ############################################################################### -import os +from __future__ import annotations import time import numpy as np import soundfile as sf @@ -32,12 +32,17 @@ from io import BytesIO from threading import Thread, Event from enum import Enum +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from basereal import BaseReal + +from logger import logger class State(Enum): RUNNING=0 PAUSE=1 class BaseTTS: - def __init__(self, opt, parent): + def __init__(self, opt, parent:BaseReal): self.opt=opt self.parent = parent @@ -53,7 +58,7 @@ class BaseTTS: self.msgqueue.queue.clear() self.state = State.PAUSE - def put_msg_txt(self,msg,eventpoint=None): + def put_msg_txt(self,msg:str,eventpoint=None): if len(msg)>0: self.msgqueue.put((msg,eventpoint)) @@ -69,7 +74,7 @@ class BaseTTS: except queue.Empty: continue self.txt_to_audio(msg) - print('ttsreal thread stop') + logger.info('ttsreal thread stop') def txt_to_audio(self,msg): pass @@ -82,9 +87,9 @@ class EdgeTTS(BaseTTS): text,textevent = msg t = time.time() asyncio.new_event_loop().run_until_complete(self.__main(voicename,text)) - print(f'-------edge tts time:{time.time()-t:.4f}s') + logger.info(f'-------edge tts time:{time.time()-t:.4f}s') if self.input_stream.getbuffer().nbytes<=0: #edgetts err - print('edgetts err!!!!!') + logger.error('edgetts err!!!!!') return self.input_stream.seek(0) @@ -108,15 +113,15 @@ class EdgeTTS(BaseTTS): def __create_bytes_stream(self,byte_stream): #byte_stream=BytesIO(buffer) stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 - print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') + logger.info(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') stream = stream.astype(np.float32) if stream.ndim > 1: - print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') + logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') stream = stream[:, 0] if sample_rate != self.sample_rate and stream.shape[0]>0: - print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') + logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) return stream @@ -137,7 +142,7 @@ class EdgeTTS(BaseTTS): elif chunk["type"] == "WordBoundary": pass except Exception as e: - print(e) + logger.exception('edgetts') ########################################################################################### class FishTTS(BaseTTS): @@ -173,10 +178,10 @@ class FishTTS(BaseTTS): }, ) end = time.perf_counter() - print(f"fish_speech Time to make POST: {end-start}s") + logger.info(f"fish_speech Time to make POST: {end-start}s") if res.status_code != 200: - print("Error:", res.text) + logger.error("Error:%s", res.text) return first = True @@ -185,13 +190,13 @@ class FishTTS(BaseTTS): #print('chunk len:',len(chunk)) if first: end = time.perf_counter() - print(f"fish_speech Time to first chunk: {end-start}s") + logger.info(f"fish_speech Time to first chunk: {end-start}s") first = False if chunk and self.state==State.RUNNING: yield chunk #print("gpt_sovits response.elapsed:", res.elapsed) except Exception as e: - print(e) + logger.exception('fishtts') def stream_tts(self,audio_stream,msg): text,textevent = msg @@ -254,38 +259,38 @@ class VoitsTTS(BaseTTS): stream=True, ) end = time.perf_counter() - print(f"gpt_sovits Time to make POST: {end-start}s") + logger.info(f"gpt_sovits Time to make POST: {end-start}s") if res.status_code != 200: - print("Error:", res.text) + logger.error("Error:%s", res.text) return first = True for chunk in res.iter_content(chunk_size=None): #12800 1280 32K*20ms*2 - print('chunk len:',len(chunk)) + logger.info('chunk len:%d',len(chunk)) if first: end = time.perf_counter() - print(f"gpt_sovits Time to first chunk: {end-start}s") + logger.info(f"gpt_sovits Time to first chunk: {end-start}s") first = False if chunk and self.state==State.RUNNING: yield chunk #print("gpt_sovits response.elapsed:", res.elapsed) except Exception as e: - print(e) + logger.exception('sovits') def __create_bytes_stream(self,byte_stream): #byte_stream=BytesIO(buffer) stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 - print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') + logger.info(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') stream = stream.astype(np.float32) if stream.ndim > 1: - print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') + logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') stream = stream[:, 0] if sample_rate != self.sample_rate and stream.shape[0]>0: - print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') + logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) return stream @@ -338,10 +343,10 @@ class CosyVoiceTTS(BaseTTS): res = requests.request("GET", f"{server_url}/inference_zero_shot", data=payload, files=files, stream=True) end = time.perf_counter() - print(f"cosy_voice Time to make POST: {end-start}s") + logger.info(f"cosy_voice Time to make POST: {end-start}s") if res.status_code != 200: - print("Error:", res.text) + logger.error("Error:%s", res.text) return first = True @@ -349,12 +354,12 @@ class CosyVoiceTTS(BaseTTS): for chunk in res.iter_content(chunk_size=8820): # 882 22.05K*20ms*2 if first: end = time.perf_counter() - print(f"cosy_voice Time to first chunk: {end-start}s") + logger.info(f"cosy_voice Time to first chunk: {end-start}s") first = False if chunk and self.state==State.RUNNING: yield chunk except Exception as e: - print(e) + logger.exception('cosyvoice') def stream_tts(self,audio_stream,msg): text,textevent = msg @@ -414,7 +419,7 @@ class XTTS(BaseTTS): stream=True, ) end = time.perf_counter() - print(f"xtts Time to make POST: {end-start}s") + logger.info(f"xtts Time to make POST: {end-start}s") if res.status_code != 200: print("Error:", res.text) @@ -425,7 +430,7 @@ class XTTS(BaseTTS): for chunk in res.iter_content(chunk_size=9600): #24K*20ms*2 if first: end = time.perf_counter() - print(f"xtts Time to first chunk: {end-start}s") + logger.info(f"xtts Time to first chunk: {end-start}s") first = False if chunk: yield chunk diff --git a/webrtc.py b/webrtc.py index 81779e8..7cd3334 100644 --- a/webrtc.py +++ b/webrtc.py @@ -40,8 +40,9 @@ from aiortc import ( MediaStreamTrack, ) -logging.basicConfig() -logger = logging.getLogger(__name__) +#logging.basicConfig() +#logger = logging.getLogger(__name__) +from logger import logger class PlayerStreamTrack(MediaStreamTrack): @@ -82,7 +83,7 @@ class PlayerStreamTrack(MediaStreamTrack): self._start = time.time() self._timestamp = 0 self.timelist.append(self._start) - print('video start:',self._start) + logger.info('video start:%f',self._start) return self._timestamp, VIDEO_TIME_BASE else: #audio if hasattr(self, "_timestamp"): @@ -100,7 +101,7 @@ class PlayerStreamTrack(MediaStreamTrack): self._start = time.time() self._timestamp = 0 self.timelist.append(self._start) - print('audio start:',self._start) + logger.info('audio start:%f',self._start) return self._timestamp, AUDIO_TIME_BASE async def recv(self) -> Union[Frame, Packet]: @@ -136,7 +137,7 @@ class PlayerStreamTrack(MediaStreamTrack): self.framecount += 1 self.lasttime = time.perf_counter() if self.framecount==100: - print(f"------actual avg final fps:{self.framecount/self.totaltime:.4f}") + logger.info(f"------actual avg final fps:{self.framecount/self.totaltime:.4f}") self.framecount = 0 self.totaltime=0 return frame