You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

384 lines
13 KiB
Python

###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
2 years ago
# server.py
1 year ago
from flask import Flask, render_template,send_from_directory,request, jsonify
2 years ago
from flask_sockets import Sockets
import base64
import json
#import gevent
#from gevent import pywsgi
#from geventwebsocket.handler import WebSocketHandler
2 years ago
import re
import numpy as np
1 year ago
from threading import Thread,Event
#import multiprocessing
import torch.multiprocessing as mp
2 years ago
1 year ago
from aiohttp import web
import aiohttp
import aiohttp_cors
1 year ago
from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.rtcrtpsender import RTCRtpSender
1 year ago
from webrtc import HumanPlayer
from basereal import BaseReal
from llm import llm_response
1 year ago
2 years ago
import argparse
import random
2 years ago
import shutil
import asyncio
import torch
from typing import Dict
from logger import logger
2 years ago
app = Flask(__name__)
#sockets = Sockets(app)
nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal
opt = None
model = None
avatar = None
1 year ago
#####webrtc###############################
pcs = set()
def randN(N)->int:
'''生成长度为 N的随机数 '''
min = pow(10, N - 1)
max = pow(10, N)
return random.randint(min, max - 1)
def build_nerfreal(sessionid:int)->BaseReal:
opt.sessionid=sessionid
if opt.model == 'wav2lip':
from lipreal import LipReal
nerfreal = LipReal(opt,model,avatar)
elif opt.model == 'musetalk':
from musereal import MuseReal
nerfreal = MuseReal(opt,model,avatar)
# elif opt.model == 'ernerf':
# from nerfreal import NeRFReal
# nerfreal = NeRFReal(opt,model,avatar)
elif opt.model == 'ultralight':
from lightreal import LightReal
nerfreal = LightReal(opt,model,avatar)
return nerfreal
1 year ago
#@app.route('/offer', methods=['POST'])
async def offer(request):
params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
if len(nerfreals) >= opt.max_session:
logger.info('reach max session')
return -1
sessionid = randN(6) #len(nerfreals)
logger.info('sessionid=%d',sessionid)
nerfreals[sessionid] = None
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
nerfreals[sessionid] = nerfreal
1 year ago
pc = RTCPeerConnection()
pcs.add(pc)
@pc.on("connectionstatechange")
async def on_connectionstatechange():
logger.info("Connection state is %s" % pc.connectionState)
1 year ago
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
del nerfreals[sessionid]
if pc.connectionState == "closed":
pcs.discard(pc)
del nerfreals[sessionid]
1 year ago
player = HumanPlayer(nerfreals[sessionid])
1 year ago
audio_sender = pc.addTrack(player.audio)
video_sender = pc.addTrack(player.video)
capabilities = RTCRtpSender.getCapabilities("video")
preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs))
preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs))
preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs))
transceiver = pc.getTransceivers()[1]
transceiver.setCodecPreferences(preferences)
1 year ago
await pc.setRemoteDescription(offer)
2 years ago
1 year ago
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
#return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type})
return web.Response(
content_type="application/json",
text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "sessionid":sessionid}
1 year ago
),
)
async def human(request):
params = await request.json()
sessionid = params.get('sessionid',0)
if params.get('interrupt'):
nerfreals[sessionid].flush_talk()
if params['type']=='echo':
nerfreals[sessionid].put_msg_txt(params['text'])
elif params['type']=='chat':
10 months ago
res=await asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'],nerfreals[sessionid])
#nerfreals[sessionid].put_msg_txt(res)
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "data":"ok"}
),
)
1 year ago
10 months ago
async def humanaudio(request):
try:
form= await request.post()
sessionid = int(form.get('sessionid',0))
fileobj = form["file"]
filename=fileobj.filename
filebytes=fileobj.file.read()
nerfreals[sessionid].put_audio_file(filebytes)
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "msg":"ok"}
),
)
except Exception as e:
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": -1, "msg":"err","data": ""+e.args[0]+""}
),
)
async def set_audiotype(request):
params = await request.json()
sessionid = params.get('sessionid',0)
nerfreals[sessionid].set_custom_state(params['audiotype'],params['reinit'])
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "data":"ok"}
),
)
11 months ago
async def record(request):
params = await request.json()
sessionid = params.get('sessionid',0)
if params['type']=='start_record':
# nerfreals[sessionid].put_msg_txt(params['text'])
nerfreals[sessionid].start_recording()
11 months ago
elif params['type']=='end_record':
nerfreals[sessionid].stop_recording()
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "data":"ok"}
),
)
10 months ago
async def is_speaking(request):
params = await request.json()
sessionid = params.get('sessionid',0)
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "data": nerfreals[sessionid].is_speaking()}
),
)
1 year ago
async def on_shutdown(app):
# close peer connections
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
pcs.clear()
async def post(url,data):
try:
async with aiohttp.ClientSession() as session:
async with session.post(url,data=data) as response:
return await response.text()
except aiohttp.ClientError as e:
logger.info(f'Error: {e}')
async def run(push_url,sessionid):
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
nerfreals[sessionid] = nerfreal
pc = RTCPeerConnection()
pcs.add(pc)
@pc.on("connectionstatechange")
async def on_connectionstatechange():
logger.info("Connection state is %s" % pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
player = HumanPlayer(nerfreals[sessionid])
audio_sender = pc.addTrack(player.audio)
video_sender = pc.addTrack(player.video)
await pc.setLocalDescription(await pc.createOffer())
answer = await post(push_url,pc.localDescription.sdp)
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
##########################################
# os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
# os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
2 years ago
if __name__ == '__main__':
mp.set_start_method('spawn')
2 years ago
parser = argparse.ArgumentParser()
# audio FPS
parser.add_argument('--fps', type=int, default=50, help="audio fps,must be 50")
2 years ago
# sliding window left-middle-right length (unit: 20ms)
parser.add_argument('-l', type=int, default=10)
parser.add_argument('-m', type=int, default=8)
2 years ago
parser.add_argument('-r', type=int, default=10)
parser.add_argument('--W', type=int, default=450, help="GUI width")
parser.add_argument('--H', type=int, default=450, help="GUI height")
1 year ago
#musetalk opt
parser.add_argument('--avatar_id', type=str, default='avator_1', help="define which avatar in data/avatars")
#parser.add_argument('--bbox_shift', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=16, help="infer batch")
1 year ago
parser.add_argument('--customvideo_config', type=str, default='', help="custom action json")
parser.add_argument('--tts', type=str, default='edgetts', help="tts service type") #xtts gpt-sovits cosyvoice
parser.add_argument('--REF_FILE', type=str, default="zh-CN-YunxiaNeural")
1 year ago
parser.add_argument('--REF_TEXT', type=str, default=None)
parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000
# parser.add_argument('--CHARACTER', type=str, default='test')
# parser.add_argument('--EMOTION', type=str, default='default')
parser.add_argument('--model', type=str, default='musetalk') #musetalk wav2lip ultralight
1 year ago
parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush
parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
parser.add_argument('--max_session', type=int, default=1) #multi session count
parser.add_argument('--listenport', type=int, default=8010, help="web listen port")
2 years ago
opt = parser.parse_args()
#app.config.from_object(opt)
1 year ago
#print(app.config)
opt.customopt = []
if opt.customvideo_config!='':
with open(opt.customvideo_config,'r') as file:
opt.customopt = json.load(file)
# if opt.model == 'ernerf':
# from nerfreal import NeRFReal,load_model,load_avatar
# model = load_model(opt)
# avatar = load_avatar(opt)
if opt.model == 'musetalk':
from musereal import MuseReal,load_model,load_avatar,warm_up
logger.info(opt)
model = load_model()
avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,model)
elif opt.model == 'wav2lip':
from lipreal import LipReal,load_model,load_avatar,warm_up
logger.info(opt)
model = load_model("./models/wav2lip.pth")
avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,model,256)
elif opt.model == 'ultralight':
from lightreal import LightReal,load_model,load_avatar,warm_up
logger.info(opt)
model = load_model(opt)
avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,avatar,160)
2 years ago
# if opt.transport=='rtmp':
# thread_quit = Event()
# nerfreals[0] = build_nerfreal(0)
# rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
# rendthrd.start()
2 years ago
#############################################################################
1 year ago
appasync = web.Application()
appasync.on_shutdown.append(on_shutdown)
appasync.router.add_post("/offer", offer)
appasync.router.add_post("/human", human)
10 months ago
appasync.router.add_post("/humanaudio", humanaudio)
appasync.router.add_post("/set_audiotype", set_audiotype)
11 months ago
appasync.router.add_post("/record", record)
10 months ago
appasync.router.add_post("/is_speaking", is_speaking)
1 year ago
appasync.router.add_static('/',path='web')
# Configure default CORS settings.
cors = aiohttp_cors.setup(appasync, defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True,
expose_headers="*",
allow_headers="*",
)
})
# Configure CORS on all routes.
for route in list(appasync.router.routes()):
cors.add(route)
pagename='webrtcapi.html'
if opt.transport=='rtmp':
pagename='echoapi.html'
elif opt.transport=='rtcpush':
pagename='rtcpushapi.html'
logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename)
logger.info('如果使用webrtc推荐访问webrtc集成前端: http://<serverip>:'+str(opt.listenport)+'/dashboard.html')
1 year ago
def run_server(runner):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(runner.setup())
site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
1 year ago
loop.run_until_complete(site.start())
if opt.transport=='rtcpush':
for k in range(opt.max_session):
push_url = opt.push_url
if k!=0:
push_url = opt.push_url+str(k)
loop.run_until_complete(run(push_url,k))
1 year ago
loop.run_forever()
#Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
run_server(web.AppRunner(appasync))
1 year ago
#app.on_shutdown.append(on_shutdown)
#app.router.add_post("/offer", offer)
# print('start websocket server')
# server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
# server.serve_forever()
2 years ago