Implement process_frame in the parent class BaseReal and add virtual camera output

main
lipku 2 months ago
parent 8ab089f3a0
commit 9e5a56678b

@ -13,27 +13,28 @@
- 2025.3.2 添加腾讯语音合成服务 - 2025.3.2 添加腾讯语音合成服务
- 2025.3.16 支持mac gpu推理感谢[@GcsSloop](https://github.com/GcsSloop) - 2025.3.16 支持mac gpu推理感谢[@GcsSloop](https://github.com/GcsSloop)
- 2025.5.1 精简运行参数ernerf模型移至git分支ernerf-rtmp - 2025.5.1 精简运行参数ernerf模型移至git分支ernerf-rtmp
- 2025.6.7 添加虚拟摄像头输出
## Features ## Features
1. 支持多种数字人模型: ernerf、musetalk、wav2lip、Ultralight-Digital-Human 1. 支持多种数字人模型: ernerf、musetalk、wav2lip、Ultralight-Digital-Human
2. 支持声音克隆 2. 支持声音克隆
3. 支持数字人说话被打断 3. 支持数字人说话被打断
4. 支持全身视频拼接 4. 支持全身视频拼接
5. 支持rtmp和webrtc 5. 支持webrtc、虚拟摄像头输出
6. 支持视频编排:不说话时播放自定义视频 6. 支持动作编排:不说话时播放自定义视频
7. 支持多并发 7. 支持多并发
## 1. Installation ## 1. Installation
Tested on Ubuntu 20.04, Python3.10, Pytorch 1.12 and CUDA 11.3 Tested on Ubuntu 24.04, Python3.10, Pytorch 2.5.0 and CUDA 12.4
### 1.1 Install dependency ### 1.1 Install dependency
```bash ```bash
conda create -n nerfstream python=3.10 conda create -n nerfstream python=3.10
conda activate nerfstream conda activate nerfstream
#如果cuda版本不为11.3(运行nvidia-smi确认版本),根据<https://pytorch.org/get-started/previous-versions/>安装对应版本的pytorch #如果cuda版本不为12.4(运行nvidia-smi确认版本),根据<https://pytorch.org/get-started/previous-versions/>安装对应版本的pytorch
conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch conda install pytorch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 pytorch-cuda=12.4 -c pytorch -c nvidia
pip install -r requirements.txt pip install -r requirements.txt
#如果需要训练ernerf模型安装下面的库 #如果需要训练ernerf模型安装下面的库
# pip install "git+https://github.com/facebookresearch/pytorch3d.git" # pip install "git+https://github.com/facebookresearch/pytorch3d.git"

@ -284,7 +284,7 @@ if __name__ == '__main__':
parser.add_argument('--model', type=str, default='musetalk') #musetalk wav2lip ultralight parser.add_argument('--model', type=str, default='musetalk') #musetalk wav2lip ultralight
parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush parser.add_argument('--transport', type=str, default='rtcpush') #webrtc rtcpush virtualcam
parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
parser.add_argument('--max_session', type=int, default=1) #multi session count parser.add_argument('--max_session', type=int, default=1) #multi session count
@ -326,6 +326,11 @@ if __name__ == '__main__':
# nerfreals[0] = build_nerfreal(0) # nerfreals[0] = build_nerfreal(0)
# rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,)) # rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
# rendthrd.start() # rendthrd.start()
if opt.transport=='virtualcam':
thread_quit = Event()
nerfreals[0] = build_nerfreal(0)
rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
rendthrd.start()
############################################################################# #############################################################################
appasync = web.Application(client_max_size=1024**2*100) appasync = web.Application(client_max_size=1024**2*100)

@ -32,6 +32,9 @@ from threading import Thread, Event
from io import BytesIO from io import BytesIO
import soundfile as sf import soundfile as sf
import asyncio
from av import AudioFrame, VideoFrame
import av import av
from fractions import Fraction from fractions import Fraction
@ -47,6 +50,23 @@ def read_imgs(img_list):
frames.append(frame) frames.append(frame)
return frames return frames
def play_audio(quit_event,queue):
import pyaudio
p = pyaudio.PyAudio()
stream = p.open(
rate=16000,
channels=1,
format=8,
output=True,
output_device_index=1,
)
stream.start_stream()
# while queue.qsize() <= 0:
# time.sleep(0.1)
while not quit_event.is_set():
stream.write(queue.get(block=True))
stream.close()
class BaseReal: class BaseReal:
def __init__(self, opt): def __init__(self, opt):
self.opt = opt self.opt = opt
@ -269,6 +289,109 @@ class BaseReal:
self.custom_audio_index[audiotype] = 0 self.custom_audio_index[audiotype] = 0
self.custom_index[audiotype] = 0 self.custom_index[audiotype] = 0
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
enable_transition = False # 设置为False禁用过渡效果True启用
if enable_transition:
_last_speaking = False
_transition_start = time.time()
_transition_duration = 0.1 # 过渡时间
_last_silent_frame = None # 静音帧缓存
_last_speaking_frame = None # 说话帧缓存
if self.opt.transport=='virtualcam':
import pyvirtualcam
vircam = None
audio_tmp = queue.Queue(maxsize=3000)
audio_thread = Thread(target=play_audio, args=(quit_event,audio_tmp,), daemon=True, name="pyaudio_stream")
audio_thread.start()
while not quit_event.is_set():
try:
res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1)
except queue.Empty:
continue
if enable_transition:
# 检测状态变化
current_speaking = not (audio_frames[0][1]!=0 and audio_frames[1][1]!=0)
if current_speaking != _last_speaking:
logger.info(f"状态切换:{'说话' if _last_speaking else '静音'}{'说话' if current_speaking else '静音'}")
_transition_start = time.time()
_last_speaking = current_speaking
if audio_frames[0][1]!=0 and audio_frames[1][1]!=0: #全为静音数据只需要取fullimg
self.speaking = False
audiotype = audio_frames[0][1]
if self.custom_index.get(audiotype) is not None: #有自定义视频
mirindex = self.mirror_index(len(self.custom_img_cycle[audiotype]),self.custom_index[audiotype])
target_frame = self.custom_img_cycle[audiotype][mirindex]
self.custom_index[audiotype] += 1
else:
target_frame = self.frame_list_cycle[idx]
if enable_transition:
# 说话→静音过渡
if time.time() - _transition_start < _transition_duration and _last_speaking_frame is not None:
alpha = min(1.0, (time.time() - _transition_start) / _transition_duration)
combine_frame = cv2.addWeighted(_last_speaking_frame, 1-alpha, target_frame, alpha, 0)
else:
combine_frame = target_frame
# 缓存静音帧
_last_silent_frame = combine_frame.copy()
else:
combine_frame = target_frame
else:
self.speaking = True
try:
current_frame = self.paste_back_frame(res_frame,idx)
except Exception as e:
logger.warning(f"paste_back_frame error: {e}")
continue
if enable_transition:
# 静音→说话过渡
if time.time() - _transition_start < _transition_duration and _last_silent_frame is not None:
alpha = min(1.0, (time.time() - _transition_start) / _transition_duration)
combine_frame = cv2.addWeighted(_last_silent_frame, 1-alpha, current_frame, alpha, 0)
else:
combine_frame = current_frame
# 缓存说话帧
_last_speaking_frame = combine_frame.copy()
else:
combine_frame = current_frame
if self.opt.transport=='virtualcam':
if vircam==None:
height, width,_= combine_frame.shape
vircam = pyvirtualcam.Camera(width=width, height=height, fps=25, fmt=pyvirtualcam.PixelFormat.BGR,print_fps=True)
vircam.send(combine_frame)
else: #webrtc
image = combine_frame
image[0,:] &= 0xFE
new_frame = VideoFrame.from_ndarray(image, format="bgr24")
asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop)
self.record_video_data(combine_frame)
for audio_frame in audio_frames:
frame,type,eventpoint = audio_frame
frame = (frame * 32767).astype(np.int16)
if self.opt.transport=='virtualcam':
audio_tmp.put(frame.tobytes()) #TODO
else: #webrtc
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
new_frame.planes[0].update(frame.tobytes())
new_frame.sample_rate=16000
asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop)
self.record_audio_data(frame)
if self.opt.transport=='virtualcam':
vircam.sleep_until_next_frame()
if self.opt.transport=='virtualcam':
audio_thread.join()
vircam.close()
logger.info('basereal process_frames thread stop')
# def process_custom(self,audiotype:int,idx:int): # def process_custom(self,audiotype:int,idx:int):
# if self.curr_state!=audiotype: #从推理切到口播 # if self.curr_state!=audiotype: #从推理切到口播
# if idx in self.switch_pos: #在卡点位置可以切换 # if idx in self.switch_pos: #在卡点位置可以切换

@ -248,28 +248,7 @@ class LightReal(BaseReal):
def __del__(self): def __del__(self):
logger.info(f'lightreal({self.sessionid}) delete') logger.info(f'lightreal({self.sessionid}) delete')
def paste_back_frame(self,pred_frame,idx:int):
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
while not quit_event.is_set():
try:
res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1)
except queue.Empty:
continue
if audio_frames[0][1]!=0 and audio_frames[1][1]!=0: #全为静音数据只需要取fullimg
self.speaking = False
audiotype = audio_frames[0][1]
if self.custom_index.get(audiotype) is not None: #有自定义视频
mirindex = self.mirror_index(len(self.custom_img_cycle[audiotype]),self.custom_index[audiotype])
combine_frame = self.custom_img_cycle[audiotype][mirindex]
self.custom_index[audiotype] += 1
# if not self.custom_opt[audiotype].loop and self.custom_index[audiotype]>=len(self.custom_img_cycle[audiotype]):
# self.curr_state = 1 #当前视频不循环播放,切换到静音状态
else:
combine_frame = self.frame_list_cycle[idx]
#combine_frame = self.imagecache.get_img(idx)
else:
self.speaking = True
bbox = self.coord_list_cycle[idx] bbox = self.coord_list_cycle[idx]
combine_frame = copy.deepcopy(self.frame_list_cycle[idx]) combine_frame = copy.deepcopy(self.frame_list_cycle[idx])
x1, y1, x2, y2 = bbox x1, y1, x2, y2 = bbox
@ -277,31 +256,11 @@ class LightReal(BaseReal):
crop_img = self.face_list_cycle[idx] crop_img = self.face_list_cycle[idx]
crop_img_ori = crop_img.copy() crop_img_ori = crop_img.copy()
#res_frame = np.array(res_frame, dtype=np.uint8) #res_frame = np.array(res_frame, dtype=np.uint8)
try:
crop_img_ori[4:164, 4:164] = res_frame.astype(np.uint8) crop_img_ori[4:164, 4:164] = pred_frame.astype(np.uint8)
crop_img_ori = cv2.resize(crop_img_ori, (x2-x1,y2-y1)) crop_img_ori = cv2.resize(crop_img_ori, (x2-x1,y2-y1))
except:
continue
combine_frame[y1:y2, x1:x2] = crop_img_ori combine_frame[y1:y2, x1:x2] = crop_img_ori
#print('blending time:',time.perf_counter()-t) return combine_frame
combine_frame[0,:] &= 0xFE
new_frame = VideoFrame.from_ndarray(combine_frame, format="bgr24")
asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop)
self.record_video_data(combine_frame)
for audio_frame in audio_frames:
frame,type_,eventpoint = audio_frame
frame = (frame * 32767).astype(np.int16)
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
new_frame.planes[0].update(frame.tobytes())
new_frame.sample_rate=16000
# if audio_track._queue.qsize()>10:
# time.sleep(0.1)
asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop)
self.record_audio_data(frame)
#self.notify(eventpoint)
logger.info('lightreal process_frames thread stop')
def render(self,quit_event,loop=None,audio_track=None,video_track=None): def render(self,quit_event,loop=None,audio_track=None,video_track=None):
#if self.opt.asr: #if self.opt.asr:
@ -329,7 +288,7 @@ class LightReal(BaseReal):
# if video_track._queue.qsize()>=2*self.opt.batch_size: # if video_track._queue.qsize()>=2*self.opt.batch_size:
# print('sleep qsize=',video_track._queue.qsize()) # print('sleep qsize=',video_track._queue.qsize())
# time.sleep(0.04*video_track._queue.qsize()*0.8) # time.sleep(0.04*video_track._queue.qsize()*0.8)
if video_track._queue.qsize()>=5: if video_track and video_track._queue.qsize()>=5:
logger.debug('sleep qsize=%d',video_track._queue.qsize()) logger.debug('sleep qsize=%d',video_track._queue.qsize())
time.sleep(0.04*video_track._queue.qsize()*0.8) time.sleep(0.04*video_track._queue.qsize()*0.8)

@ -206,59 +206,16 @@ class LipReal(BaseReal):
def __del__(self): def __del__(self):
logger.info(f'lipreal({self.sessionid}) delete') logger.info(f'lipreal({self.sessionid}) delete')
def paste_back_frame(self,pred_frame,idx:int):
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
while not quit_event.is_set():
try:
res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1)
except queue.Empty:
continue
if audio_frames[0][1]!=0 and audio_frames[1][1]!=0: #全为静音数据只需要取fullimg
self.speaking = False
audiotype = audio_frames[0][1]
if self.custom_index.get(audiotype) is not None: #有自定义视频
mirindex = self.mirror_index(len(self.custom_img_cycle[audiotype]),self.custom_index[audiotype])
combine_frame = self.custom_img_cycle[audiotype][mirindex]
self.custom_index[audiotype] += 1
# if not self.custom_opt[audiotype].loop and self.custom_index[audiotype]>=len(self.custom_img_cycle[audiotype]):
# self.curr_state = 1 #当前视频不循环播放,切换到静音状态
else:
combine_frame = self.frame_list_cycle[idx]
#combine_frame = self.imagecache.get_img(idx)
else:
self.speaking = True
bbox = self.coord_list_cycle[idx] bbox = self.coord_list_cycle[idx]
combine_frame = copy.deepcopy(self.frame_list_cycle[idx]) combine_frame = copy.deepcopy(self.frame_list_cycle[idx])
#combine_frame = copy.deepcopy(self.imagecache.get_img(idx)) #combine_frame = copy.deepcopy(self.imagecache.get_img(idx))
y1, y2, x1, x2 = bbox y1, y2, x1, x2 = bbox
try: res_frame = cv2.resize(pred_frame.astype(np.uint8),(x2-x1,y2-y1))
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
except:
continue
#combine_frame = get_image(ori_frame,res_frame,bbox) #combine_frame = get_image(ori_frame,res_frame,bbox)
#t=time.perf_counter() #t=time.perf_counter()
combine_frame[y1:y2, x1:x2] = res_frame combine_frame[y1:y2, x1:x2] = res_frame
#print('blending time:',time.perf_counter()-t) return combine_frame
image = combine_frame #(outputs['image'] * 255).astype(np.uint8)
image[0,:] &= 0xFE
new_frame = VideoFrame.from_ndarray(image, format="bgr24")
asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop)
self.record_video_data(image)
for audio_frame in audio_frames:
frame,type,eventpoint = audio_frame
frame = (frame * 32767).astype(np.int16)
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
new_frame.planes[0].update(frame.tobytes())
new_frame.sample_rate=16000
# if audio_track._queue.qsize()>10:
# time.sleep(0.1)
asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop)
self.record_audio_data(frame)
#self.notify(eventpoint)
logger.info('lipreal process_frames thread stop')
def render(self,quit_event,loop=None,audio_track=None,video_track=None): def render(self,quit_event,loop=None,audio_track=None,video_track=None):
#if self.opt.asr: #if self.opt.asr:
@ -287,7 +244,7 @@ class LipReal(BaseReal):
# if video_track._queue.qsize()>=2*self.opt.batch_size: # if video_track._queue.qsize()>=2*self.opt.batch_size:
# print('sleep qsize=',video_track._queue.qsize()) # print('sleep qsize=',video_track._queue.qsize())
# time.sleep(0.04*video_track._queue.qsize()*0.8) # time.sleep(0.04*video_track._queue.qsize()*0.8)
if video_track._queue.qsize()>=5: if video_track and video_track._queue.qsize()>=5:
logger.debug('sleep qsize=%d',video_track._queue.qsize()) logger.debug('sleep qsize=%d',video_track._queue.qsize())
time.sleep(0.04*video_track._queue.qsize()*0.8) time.sleep(0.04*video_track._queue.qsize()*0.8)

@ -266,92 +266,17 @@ class MuseReal(BaseReal):
recon = self.vae.decode_latents(pred_latents) recon = self.vae.decode_latents(pred_latents)
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): def paste_back_frame(self,pred_frame,idx:int):
enable_transition = True # 设置为False禁用过渡效果True启用
if enable_transition:
self.last_speaking = False
self.transition_start = time.time()
self.transition_duration = 0.1 # 过渡时间
self.last_silent_frame = None # 静音帧缓存
self.last_speaking_frame = None # 说话帧缓存
while not quit_event.is_set():
try:
res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1)
except queue.Empty:
continue
if enable_transition:
# 检测状态变化
current_speaking = not (audio_frames[0][1]!=0 and audio_frames[1][1]!=0)
if current_speaking != self.last_speaking:
logger.info(f"状态切换:{'说话' if self.last_speaking else '静音'}{'说话' if current_speaking else '静音'}")
self.transition_start = time.time()
self.last_speaking = current_speaking
if audio_frames[0][1]!=0 and audio_frames[1][1]!=0:
self.speaking = False
audiotype = audio_frames[0][1]
if self.custom_index.get(audiotype) is not None:
mirindex = self.mirror_index(len(self.custom_img_cycle[audiotype]),self.custom_index[audiotype])
target_frame = self.custom_img_cycle[audiotype][mirindex]
self.custom_index[audiotype] += 1
else:
target_frame = self.frame_list_cycle[idx]
if enable_transition:
# 说话→静音过渡
if time.time() - self.transition_start < self.transition_duration and self.last_speaking_frame is not None:
alpha = min(1.0, (time.time() - self.transition_start) / self.transition_duration)
combine_frame = cv2.addWeighted(self.last_speaking_frame, 1-alpha, target_frame, alpha, 0)
else:
combine_frame = target_frame
# 缓存静音帧
self.last_silent_frame = combine_frame.copy()
else:
combine_frame = target_frame
else:
self.speaking = True
bbox = self.coord_list_cycle[idx] bbox = self.coord_list_cycle[idx]
ori_frame = copy.deepcopy(self.frame_list_cycle[idx]) ori_frame = copy.deepcopy(self.frame_list_cycle[idx])
x1, y1, x2, y2 = bbox x1, y1, x2, y2 = bbox
try:
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) res_frame = cv2.resize(pred_frame.astype(np.uint8),(x2-x1,y2-y1))
except Exception as e:
logger.warning(f"resize error: {e}")
continue
mask = self.mask_list_cycle[idx] mask = self.mask_list_cycle[idx]
mask_crop_box = self.mask_coords_list_cycle[idx] mask_crop_box = self.mask_coords_list_cycle[idx]
current_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box) combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
if enable_transition: return combine_frame
# 静音→说话过渡
if time.time() - self.transition_start < self.transition_duration and self.last_silent_frame is not None:
alpha = min(1.0, (time.time() - self.transition_start) / self.transition_duration)
combine_frame = cv2.addWeighted(self.last_silent_frame, 1-alpha, current_frame, alpha, 0)
else:
combine_frame = current_frame
# 缓存说话帧
self.last_speaking_frame = combine_frame.copy()
else:
combine_frame = current_frame
image = combine_frame
image[0,:] &= 0xFE
new_frame = VideoFrame.from_ndarray(image, format="bgr24")
asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop)
self.record_video_data(image)
for audio_frame in audio_frames:
frame,type,eventpoint = audio_frame
frame = (frame * 32767).astype(np.int16)
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
new_frame.planes[0].update(frame.tobytes())
new_frame.sample_rate=16000
asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop)
self.record_audio_data(frame)
logger.info('musereal process_frames thread stop')
def render(self,quit_event,loop=None,audio_track=None,video_track=None): def render(self,quit_event,loop=None,audio_track=None,video_track=None):
#if self.opt.asr: #if self.opt.asr:
@ -382,7 +307,7 @@ class MuseReal(BaseReal):
# print(f"------actual avg infer fps:{count/totaltime:.4f}") # print(f"------actual avg infer fps:{count/totaltime:.4f}")
# count=0 # count=0
# totaltime=0 # totaltime=0
if video_track._queue.qsize()>=1.5*self.opt.batch_size: if video_track and video_track._queue.qsize()>=1.5*self.opt.batch_size:
logger.debug('sleep qsize=%d',video_track._queue.qsize()) logger.debug('sleep qsize=%d',video_track._queue.qsize())
time.sleep(0.04*video_track._queue.qsize()*0.8) time.sleep(0.04*video_track._queue.qsize()*0.8)
# if video_track._queue.qsize()>=5: # if video_track._queue.qsize()>=5:

Loading…
Cancel
Save