add log and typing

main
lipku 5 months ago
parent 4375948481
commit 367477797b

@ -19,12 +19,10 @@
from flask import Flask, render_template,send_from_directory,request, jsonify
from flask_sockets import Sockets
import base64
import time
import json
#import gevent
#from gevent import pywsgi
#from geventwebsocket.handler import WebSocketHandler
import os
import re
import numpy as np
from threading import Thread,Event
@ -37,86 +35,36 @@ import aiohttp_cors
from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.rtcrtpsender import RTCRtpSender
from webrtc import HumanPlayer
from basereal import BaseReal
from llm import llm_response
import argparse
import random
import shutil
import asyncio
import torch
from typing import Dict
from logger import logger
app = Flask(__name__)
#sockets = Sockets(app)
nerfreals = {}
nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal
opt = None
model = None
avatar = None
# def llm_response(message):
# from llm.LLM import LLM
# # llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None)
# # llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key')
# llm = LLM().init_model('VllmGPT', model_path= 'THUDM/chatglm3-6b')
# response = llm.chat(message)
# print(response)
# return response
def llm_response(message,nerfreal):
start = time.perf_counter()
from openai import OpenAI
client = OpenAI(
# 如果您没有配置环境变量请在此处用您的API Key进行替换
api_key=os.getenv("DASHSCOPE_API_KEY"),
# 填写DashScope SDK的base_url
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
end = time.perf_counter()
print(f"llm Time init: {end-start}s")
completion = client.chat.completions.create(
model="qwen-plus",
messages=[{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': message}],
stream=True,
# 通过以下设置在流式输出的最后一行展示token使用信息
stream_options={"include_usage": True}
)
result=""
first = True
for chunk in completion:
if len(chunk.choices)>0:
#print(chunk.choices[0].delta.content)
if first:
end = time.perf_counter()
print(f"llm Time to first chunk: {end-start}s")
first = False
msg = chunk.choices[0].delta.content
lastpos=0
#msglist = re.split('[,.!;:,。!?]',msg)
for i, char in enumerate(msg):
if char in ",.!;:,。!?:;" :
result = result+msg[lastpos:i+1]
lastpos = i+1
if len(result)>10:
print(result)
nerfreal.put_msg_txt(result)
result=""
result = result+msg[lastpos:]
end = time.perf_counter()
print(f"llm Time to last chunk: {end-start}s")
nerfreal.put_msg_txt(result)
#####webrtc###############################
pcs = set()
def randN(N):
def randN(N)->int:
'''生成长度为 N的随机数 '''
min = pow(10, N - 1)
max = pow(10, N)
return random.randint(min, max - 1)
def build_nerfreal(sessionid):
def build_nerfreal(sessionid:int)->BaseReal:
opt.sessionid=sessionid
if opt.model == 'wav2lip':
from lipreal import LipReal
@ -138,10 +86,10 @@ async def offer(request):
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
if len(nerfreals) >= opt.max_session:
print('reach max session')
logger.info('reach max session')
return -1
sessionid = randN(6) #len(nerfreals)
print('sessionid=',sessionid)
logger.info('sessionid=%d',sessionid)
nerfreals[sessionid] = None
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
nerfreals[sessionid] = nerfreal
@ -151,7 +99,7 @@ async def offer(request):
@pc.on("connectionstatechange")
async def on_connectionstatechange():
print("Connection state is %s" % pc.connectionState)
logger.info("Connection state is %s" % pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
@ -280,7 +228,7 @@ async def post(url,data):
async with session.post(url,data=data) as response:
return await response.text()
except aiohttp.ClientError as e:
print(f'Error: {e}')
logger.info(f'Error: {e}')
async def run(push_url,sessionid):
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
@ -291,7 +239,7 @@ async def run(push_url,sessionid):
@pc.on("connectionstatechange")
async def on_connectionstatechange():
print("Connection state is %s" % pc.connectionState)
logger.info("Connection state is %s" % pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
@ -465,7 +413,7 @@ if __name__ == '__main__':
# nerfreals.append(nerfreal)
elif opt.model == 'musetalk':
from musereal import MuseReal,load_model,load_avatar,warm_up
print(opt)
logger.info(opt)
model = load_model()
avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,model)
@ -475,7 +423,7 @@ if __name__ == '__main__':
# nerfreals.append(nerfreal)
elif opt.model == 'wav2lip':
from lipreal import LipReal,load_model,load_avatar,warm_up
print(opt)
logger.info(opt)
model = load_model("./models/wav2lip.pth")
avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,model,256)
@ -485,7 +433,7 @@ if __name__ == '__main__':
# nerfreals.append(nerfreal)
elif opt.model == 'ultralight':
from lightreal import LightReal,load_model,load_avatar,warm_up
print(opt)
logger.info(opt)
model = load_model(opt)
avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,avatar,160)
@ -524,7 +472,7 @@ if __name__ == '__main__':
pagename='echoapi.html'
elif opt.transport=='rtcpush':
pagename='rtcpushapi.html'
print('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename)
logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename)
def run_server(runner):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

@ -22,9 +22,11 @@ import queue
from queue import Queue
import torch.multiprocessing as mp
from basereal import BaseReal
class BaseASR:
def __init__(self, opt, parent=None):
def __init__(self, opt, parent:BaseReal|None = None):
self.opt = opt
self.parent = parent

@ -36,11 +36,12 @@ import av
from fractions import Fraction
from ttsreal import EdgeTTS,VoitsTTS,XTTS,CosyVoiceTTS,FishTTS
from logger import logger
from tqdm import tqdm
def read_imgs(img_list):
frames = []
print('reading images...')
logger.info('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
@ -98,15 +99,15 @@ class BaseReal:
def __create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer)
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]put audio stream {sample_rate}: {stream.shape}')
logger.info(f'[INFO]put audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self.sample_rate and stream.shape[0]>0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
return stream
@ -120,7 +121,7 @@ class BaseReal:
def __loadcustom(self):
for item in self.opt.customopt:
print(item)
logger.info(item)
input_img_list = glob.glob(os.path.join(item['imgpath'], '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.custom_img_cycle[item['audiotype']] = read_imgs(input_img_list)
@ -137,7 +138,7 @@ class BaseReal:
self.custom_index[key]=0
def notify(self,eventpoint):
print("notify:",eventpoint)
logger.info("notify:%s",eventpoint)
def start_recording(self):
"""开始录制视频"""

@ -54,11 +54,11 @@ from transformers import Wav2Vec2Processor, HubertModel
from torch.utils.data import DataLoader
from ultralight.unet import Model
from ultralight.audio2feature import Audio2Feature
from logger import logger
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device))
logger.info('Using {} for inference.'.format(device))
def load_model(opt):
@ -89,7 +89,7 @@ def load_avatar(avatar_id):
@torch.no_grad()
def warm_up(batch_size,avatar,modelres):
print('warmup model...')
logger.info('warmup model...')
model,_,_,_ = avatar
img_batch = torch.ones(batch_size, 6, modelres, modelres).to(device)
mel_batch = torch.ones(batch_size, 32, 32, 32).to(device)
@ -97,7 +97,7 @@ def warm_up(batch_size,avatar,modelres):
def read_imgs(img_list):
frames = []
print('reading images...')
logger.info('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
@ -124,7 +124,7 @@ def get_audio_features(features, index):
def read_lms(lms_list):
land_marks = []
print('reading lms...')
logger.info('reading lms...')
for lms_path in tqdm(lms_list):
file_landmarks = [] # Store landmarks for this file
with open(lms_path, "r") as f:
@ -152,7 +152,7 @@ def inference(quit_event, batch_size, face_list_cycle, audio_feat_queue, audio_o
index = 0
count = 0
counttime = 0
print('start inference')
logger.info('start inference')
while not quit_event.is_set():
starttime=time.perf_counter()
@ -206,7 +206,7 @@ def inference(quit_event, batch_size, face_list_cycle, audio_feat_queue, audio_o
counttime += (time.perf_counter() - t)
count += batch_size
if count >= 100:
print(f"------actual avg infer fps:{count / counttime:.4f}")
logger.info(f"------actual avg infer fps:{count / counttime:.4f}")
count = 0
counttime = 0
for i,res_frame in enumerate(pred):
@ -221,7 +221,7 @@ def inference(quit_event, batch_size, face_list_cycle, audio_feat_queue, audio_o
#print('total batch time:', time.perf_counter() - starttime)
print('lightreal inference processor stop')
logger.info('lightreal inference processor stop')
class LightReal(BaseReal):
@ -248,7 +248,7 @@ class LightReal(BaseReal):
self.render_event = mp.Event()
def __del__(self):
print(f'lightreal({self.sessionid}) delete')
logger.info(f'lightreal({self.sessionid}) delete')
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
@ -302,7 +302,7 @@ class LightReal(BaseReal):
asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop)
self.record_audio_data(frame)
#self.notify(eventpoint)
print('lightreal process_frames thread stop')
logger.info('lightreal process_frames thread stop')
def render(self,quit_event,loop=None,audio_track=None,video_track=None):
#if self.opt.asr:
@ -331,13 +331,13 @@ class LightReal(BaseReal):
# print('sleep qsize=',video_track._queue.qsize())
# time.sleep(0.04*video_track._queue.qsize()*0.8)
if video_track._queue.qsize()>=5:
print('sleep qsize=',video_track._queue.qsize())
logger.debug('sleep qsize=%d',video_track._queue.qsize())
time.sleep(0.04*video_track._queue.qsize()*0.8)
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0:
# time.sleep(delay)
#self.render_event.clear() #end infer process render
print('lightreal thread stop')
logger.info('lightreal thread stop')

@ -42,9 +42,10 @@ from basereal import BaseReal
#from imgcache import ImgCache
from tqdm import tqdm
from logger import logger
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device))
logger.info('Using {} for inference.'.format(device))
def _load(checkpoint_path):
if device == 'cuda':
@ -56,7 +57,7 @@ def _load(checkpoint_path):
def load_model(path):
model = Wav2Lip()
print("Load checkpoint from: {}".format(path))
logger.info("Load checkpoint from: {}".format(path))
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
@ -88,14 +89,14 @@ def load_avatar(avatar_id):
@torch.no_grad()
def warm_up(batch_size,model,modelres):
# 预热函数
print('warmup model...')
logger.info('warmup model...')
img_batch = torch.ones(batch_size, 6, modelres, modelres).to(device)
mel_batch = torch.ones(batch_size, 1, 80, 16).to(device)
model(mel_batch, img_batch)
def read_imgs(img_list):
frames = []
print('reading images...')
logger.info('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
@ -122,7 +123,7 @@ def inference(quit_event,batch_size,face_list_cycle,audio_feat_queue,audio_out_q
index = 0
count=0
counttime=0
print('start inference')
logger.info('start inference')
while not quit_event.is_set():
starttime=time.perf_counter()
mel_batch = []
@ -170,7 +171,7 @@ def inference(quit_event,batch_size,face_list_cycle,audio_feat_queue,audio_out_q
count += batch_size
#_totalframe += 1
if count>=100:
print(f"------actual avg infer fps:{count/counttime:.4f}")
logger.info(f"------actual avg infer fps:{count/counttime:.4f}")
count=0
counttime=0
for i,res_frame in enumerate(pred):
@ -178,7 +179,7 @@ def inference(quit_event,batch_size,face_list_cycle,audio_feat_queue,audio_out_q
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
#print('total batch time:',time.perf_counter()-starttime)
print('lipreal inference processor stop')
logger.info('lipreal inference processor stop')
class LipReal(BaseReal):
@torch.no_grad()
@ -203,7 +204,7 @@ class LipReal(BaseReal):
self.render_event = mp.Event()
def __del__(self):
print(f'lipreal({self.sessionid}) delete')
logger.info(f'lipreal({self.sessionid}) delete')
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
@ -256,7 +257,7 @@ class LipReal(BaseReal):
asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop)
self.record_audio_data(frame)
#self.notify(eventpoint)
print('lipreal process_frames thread stop')
logger.info('lipreal process_frames thread stop')
def render(self,quit_event,loop=None,audio_track=None,video_track=None):
#if self.opt.asr:
@ -286,12 +287,12 @@ class LipReal(BaseReal):
# print('sleep qsize=',video_track._queue.qsize())
# time.sleep(0.04*video_track._queue.qsize()*0.8)
if video_track._queue.qsize()>=5:
print('sleep qsize=',video_track._queue.qsize())
logger.debug('sleep qsize=%d',video_track._queue.qsize())
time.sleep(0.04*video_track._queue.qsize()*0.8)
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0:
# time.sleep(delay)
#self.render_event.clear() #end infer process render
print('lipreal thread stop')
logger.info('lipreal thread stop')

@ -0,0 +1,48 @@
import time
import os
from basereal import BaseReal
from logger import logger
def llm_response(message,nerfreal:BaseReal):
start = time.perf_counter()
from openai import OpenAI
client = OpenAI(
# 如果您没有配置环境变量请在此处用您的API Key进行替换
api_key=os.getenv("DASHSCOPE_API_KEY"),
# 填写DashScope SDK的base_url
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
end = time.perf_counter()
logger.info(f"llm Time init: {end-start}s")
completion = client.chat.completions.create(
model="qwen-plus",
messages=[{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': message}],
stream=True,
# 通过以下设置在流式输出的最后一行展示token使用信息
stream_options={"include_usage": True}
)
result=""
first = True
for chunk in completion:
if len(chunk.choices)>0:
#print(chunk.choices[0].delta.content)
if first:
end = time.perf_counter()
logger.info(f"llm Time to first chunk: {end-start}s")
first = False
msg = chunk.choices[0].delta.content
lastpos=0
#msglist = re.split('[,.!;:,。!?]',msg)
for i, char in enumerate(msg):
if char in ",.!;:,。!?:;" :
result = result+msg[lastpos:i+1]
lastpos = i+1
if len(result)>10:
logger.info(result)
nerfreal.put_msg_txt(result)
result=""
result = result+msg[lastpos:]
end = time.perf_counter()
logger.info(f"llm Time to last chunk: {end-start}s")
nerfreal.put_msg_txt(result)

@ -0,0 +1,16 @@
import logging
# 配置日志器
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fhandler = logging.FileHandler('livetalking.log') # 可以改为StreamHandler输出到控制台或多个Handler组合使用等。
fhandler.setFormatter(formatter)
fhandler.setLevel(logging.INFO)
logger.addHandler(fhandler)
handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG)
sformatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(sformatter)
logger.addHandler(handler)

@ -46,6 +46,7 @@ from av import AudioFrame, VideoFrame
from basereal import BaseReal
from tqdm import tqdm
from logger import logger
def load_model():
# load model weights
@ -92,7 +93,7 @@ def load_avatar(avatar_id):
@torch.no_grad()
def warm_up(batch_size,model):
# 预热函数
print('warmup model...')
logger.info('warmup model...')
vae, unet, pe, timesteps, audio_processor = model
#batch_size = 16
#timesteps = torch.tensor([0], device=unet.device)
@ -110,7 +111,7 @@ def warm_up(batch_size,model):
def read_imgs(img_list):
frames = []
print('reading images...')
logger.info('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
@ -140,7 +141,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a
index = 0
count=0
counttime=0
print('start inference')
logger.info('start inference')
while render_event.is_set():
starttime=time.perf_counter()
try:
@ -195,7 +196,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a
count += batch_size
#_totalframe += 1
if count>=100:
print(f"------actual avg infer fps:{count/counttime:.4f}")
logger.info(f"------actual avg infer fps:{count/counttime:.4f}")
count=0
counttime=0
for i,res_frame in enumerate(recon):
@ -203,7 +204,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
#print('total batch time:',time.perf_counter()-starttime)
print('musereal inference processor stop')
logger.info('musereal inference processor stop')
class MuseReal(BaseReal):
@torch.no_grad()
@ -229,7 +230,7 @@ class MuseReal(BaseReal):
self.render_event = mp.Event()
def __del__(self):
print(f'musereal({self.sessionid}) delete')
logger.info(f'musereal({self.sessionid}) delete')
def __mirror_index(self, index):
@ -251,7 +252,7 @@ class MuseReal(BaseReal):
latent = self.input_latent_list_cycle[idx]
latent_batch.append(latent)
latent_batch = torch.cat(latent_batch, dim=0)
print('infer=======')
logger.info('infer=======')
# for i, (whisper_batch,latent_batch) in enumerate(gen):
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=self.unet.device,
@ -317,7 +318,7 @@ class MuseReal(BaseReal):
self.record_audio_data(frame)
#self.notify(eventpoint)
#self.recordq_audio.put(new_frame)
print('musereal process_frames thread stop')
logger.info('musereal process_frames thread stop')
def render(self,quit_event,loop=None,audio_track=None,video_track=None):
#if self.opt.asr:
@ -349,7 +350,7 @@ class MuseReal(BaseReal):
# count=0
# totaltime=0
if video_track._queue.qsize()>=1.5*self.opt.batch_size:
print('sleep qsize=',video_track._queue.qsize())
logger.debug('sleep qsize=%d',video_track._queue.qsize())
time.sleep(0.04*video_track._queue.qsize()*0.8)
# if video_track._queue.qsize()>=5:
# print('sleep qsize=',video_track._queue.qsize())
@ -359,5 +360,5 @@ class MuseReal(BaseReal):
# if delay > 0:
# time.sleep(delay)
self.render_event.clear() #end infer process render
print('musereal thread stop')
logger.info('musereal thread stop')

@ -38,10 +38,11 @@ from ernerf.nerf_triplane.utils import *
from ernerf.nerf_triplane.network import NeRFNetwork
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
from logger import logger
from tqdm import tqdm
def read_imgs(img_list):
frames = []
print('reading images...')
logger.info('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
@ -74,21 +75,21 @@ def load_model(opt):
# assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."
seed_everything(opt.seed)
print(opt)
logger.info(opt)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = NeRFNetwork(opt)
criterion = torch.nn.MSELoss(reduction='none')
metrics = [] # use no metric in GUI for faster initialization...
print(model)
logger.info(model)
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
test_loader = NeRFDataset_Test(opt, device=device).dataloader()
model.aud_features = test_loader._data.auds
model.eye_areas = test_loader._data.eye_area
print(f'[INFO] loading ASR model {opt.asr_model}...')
logger.info(f'[INFO] loading ASR model {opt.asr_model}...')
if 'hubert' in opt.asr_model:
audio_processor = Wav2Vec2Processor.from_pretrained(opt.asr_model)
audio_model = HubertModel.from_pretrained(opt.asr_model).to(device)
@ -197,7 +198,7 @@ class NeRFReal(BaseReal):
'''
def __del__(self):
print(f'nerfreal({self.sessionid}) delete')
logger.info(f'nerfreal({self.sessionid}) delete')
def __enter__(self):
return self
@ -365,7 +366,7 @@ class NeRFReal(BaseReal):
count += 1
_totalframe += 1
if count==100:
print(f"------actual avg infer fps:{count/totaltime:.4f}")
logger.info(f"------actual avg infer fps:{count/totaltime:.4f}")
count=0
totaltime=0
if self.opt.transport=='rtmp':
@ -376,6 +377,6 @@ class NeRFReal(BaseReal):
if video_track._queue.qsize()>=5:
#print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.04*video_track._queue.qsize()*0.8)
print('nerfreal thread stop')
logger.info('nerfreal thread stop')

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import os
from __future__ import annotations
import time
import numpy as np
import soundfile as sf
@ -32,12 +32,17 @@ from io import BytesIO
from threading import Thread, Event
from enum import Enum
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from basereal import BaseReal
from logger import logger
class State(Enum):
RUNNING=0
PAUSE=1
class BaseTTS:
def __init__(self, opt, parent):
def __init__(self, opt, parent:BaseReal):
self.opt=opt
self.parent = parent
@ -53,7 +58,7 @@ class BaseTTS:
self.msgqueue.queue.clear()
self.state = State.PAUSE
def put_msg_txt(self,msg,eventpoint=None):
def put_msg_txt(self,msg:str,eventpoint=None):
if len(msg)>0:
self.msgqueue.put((msg,eventpoint))
@ -69,7 +74,7 @@ class BaseTTS:
except queue.Empty:
continue
self.txt_to_audio(msg)
print('ttsreal thread stop')
logger.info('ttsreal thread stop')
def txt_to_audio(self,msg):
pass
@ -82,9 +87,9 @@ class EdgeTTS(BaseTTS):
text,textevent = msg
t = time.time()
asyncio.new_event_loop().run_until_complete(self.__main(voicename,text))
print(f'-------edge tts time:{time.time()-t:.4f}s')
logger.info(f'-------edge tts time:{time.time()-t:.4f}s')
if self.input_stream.getbuffer().nbytes<=0: #edgetts err
print('edgetts err!!!!!')
logger.error('edgetts err!!!!!')
return
self.input_stream.seek(0)
@ -108,15 +113,15 @@ class EdgeTTS(BaseTTS):
def __create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer)
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
logger.info(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self.sample_rate and stream.shape[0]>0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
return stream
@ -137,7 +142,7 @@ class EdgeTTS(BaseTTS):
elif chunk["type"] == "WordBoundary":
pass
except Exception as e:
print(e)
logger.exception('edgetts')
###########################################################################################
class FishTTS(BaseTTS):
@ -173,10 +178,10 @@ class FishTTS(BaseTTS):
},
)
end = time.perf_counter()
print(f"fish_speech Time to make POST: {end-start}s")
logger.info(f"fish_speech Time to make POST: {end-start}s")
if res.status_code != 200:
print("Error:", res.text)
logger.error("Error:%s", res.text)
return
first = True
@ -185,13 +190,13 @@ class FishTTS(BaseTTS):
#print('chunk len:',len(chunk))
if first:
end = time.perf_counter()
print(f"fish_speech Time to first chunk: {end-start}s")
logger.info(f"fish_speech Time to first chunk: {end-start}s")
first = False
if chunk and self.state==State.RUNNING:
yield chunk
#print("gpt_sovits response.elapsed:", res.elapsed)
except Exception as e:
print(e)
logger.exception('fishtts')
def stream_tts(self,audio_stream,msg):
text,textevent = msg
@ -254,38 +259,38 @@ class VoitsTTS(BaseTTS):
stream=True,
)
end = time.perf_counter()
print(f"gpt_sovits Time to make POST: {end-start}s")
logger.info(f"gpt_sovits Time to make POST: {end-start}s")
if res.status_code != 200:
print("Error:", res.text)
logger.error("Error:%s", res.text)
return
first = True
for chunk in res.iter_content(chunk_size=None): #12800 1280 32K*20ms*2
print('chunk len:',len(chunk))
logger.info('chunk len:%d',len(chunk))
if first:
end = time.perf_counter()
print(f"gpt_sovits Time to first chunk: {end-start}s")
logger.info(f"gpt_sovits Time to first chunk: {end-start}s")
first = False
if chunk and self.state==State.RUNNING:
yield chunk
#print("gpt_sovits response.elapsed:", res.elapsed)
except Exception as e:
print(e)
logger.exception('sovits')
def __create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer)
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
logger.info(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self.sample_rate and stream.shape[0]>0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
return stream
@ -338,10 +343,10 @@ class CosyVoiceTTS(BaseTTS):
res = requests.request("GET", f"{server_url}/inference_zero_shot", data=payload, files=files, stream=True)
end = time.perf_counter()
print(f"cosy_voice Time to make POST: {end-start}s")
logger.info(f"cosy_voice Time to make POST: {end-start}s")
if res.status_code != 200:
print("Error:", res.text)
logger.error("Error:%s", res.text)
return
first = True
@ -349,12 +354,12 @@ class CosyVoiceTTS(BaseTTS):
for chunk in res.iter_content(chunk_size=8820): # 882 22.05K*20ms*2
if first:
end = time.perf_counter()
print(f"cosy_voice Time to first chunk: {end-start}s")
logger.info(f"cosy_voice Time to first chunk: {end-start}s")
first = False
if chunk and self.state==State.RUNNING:
yield chunk
except Exception as e:
print(e)
logger.exception('cosyvoice')
def stream_tts(self,audio_stream,msg):
text,textevent = msg
@ -414,7 +419,7 @@ class XTTS(BaseTTS):
stream=True,
)
end = time.perf_counter()
print(f"xtts Time to make POST: {end-start}s")
logger.info(f"xtts Time to make POST: {end-start}s")
if res.status_code != 200:
print("Error:", res.text)
@ -425,7 +430,7 @@ class XTTS(BaseTTS):
for chunk in res.iter_content(chunk_size=9600): #24K*20ms*2
if first:
end = time.perf_counter()
print(f"xtts Time to first chunk: {end-start}s")
logger.info(f"xtts Time to first chunk: {end-start}s")
first = False
if chunk:
yield chunk

@ -40,8 +40,9 @@ from aiortc import (
MediaStreamTrack,
)
logging.basicConfig()
logger = logging.getLogger(__name__)
#logging.basicConfig()
#logger = logging.getLogger(__name__)
from logger import logger
class PlayerStreamTrack(MediaStreamTrack):
@ -82,7 +83,7 @@ class PlayerStreamTrack(MediaStreamTrack):
self._start = time.time()
self._timestamp = 0
self.timelist.append(self._start)
print('video start:',self._start)
logger.info('video start:%f',self._start)
return self._timestamp, VIDEO_TIME_BASE
else: #audio
if hasattr(self, "_timestamp"):
@ -100,7 +101,7 @@ class PlayerStreamTrack(MediaStreamTrack):
self._start = time.time()
self._timestamp = 0
self.timelist.append(self._start)
print('audio start:',self._start)
logger.info('audio start:%f',self._start)
return self._timestamp, AUDIO_TIME_BASE
async def recv(self) -> Union[Frame, Packet]:
@ -136,7 +137,7 @@ class PlayerStreamTrack(MediaStreamTrack):
self.framecount += 1
self.lasttime = time.perf_counter()
if self.framecount==100:
print(f"------actual avg final fps:{self.framecount/self.totaltime:.4f}")
logger.info(f"------actual avg final fps:{self.framecount/self.totaltime:.4f}")
self.framecount = 0
self.totaltime=0
return frame

Loading…
Cancel
Save