add log and typing

main
lipku 5 months ago
parent 4375948481
commit 367477797b

@ -19,12 +19,10 @@
from flask import Flask, render_template,send_from_directory,request, jsonify from flask import Flask, render_template,send_from_directory,request, jsonify
from flask_sockets import Sockets from flask_sockets import Sockets
import base64 import base64
import time
import json import json
#import gevent #import gevent
#from gevent import pywsgi #from gevent import pywsgi
#from geventwebsocket.handler import WebSocketHandler #from geventwebsocket.handler import WebSocketHandler
import os
import re import re
import numpy as np import numpy as np
from threading import Thread,Event from threading import Thread,Event
@ -37,86 +35,36 @@ import aiohttp_cors
from aiortc import RTCPeerConnection, RTCSessionDescription from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.rtcrtpsender import RTCRtpSender from aiortc.rtcrtpsender import RTCRtpSender
from webrtc import HumanPlayer from webrtc import HumanPlayer
from basereal import BaseReal
from llm import llm_response
import argparse import argparse
import random import random
import shutil import shutil
import asyncio import asyncio
import torch import torch
from typing import Dict
from logger import logger
app = Flask(__name__) app = Flask(__name__)
#sockets = Sockets(app) #sockets = Sockets(app)
nerfreals = {} nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal
opt = None opt = None
model = None model = None
avatar = None avatar = None
# def llm_response(message):
# from llm.LLM import LLM
# # llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None)
# # llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key')
# llm = LLM().init_model('VllmGPT', model_path= 'THUDM/chatglm3-6b')
# response = llm.chat(message)
# print(response)
# return response
def llm_response(message,nerfreal):
start = time.perf_counter()
from openai import OpenAI
client = OpenAI(
# 如果您没有配置环境变量请在此处用您的API Key进行替换
api_key=os.getenv("DASHSCOPE_API_KEY"),
# 填写DashScope SDK的base_url
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
end = time.perf_counter()
print(f"llm Time init: {end-start}s")
completion = client.chat.completions.create(
model="qwen-plus",
messages=[{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': message}],
stream=True,
# 通过以下设置在流式输出的最后一行展示token使用信息
stream_options={"include_usage": True}
)
result=""
first = True
for chunk in completion:
if len(chunk.choices)>0:
#print(chunk.choices[0].delta.content)
if first:
end = time.perf_counter()
print(f"llm Time to first chunk: {end-start}s")
first = False
msg = chunk.choices[0].delta.content
lastpos=0
#msglist = re.split('[,.!;:,。!?]',msg)
for i, char in enumerate(msg):
if char in ",.!;:,。!?:;" :
result = result+msg[lastpos:i+1]
lastpos = i+1
if len(result)>10:
print(result)
nerfreal.put_msg_txt(result)
result=""
result = result+msg[lastpos:]
end = time.perf_counter()
print(f"llm Time to last chunk: {end-start}s")
nerfreal.put_msg_txt(result)
#####webrtc############################### #####webrtc###############################
pcs = set() pcs = set()
def randN(N): def randN(N)->int:
'''生成长度为 N的随机数 ''' '''生成长度为 N的随机数 '''
min = pow(10, N - 1) min = pow(10, N - 1)
max = pow(10, N) max = pow(10, N)
return random.randint(min, max - 1) return random.randint(min, max - 1)
def build_nerfreal(sessionid): def build_nerfreal(sessionid:int)->BaseReal:
opt.sessionid=sessionid opt.sessionid=sessionid
if opt.model == 'wav2lip': if opt.model == 'wav2lip':
from lipreal import LipReal from lipreal import LipReal
@ -138,10 +86,10 @@ async def offer(request):
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
if len(nerfreals) >= opt.max_session: if len(nerfreals) >= opt.max_session:
print('reach max session') logger.info('reach max session')
return -1 return -1
sessionid = randN(6) #len(nerfreals) sessionid = randN(6) #len(nerfreals)
print('sessionid=',sessionid) logger.info('sessionid=%d',sessionid)
nerfreals[sessionid] = None nerfreals[sessionid] = None
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid) nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
nerfreals[sessionid] = nerfreal nerfreals[sessionid] = nerfreal
@ -151,7 +99,7 @@ async def offer(request):
@pc.on("connectionstatechange") @pc.on("connectionstatechange")
async def on_connectionstatechange(): async def on_connectionstatechange():
print("Connection state is %s" % pc.connectionState) logger.info("Connection state is %s" % pc.connectionState)
if pc.connectionState == "failed": if pc.connectionState == "failed":
await pc.close() await pc.close()
pcs.discard(pc) pcs.discard(pc)
@ -280,7 +228,7 @@ async def post(url,data):
async with session.post(url,data=data) as response: async with session.post(url,data=data) as response:
return await response.text() return await response.text()
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
print(f'Error: {e}') logger.info(f'Error: {e}')
async def run(push_url,sessionid): async def run(push_url,sessionid):
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid) nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
@ -291,7 +239,7 @@ async def run(push_url,sessionid):
@pc.on("connectionstatechange") @pc.on("connectionstatechange")
async def on_connectionstatechange(): async def on_connectionstatechange():
print("Connection state is %s" % pc.connectionState) logger.info("Connection state is %s" % pc.connectionState)
if pc.connectionState == "failed": if pc.connectionState == "failed":
await pc.close() await pc.close()
pcs.discard(pc) pcs.discard(pc)
@ -465,7 +413,7 @@ if __name__ == '__main__':
# nerfreals.append(nerfreal) # nerfreals.append(nerfreal)
elif opt.model == 'musetalk': elif opt.model == 'musetalk':
from musereal import MuseReal,load_model,load_avatar,warm_up from musereal import MuseReal,load_model,load_avatar,warm_up
print(opt) logger.info(opt)
model = load_model() model = load_model()
avatar = load_avatar(opt.avatar_id) avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,model) warm_up(opt.batch_size,model)
@ -475,7 +423,7 @@ if __name__ == '__main__':
# nerfreals.append(nerfreal) # nerfreals.append(nerfreal)
elif opt.model == 'wav2lip': elif opt.model == 'wav2lip':
from lipreal import LipReal,load_model,load_avatar,warm_up from lipreal import LipReal,load_model,load_avatar,warm_up
print(opt) logger.info(opt)
model = load_model("./models/wav2lip.pth") model = load_model("./models/wav2lip.pth")
avatar = load_avatar(opt.avatar_id) avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,model,256) warm_up(opt.batch_size,model,256)
@ -485,7 +433,7 @@ if __name__ == '__main__':
# nerfreals.append(nerfreal) # nerfreals.append(nerfreal)
elif opt.model == 'ultralight': elif opt.model == 'ultralight':
from lightreal import LightReal,load_model,load_avatar,warm_up from lightreal import LightReal,load_model,load_avatar,warm_up
print(opt) logger.info(opt)
model = load_model(opt) model = load_model(opt)
avatar = load_avatar(opt.avatar_id) avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,avatar,160) warm_up(opt.batch_size,avatar,160)
@ -524,7 +472,7 @@ if __name__ == '__main__':
pagename='echoapi.html' pagename='echoapi.html'
elif opt.transport=='rtcpush': elif opt.transport=='rtcpush':
pagename='rtcpushapi.html' pagename='rtcpushapi.html'
print('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename) logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename)
def run_server(runner): def run_server(runner):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)

@ -22,9 +22,11 @@ import queue
from queue import Queue from queue import Queue
import torch.multiprocessing as mp import torch.multiprocessing as mp
from basereal import BaseReal
class BaseASR: class BaseASR:
def __init__(self, opt, parent=None): def __init__(self, opt, parent:BaseReal|None = None):
self.opt = opt self.opt = opt
self.parent = parent self.parent = parent

@ -36,11 +36,12 @@ import av
from fractions import Fraction from fractions import Fraction
from ttsreal import EdgeTTS,VoitsTTS,XTTS,CosyVoiceTTS,FishTTS from ttsreal import EdgeTTS,VoitsTTS,XTTS,CosyVoiceTTS,FishTTS
from logger import logger
from tqdm import tqdm from tqdm import tqdm
def read_imgs(img_list): def read_imgs(img_list):
frames = [] frames = []
print('reading images...') logger.info('reading images...')
for img_path in tqdm(img_list): for img_path in tqdm(img_list):
frame = cv2.imread(img_path) frame = cv2.imread(img_path)
frames.append(frame) frames.append(frame)
@ -98,15 +99,15 @@ class BaseReal:
def __create_bytes_stream(self,byte_stream): def __create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer) #byte_stream=BytesIO(buffer)
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]put audio stream {sample_rate}: {stream.shape}') logger.info(f'[INFO]put audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32) stream = stream.astype(np.float32)
if stream.ndim > 1: if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0] stream = stream[:, 0]
if sample_rate != self.sample_rate and stream.shape[0]>0: if sample_rate != self.sample_rate and stream.shape[0]>0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
return stream return stream
@ -120,7 +121,7 @@ class BaseReal:
def __loadcustom(self): def __loadcustom(self):
for item in self.opt.customopt: for item in self.opt.customopt:
print(item) logger.info(item)
input_img_list = glob.glob(os.path.join(item['imgpath'], '*.[jpJP][pnPN]*[gG]')) input_img_list = glob.glob(os.path.join(item['imgpath'], '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.custom_img_cycle[item['audiotype']] = read_imgs(input_img_list) self.custom_img_cycle[item['audiotype']] = read_imgs(input_img_list)
@ -137,7 +138,7 @@ class BaseReal:
self.custom_index[key]=0 self.custom_index[key]=0
def notify(self,eventpoint): def notify(self,eventpoint):
print("notify:",eventpoint) logger.info("notify:%s",eventpoint)
def start_recording(self): def start_recording(self):
"""开始录制视频""" """开始录制视频"""

@ -54,11 +54,11 @@ from transformers import Wav2Vec2Processor, HubertModel
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from ultralight.unet import Model from ultralight.unet import Model
from ultralight.audio2feature import Audio2Feature from ultralight.audio2feature import Audio2Feature
from logger import logger
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device)) logger.info('Using {} for inference.'.format(device))
def load_model(opt): def load_model(opt):
@ -89,7 +89,7 @@ def load_avatar(avatar_id):
@torch.no_grad() @torch.no_grad()
def warm_up(batch_size,avatar,modelres): def warm_up(batch_size,avatar,modelres):
print('warmup model...') logger.info('warmup model...')
model,_,_,_ = avatar model,_,_,_ = avatar
img_batch = torch.ones(batch_size, 6, modelres, modelres).to(device) img_batch = torch.ones(batch_size, 6, modelres, modelres).to(device)
mel_batch = torch.ones(batch_size, 32, 32, 32).to(device) mel_batch = torch.ones(batch_size, 32, 32, 32).to(device)
@ -97,7 +97,7 @@ def warm_up(batch_size,avatar,modelres):
def read_imgs(img_list): def read_imgs(img_list):
frames = [] frames = []
print('reading images...') logger.info('reading images...')
for img_path in tqdm(img_list): for img_path in tqdm(img_list):
frame = cv2.imread(img_path) frame = cv2.imread(img_path)
frames.append(frame) frames.append(frame)
@ -124,7 +124,7 @@ def get_audio_features(features, index):
def read_lms(lms_list): def read_lms(lms_list):
land_marks = [] land_marks = []
print('reading lms...') logger.info('reading lms...')
for lms_path in tqdm(lms_list): for lms_path in tqdm(lms_list):
file_landmarks = [] # Store landmarks for this file file_landmarks = [] # Store landmarks for this file
with open(lms_path, "r") as f: with open(lms_path, "r") as f:
@ -152,7 +152,7 @@ def inference(quit_event, batch_size, face_list_cycle, audio_feat_queue, audio_o
index = 0 index = 0
count = 0 count = 0
counttime = 0 counttime = 0
print('start inference') logger.info('start inference')
while not quit_event.is_set(): while not quit_event.is_set():
starttime=time.perf_counter() starttime=time.perf_counter()
@ -206,7 +206,7 @@ def inference(quit_event, batch_size, face_list_cycle, audio_feat_queue, audio_o
counttime += (time.perf_counter() - t) counttime += (time.perf_counter() - t)
count += batch_size count += batch_size
if count >= 100: if count >= 100:
print(f"------actual avg infer fps:{count / counttime:.4f}") logger.info(f"------actual avg infer fps:{count / counttime:.4f}")
count = 0 count = 0
counttime = 0 counttime = 0
for i,res_frame in enumerate(pred): for i,res_frame in enumerate(pred):
@ -221,7 +221,7 @@ def inference(quit_event, batch_size, face_list_cycle, audio_feat_queue, audio_o
#print('total batch time:', time.perf_counter() - starttime) #print('total batch time:', time.perf_counter() - starttime)
print('lightreal inference processor stop') logger.info('lightreal inference processor stop')
class LightReal(BaseReal): class LightReal(BaseReal):
@ -248,7 +248,7 @@ class LightReal(BaseReal):
self.render_event = mp.Event() self.render_event = mp.Event()
def __del__(self): def __del__(self):
print(f'lightreal({self.sessionid}) delete') logger.info(f'lightreal({self.sessionid}) delete')
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
@ -302,7 +302,7 @@ class LightReal(BaseReal):
asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop)
self.record_audio_data(frame) self.record_audio_data(frame)
#self.notify(eventpoint) #self.notify(eventpoint)
print('lightreal process_frames thread stop') logger.info('lightreal process_frames thread stop')
def render(self,quit_event,loop=None,audio_track=None,video_track=None): def render(self,quit_event,loop=None,audio_track=None,video_track=None):
#if self.opt.asr: #if self.opt.asr:
@ -331,13 +331,13 @@ class LightReal(BaseReal):
# print('sleep qsize=',video_track._queue.qsize()) # print('sleep qsize=',video_track._queue.qsize())
# time.sleep(0.04*video_track._queue.qsize()*0.8) # time.sleep(0.04*video_track._queue.qsize()*0.8)
if video_track._queue.qsize()>=5: if video_track._queue.qsize()>=5:
print('sleep qsize=',video_track._queue.qsize()) logger.debug('sleep qsize=%d',video_track._queue.qsize())
time.sleep(0.04*video_track._queue.qsize()*0.8) time.sleep(0.04*video_track._queue.qsize()*0.8)
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0: # if delay > 0:
# time.sleep(delay) # time.sleep(delay)
#self.render_event.clear() #end infer process render #self.render_event.clear() #end infer process render
print('lightreal thread stop') logger.info('lightreal thread stop')

@ -42,9 +42,10 @@ from basereal import BaseReal
#from imgcache import ImgCache #from imgcache import ImgCache
from tqdm import tqdm from tqdm import tqdm
from logger import logger
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device)) logger.info('Using {} for inference.'.format(device))
def _load(checkpoint_path): def _load(checkpoint_path):
if device == 'cuda': if device == 'cuda':
@ -56,7 +57,7 @@ def _load(checkpoint_path):
def load_model(path): def load_model(path):
model = Wav2Lip() model = Wav2Lip()
print("Load checkpoint from: {}".format(path)) logger.info("Load checkpoint from: {}".format(path))
checkpoint = _load(path) checkpoint = _load(path)
s = checkpoint["state_dict"] s = checkpoint["state_dict"]
new_s = {} new_s = {}
@ -88,14 +89,14 @@ def load_avatar(avatar_id):
@torch.no_grad() @torch.no_grad()
def warm_up(batch_size,model,modelres): def warm_up(batch_size,model,modelres):
# 预热函数 # 预热函数
print('warmup model...') logger.info('warmup model...')
img_batch = torch.ones(batch_size, 6, modelres, modelres).to(device) img_batch = torch.ones(batch_size, 6, modelres, modelres).to(device)
mel_batch = torch.ones(batch_size, 1, 80, 16).to(device) mel_batch = torch.ones(batch_size, 1, 80, 16).to(device)
model(mel_batch, img_batch) model(mel_batch, img_batch)
def read_imgs(img_list): def read_imgs(img_list):
frames = [] frames = []
print('reading images...') logger.info('reading images...')
for img_path in tqdm(img_list): for img_path in tqdm(img_list):
frame = cv2.imread(img_path) frame = cv2.imread(img_path)
frames.append(frame) frames.append(frame)
@ -122,7 +123,7 @@ def inference(quit_event,batch_size,face_list_cycle,audio_feat_queue,audio_out_q
index = 0 index = 0
count=0 count=0
counttime=0 counttime=0
print('start inference') logger.info('start inference')
while not quit_event.is_set(): while not quit_event.is_set():
starttime=time.perf_counter() starttime=time.perf_counter()
mel_batch = [] mel_batch = []
@ -170,7 +171,7 @@ def inference(quit_event,batch_size,face_list_cycle,audio_feat_queue,audio_out_q
count += batch_size count += batch_size
#_totalframe += 1 #_totalframe += 1
if count>=100: if count>=100:
print(f"------actual avg infer fps:{count/counttime:.4f}") logger.info(f"------actual avg infer fps:{count/counttime:.4f}")
count=0 count=0
counttime=0 counttime=0
for i,res_frame in enumerate(pred): for i,res_frame in enumerate(pred):
@ -178,7 +179,7 @@ def inference(quit_event,batch_size,face_list_cycle,audio_feat_queue,audio_out_q
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2])) res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1 index = index + 1
#print('total batch time:',time.perf_counter()-starttime) #print('total batch time:',time.perf_counter()-starttime)
print('lipreal inference processor stop') logger.info('lipreal inference processor stop')
class LipReal(BaseReal): class LipReal(BaseReal):
@torch.no_grad() @torch.no_grad()
@ -203,7 +204,7 @@ class LipReal(BaseReal):
self.render_event = mp.Event() self.render_event = mp.Event()
def __del__(self): def __del__(self):
print(f'lipreal({self.sessionid}) delete') logger.info(f'lipreal({self.sessionid}) delete')
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
@ -256,7 +257,7 @@ class LipReal(BaseReal):
asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop)
self.record_audio_data(frame) self.record_audio_data(frame)
#self.notify(eventpoint) #self.notify(eventpoint)
print('lipreal process_frames thread stop') logger.info('lipreal process_frames thread stop')
def render(self,quit_event,loop=None,audio_track=None,video_track=None): def render(self,quit_event,loop=None,audio_track=None,video_track=None):
#if self.opt.asr: #if self.opt.asr:
@ -286,12 +287,12 @@ class LipReal(BaseReal):
# print('sleep qsize=',video_track._queue.qsize()) # print('sleep qsize=',video_track._queue.qsize())
# time.sleep(0.04*video_track._queue.qsize()*0.8) # time.sleep(0.04*video_track._queue.qsize()*0.8)
if video_track._queue.qsize()>=5: if video_track._queue.qsize()>=5:
print('sleep qsize=',video_track._queue.qsize()) logger.debug('sleep qsize=%d',video_track._queue.qsize())
time.sleep(0.04*video_track._queue.qsize()*0.8) time.sleep(0.04*video_track._queue.qsize()*0.8)
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0: # if delay > 0:
# time.sleep(delay) # time.sleep(delay)
#self.render_event.clear() #end infer process render #self.render_event.clear() #end infer process render
print('lipreal thread stop') logger.info('lipreal thread stop')

@ -0,0 +1,48 @@
import time
import os
from basereal import BaseReal
from logger import logger
def llm_response(message,nerfreal:BaseReal):
start = time.perf_counter()
from openai import OpenAI
client = OpenAI(
# 如果您没有配置环境变量请在此处用您的API Key进行替换
api_key=os.getenv("DASHSCOPE_API_KEY"),
# 填写DashScope SDK的base_url
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
end = time.perf_counter()
logger.info(f"llm Time init: {end-start}s")
completion = client.chat.completions.create(
model="qwen-plus",
messages=[{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': message}],
stream=True,
# 通过以下设置在流式输出的最后一行展示token使用信息
stream_options={"include_usage": True}
)
result=""
first = True
for chunk in completion:
if len(chunk.choices)>0:
#print(chunk.choices[0].delta.content)
if first:
end = time.perf_counter()
logger.info(f"llm Time to first chunk: {end-start}s")
first = False
msg = chunk.choices[0].delta.content
lastpos=0
#msglist = re.split('[,.!;:,。!?]',msg)
for i, char in enumerate(msg):
if char in ",.!;:,。!?:;" :
result = result+msg[lastpos:i+1]
lastpos = i+1
if len(result)>10:
logger.info(result)
nerfreal.put_msg_txt(result)
result=""
result = result+msg[lastpos:]
end = time.perf_counter()
logger.info(f"llm Time to last chunk: {end-start}s")
nerfreal.put_msg_txt(result)

@ -0,0 +1,16 @@
import logging
# 配置日志器
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fhandler = logging.FileHandler('livetalking.log') # 可以改为StreamHandler输出到控制台或多个Handler组合使用等。
fhandler.setFormatter(formatter)
fhandler.setLevel(logging.INFO)
logger.addHandler(fhandler)
handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG)
sformatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(sformatter)
logger.addHandler(handler)

@ -46,6 +46,7 @@ from av import AudioFrame, VideoFrame
from basereal import BaseReal from basereal import BaseReal
from tqdm import tqdm from tqdm import tqdm
from logger import logger
def load_model(): def load_model():
# load model weights # load model weights
@ -92,7 +93,7 @@ def load_avatar(avatar_id):
@torch.no_grad() @torch.no_grad()
def warm_up(batch_size,model): def warm_up(batch_size,model):
# 预热函数 # 预热函数
print('warmup model...') logger.info('warmup model...')
vae, unet, pe, timesteps, audio_processor = model vae, unet, pe, timesteps, audio_processor = model
#batch_size = 16 #batch_size = 16
#timesteps = torch.tensor([0], device=unet.device) #timesteps = torch.tensor([0], device=unet.device)
@ -110,7 +111,7 @@ def warm_up(batch_size,model):
def read_imgs(img_list): def read_imgs(img_list):
frames = [] frames = []
print('reading images...') logger.info('reading images...')
for img_path in tqdm(img_list): for img_path in tqdm(img_list):
frame = cv2.imread(img_path) frame = cv2.imread(img_path)
frames.append(frame) frames.append(frame)
@ -140,7 +141,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a
index = 0 index = 0
count=0 count=0
counttime=0 counttime=0
print('start inference') logger.info('start inference')
while render_event.is_set(): while render_event.is_set():
starttime=time.perf_counter() starttime=time.perf_counter()
try: try:
@ -195,7 +196,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a
count += batch_size count += batch_size
#_totalframe += 1 #_totalframe += 1
if count>=100: if count>=100:
print(f"------actual avg infer fps:{count/counttime:.4f}") logger.info(f"------actual avg infer fps:{count/counttime:.4f}")
count=0 count=0
counttime=0 counttime=0
for i,res_frame in enumerate(recon): for i,res_frame in enumerate(recon):
@ -203,7 +204,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2])) res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1 index = index + 1
#print('total batch time:',time.perf_counter()-starttime) #print('total batch time:',time.perf_counter()-starttime)
print('musereal inference processor stop') logger.info('musereal inference processor stop')
class MuseReal(BaseReal): class MuseReal(BaseReal):
@torch.no_grad() @torch.no_grad()
@ -229,7 +230,7 @@ class MuseReal(BaseReal):
self.render_event = mp.Event() self.render_event = mp.Event()
def __del__(self): def __del__(self):
print(f'musereal({self.sessionid}) delete') logger.info(f'musereal({self.sessionid}) delete')
def __mirror_index(self, index): def __mirror_index(self, index):
@ -251,7 +252,7 @@ class MuseReal(BaseReal):
latent = self.input_latent_list_cycle[idx] latent = self.input_latent_list_cycle[idx]
latent_batch.append(latent) latent_batch.append(latent)
latent_batch = torch.cat(latent_batch, dim=0) latent_batch = torch.cat(latent_batch, dim=0)
print('infer=======') logger.info('infer=======')
# for i, (whisper_batch,latent_batch) in enumerate(gen): # for i, (whisper_batch,latent_batch) in enumerate(gen):
audio_feature_batch = torch.from_numpy(whisper_batch) audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=self.unet.device, audio_feature_batch = audio_feature_batch.to(device=self.unet.device,
@ -317,7 +318,7 @@ class MuseReal(BaseReal):
self.record_audio_data(frame) self.record_audio_data(frame)
#self.notify(eventpoint) #self.notify(eventpoint)
#self.recordq_audio.put(new_frame) #self.recordq_audio.put(new_frame)
print('musereal process_frames thread stop') logger.info('musereal process_frames thread stop')
def render(self,quit_event,loop=None,audio_track=None,video_track=None): def render(self,quit_event,loop=None,audio_track=None,video_track=None):
#if self.opt.asr: #if self.opt.asr:
@ -349,7 +350,7 @@ class MuseReal(BaseReal):
# count=0 # count=0
# totaltime=0 # totaltime=0
if video_track._queue.qsize()>=1.5*self.opt.batch_size: if video_track._queue.qsize()>=1.5*self.opt.batch_size:
print('sleep qsize=',video_track._queue.qsize()) logger.debug('sleep qsize=%d',video_track._queue.qsize())
time.sleep(0.04*video_track._queue.qsize()*0.8) time.sleep(0.04*video_track._queue.qsize()*0.8)
# if video_track._queue.qsize()>=5: # if video_track._queue.qsize()>=5:
# print('sleep qsize=',video_track._queue.qsize()) # print('sleep qsize=',video_track._queue.qsize())
@ -359,5 +360,5 @@ class MuseReal(BaseReal):
# if delay > 0: # if delay > 0:
# time.sleep(delay) # time.sleep(delay)
self.render_event.clear() #end infer process render self.render_event.clear() #end infer process render
print('musereal thread stop') logger.info('musereal thread stop')

@ -38,10 +38,11 @@ from ernerf.nerf_triplane.utils import *
from ernerf.nerf_triplane.network import NeRFNetwork from ernerf.nerf_triplane.network import NeRFNetwork
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
from logger import logger
from tqdm import tqdm from tqdm import tqdm
def read_imgs(img_list): def read_imgs(img_list):
frames = [] frames = []
print('reading images...') logger.info('reading images...')
for img_path in tqdm(img_list): for img_path in tqdm(img_list):
frame = cv2.imread(img_path) frame = cv2.imread(img_path)
frames.append(frame) frames.append(frame)
@ -74,21 +75,21 @@ def load_model(opt):
# assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss." # assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays." assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."
seed_everything(opt.seed) seed_everything(opt.seed)
print(opt) logger.info(opt)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = NeRFNetwork(opt) model = NeRFNetwork(opt)
criterion = torch.nn.MSELoss(reduction='none') criterion = torch.nn.MSELoss(reduction='none')
metrics = [] # use no metric in GUI for faster initialization... metrics = [] # use no metric in GUI for faster initialization...
print(model) logger.info(model)
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt) trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
test_loader = NeRFDataset_Test(opt, device=device).dataloader() test_loader = NeRFDataset_Test(opt, device=device).dataloader()
model.aud_features = test_loader._data.auds model.aud_features = test_loader._data.auds
model.eye_areas = test_loader._data.eye_area model.eye_areas = test_loader._data.eye_area
print(f'[INFO] loading ASR model {opt.asr_model}...') logger.info(f'[INFO] loading ASR model {opt.asr_model}...')
if 'hubert' in opt.asr_model: if 'hubert' in opt.asr_model:
audio_processor = Wav2Vec2Processor.from_pretrained(opt.asr_model) audio_processor = Wav2Vec2Processor.from_pretrained(opt.asr_model)
audio_model = HubertModel.from_pretrained(opt.asr_model).to(device) audio_model = HubertModel.from_pretrained(opt.asr_model).to(device)
@ -197,7 +198,7 @@ class NeRFReal(BaseReal):
''' '''
def __del__(self): def __del__(self):
print(f'nerfreal({self.sessionid}) delete') logger.info(f'nerfreal({self.sessionid}) delete')
def __enter__(self): def __enter__(self):
return self return self
@ -365,7 +366,7 @@ class NeRFReal(BaseReal):
count += 1 count += 1
_totalframe += 1 _totalframe += 1
if count==100: if count==100:
print(f"------actual avg infer fps:{count/totaltime:.4f}") logger.info(f"------actual avg infer fps:{count/totaltime:.4f}")
count=0 count=0
totaltime=0 totaltime=0
if self.opt.transport=='rtmp': if self.opt.transport=='rtmp':
@ -376,6 +377,6 @@ class NeRFReal(BaseReal):
if video_track._queue.qsize()>=5: if video_track._queue.qsize()>=5:
#print('sleep qsize=',video_track._queue.qsize()) #print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.04*video_track._queue.qsize()*0.8) time.sleep(0.04*video_track._queue.qsize()*0.8)
print('nerfreal thread stop') logger.info('nerfreal thread stop')

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
############################################################################### ###############################################################################
import os from __future__ import annotations
import time import time
import numpy as np import numpy as np
import soundfile as sf import soundfile as sf
@ -32,12 +32,17 @@ from io import BytesIO
from threading import Thread, Event from threading import Thread, Event
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from basereal import BaseReal
from logger import logger
class State(Enum): class State(Enum):
RUNNING=0 RUNNING=0
PAUSE=1 PAUSE=1
class BaseTTS: class BaseTTS:
def __init__(self, opt, parent): def __init__(self, opt, parent:BaseReal):
self.opt=opt self.opt=opt
self.parent = parent self.parent = parent
@ -53,7 +58,7 @@ class BaseTTS:
self.msgqueue.queue.clear() self.msgqueue.queue.clear()
self.state = State.PAUSE self.state = State.PAUSE
def put_msg_txt(self,msg,eventpoint=None): def put_msg_txt(self,msg:str,eventpoint=None):
if len(msg)>0: if len(msg)>0:
self.msgqueue.put((msg,eventpoint)) self.msgqueue.put((msg,eventpoint))
@ -69,7 +74,7 @@ class BaseTTS:
except queue.Empty: except queue.Empty:
continue continue
self.txt_to_audio(msg) self.txt_to_audio(msg)
print('ttsreal thread stop') logger.info('ttsreal thread stop')
def txt_to_audio(self,msg): def txt_to_audio(self,msg):
pass pass
@ -82,9 +87,9 @@ class EdgeTTS(BaseTTS):
text,textevent = msg text,textevent = msg
t = time.time() t = time.time()
asyncio.new_event_loop().run_until_complete(self.__main(voicename,text)) asyncio.new_event_loop().run_until_complete(self.__main(voicename,text))
print(f'-------edge tts time:{time.time()-t:.4f}s') logger.info(f'-------edge tts time:{time.time()-t:.4f}s')
if self.input_stream.getbuffer().nbytes<=0: #edgetts err if self.input_stream.getbuffer().nbytes<=0: #edgetts err
print('edgetts err!!!!!') logger.error('edgetts err!!!!!')
return return
self.input_stream.seek(0) self.input_stream.seek(0)
@ -108,15 +113,15 @@ class EdgeTTS(BaseTTS):
def __create_bytes_stream(self,byte_stream): def __create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer) #byte_stream=BytesIO(buffer)
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') logger.info(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32) stream = stream.astype(np.float32)
if stream.ndim > 1: if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0] stream = stream[:, 0]
if sample_rate != self.sample_rate and stream.shape[0]>0: if sample_rate != self.sample_rate and stream.shape[0]>0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
return stream return stream
@ -137,7 +142,7 @@ class EdgeTTS(BaseTTS):
elif chunk["type"] == "WordBoundary": elif chunk["type"] == "WordBoundary":
pass pass
except Exception as e: except Exception as e:
print(e) logger.exception('edgetts')
########################################################################################### ###########################################################################################
class FishTTS(BaseTTS): class FishTTS(BaseTTS):
@ -173,10 +178,10 @@ class FishTTS(BaseTTS):
}, },
) )
end = time.perf_counter() end = time.perf_counter()
print(f"fish_speech Time to make POST: {end-start}s") logger.info(f"fish_speech Time to make POST: {end-start}s")
if res.status_code != 200: if res.status_code != 200:
print("Error:", res.text) logger.error("Error:%s", res.text)
return return
first = True first = True
@ -185,13 +190,13 @@ class FishTTS(BaseTTS):
#print('chunk len:',len(chunk)) #print('chunk len:',len(chunk))
if first: if first:
end = time.perf_counter() end = time.perf_counter()
print(f"fish_speech Time to first chunk: {end-start}s") logger.info(f"fish_speech Time to first chunk: {end-start}s")
first = False first = False
if chunk and self.state==State.RUNNING: if chunk and self.state==State.RUNNING:
yield chunk yield chunk
#print("gpt_sovits response.elapsed:", res.elapsed) #print("gpt_sovits response.elapsed:", res.elapsed)
except Exception as e: except Exception as e:
print(e) logger.exception('fishtts')
def stream_tts(self,audio_stream,msg): def stream_tts(self,audio_stream,msg):
text,textevent = msg text,textevent = msg
@ -254,38 +259,38 @@ class VoitsTTS(BaseTTS):
stream=True, stream=True,
) )
end = time.perf_counter() end = time.perf_counter()
print(f"gpt_sovits Time to make POST: {end-start}s") logger.info(f"gpt_sovits Time to make POST: {end-start}s")
if res.status_code != 200: if res.status_code != 200:
print("Error:", res.text) logger.error("Error:%s", res.text)
return return
first = True first = True
for chunk in res.iter_content(chunk_size=None): #12800 1280 32K*20ms*2 for chunk in res.iter_content(chunk_size=None): #12800 1280 32K*20ms*2
print('chunk len:',len(chunk)) logger.info('chunk len:%d',len(chunk))
if first: if first:
end = time.perf_counter() end = time.perf_counter()
print(f"gpt_sovits Time to first chunk: {end-start}s") logger.info(f"gpt_sovits Time to first chunk: {end-start}s")
first = False first = False
if chunk and self.state==State.RUNNING: if chunk and self.state==State.RUNNING:
yield chunk yield chunk
#print("gpt_sovits response.elapsed:", res.elapsed) #print("gpt_sovits response.elapsed:", res.elapsed)
except Exception as e: except Exception as e:
print(e) logger.exception('sovits')
def __create_bytes_stream(self,byte_stream): def __create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer) #byte_stream=BytesIO(buffer)
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') logger.info(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32) stream = stream.astype(np.float32)
if stream.ndim > 1: if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0] stream = stream[:, 0]
if sample_rate != self.sample_rate and stream.shape[0]>0: if sample_rate != self.sample_rate and stream.shape[0]>0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
return stream return stream
@ -338,10 +343,10 @@ class CosyVoiceTTS(BaseTTS):
res = requests.request("GET", f"{server_url}/inference_zero_shot", data=payload, files=files, stream=True) res = requests.request("GET", f"{server_url}/inference_zero_shot", data=payload, files=files, stream=True)
end = time.perf_counter() end = time.perf_counter()
print(f"cosy_voice Time to make POST: {end-start}s") logger.info(f"cosy_voice Time to make POST: {end-start}s")
if res.status_code != 200: if res.status_code != 200:
print("Error:", res.text) logger.error("Error:%s", res.text)
return return
first = True first = True
@ -349,12 +354,12 @@ class CosyVoiceTTS(BaseTTS):
for chunk in res.iter_content(chunk_size=8820): # 882 22.05K*20ms*2 for chunk in res.iter_content(chunk_size=8820): # 882 22.05K*20ms*2
if first: if first:
end = time.perf_counter() end = time.perf_counter()
print(f"cosy_voice Time to first chunk: {end-start}s") logger.info(f"cosy_voice Time to first chunk: {end-start}s")
first = False first = False
if chunk and self.state==State.RUNNING: if chunk and self.state==State.RUNNING:
yield chunk yield chunk
except Exception as e: except Exception as e:
print(e) logger.exception('cosyvoice')
def stream_tts(self,audio_stream,msg): def stream_tts(self,audio_stream,msg):
text,textevent = msg text,textevent = msg
@ -414,7 +419,7 @@ class XTTS(BaseTTS):
stream=True, stream=True,
) )
end = time.perf_counter() end = time.perf_counter()
print(f"xtts Time to make POST: {end-start}s") logger.info(f"xtts Time to make POST: {end-start}s")
if res.status_code != 200: if res.status_code != 200:
print("Error:", res.text) print("Error:", res.text)
@ -425,7 +430,7 @@ class XTTS(BaseTTS):
for chunk in res.iter_content(chunk_size=9600): #24K*20ms*2 for chunk in res.iter_content(chunk_size=9600): #24K*20ms*2
if first: if first:
end = time.perf_counter() end = time.perf_counter()
print(f"xtts Time to first chunk: {end-start}s") logger.info(f"xtts Time to first chunk: {end-start}s")
first = False first = False
if chunk: if chunk:
yield chunk yield chunk

@ -40,8 +40,9 @@ from aiortc import (
MediaStreamTrack, MediaStreamTrack,
) )
logging.basicConfig() #logging.basicConfig()
logger = logging.getLogger(__name__) #logger = logging.getLogger(__name__)
from logger import logger
class PlayerStreamTrack(MediaStreamTrack): class PlayerStreamTrack(MediaStreamTrack):
@ -82,7 +83,7 @@ class PlayerStreamTrack(MediaStreamTrack):
self._start = time.time() self._start = time.time()
self._timestamp = 0 self._timestamp = 0
self.timelist.append(self._start) self.timelist.append(self._start)
print('video start:',self._start) logger.info('video start:%f',self._start)
return self._timestamp, VIDEO_TIME_BASE return self._timestamp, VIDEO_TIME_BASE
else: #audio else: #audio
if hasattr(self, "_timestamp"): if hasattr(self, "_timestamp"):
@ -100,7 +101,7 @@ class PlayerStreamTrack(MediaStreamTrack):
self._start = time.time() self._start = time.time()
self._timestamp = 0 self._timestamp = 0
self.timelist.append(self._start) self.timelist.append(self._start)
print('audio start:',self._start) logger.info('audio start:%f',self._start)
return self._timestamp, AUDIO_TIME_BASE return self._timestamp, AUDIO_TIME_BASE
async def recv(self) -> Union[Frame, Packet]: async def recv(self) -> Union[Frame, Packet]:
@ -136,7 +137,7 @@ class PlayerStreamTrack(MediaStreamTrack):
self.framecount += 1 self.framecount += 1
self.lasttime = time.perf_counter() self.lasttime = time.perf_counter()
if self.framecount==100: if self.framecount==100:
print(f"------actual avg final fps:{self.framecount/self.totaltime:.4f}") logger.info(f"------actual avg final fps:{self.framecount/self.totaltime:.4f}")
self.framecount = 0 self.framecount = 0
self.totaltime=0 self.totaltime=0
return frame return frame

Loading…
Cancel
Save