You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

316 lines
11 KiB
Python

import asyncio
import json
import websockets
import time
import logging
import tracemalloc
import numpy as np
import argparse
import ssl
parser = argparse.ArgumentParser()
parser.add_argument("--host",
type=str,
default="0.0.0.0",
required=False,
help="host ip, localhost, 0.0.0.0")
parser.add_argument("--port",
type=int,
default=10095,
required=False,
help="grpc server port")
parser.add_argument("--asr_model",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="model from modelscope")
parser.add_argument("--asr_model_revision",
type=str,
default="v2.0.4",
help="")
parser.add_argument("--asr_model_online",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
help="model from modelscope")
parser.add_argument("--asr_model_online_revision",
type=str,
default="v2.0.4",
help="")
parser.add_argument("--vad_model",
type=str,
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
help="model from modelscope")
parser.add_argument("--vad_model_revision",
type=str,
default="v2.0.4",
help="")
parser.add_argument("--punc_model",
type=str,
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
help="model from modelscope")
parser.add_argument("--punc_model_revision",
type=str,
default="v2.0.4",
help="")
parser.add_argument("--ngpu",
type=int,
default=1,
help="0 for cpu, 1 for gpu")
parser.add_argument("--device",
type=str,
default="cuda",
help="cuda, cpu")
parser.add_argument("--ncpu",
type=int,
default=4,
help="cpu cores")
parser.add_argument("--certfile",
type=str,
default="ssl_key/server.crt",
required=False,
help="certfile for ssl")
parser.add_argument("--keyfile",
type=str,
default="ssl_key/server.key",
required=False,
help="keyfile for ssl")
args = parser.parse_args()
websocket_users = set()
print("model loading")
from funasr import AutoModel
# asr
model_asr = AutoModel(model=args.asr_model,
model_revision=args.asr_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
# asr
model_asr_streaming = AutoModel(model=args.asr_model_online,
model_revision=args.asr_model_online_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
# vad
model_vad = AutoModel(model=args.vad_model,
model_revision=args.vad_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
# chunk_size=60,
)
if args.punc_model != "":
model_punc = AutoModel(model=args.punc_model,
model_revision=args.punc_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
else:
model_punc = None
print("model loaded! only support one client at the same time now!!!!")
async def ws_reset(websocket):
print("ws reset now, total num is ",len(websocket_users))
websocket.status_dict_asr_online["cache"] = {}
websocket.status_dict_asr_online["is_final"] = True
websocket.status_dict_vad["cache"] = {}
websocket.status_dict_vad["is_final"] = True
websocket.status_dict_punc["cache"] = {}
await websocket.close()
async def clear_websocket():
for websocket in websocket_users:
await ws_reset(websocket)
websocket_users.clear()
async def ws_serve(websocket, path):
frames = []
frames_asr = []
frames_asr_online = []
global websocket_users
# await clear_websocket()
websocket_users.add(websocket)
websocket.status_dict_asr = {}
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
websocket.status_dict_vad = {'cache': {}, "is_final": False}
websocket.status_dict_punc = {'cache': {}}
websocket.chunk_interval = 10
websocket.vad_pre_idx = 0
speech_start = False
speech_end_i = -1
websocket.wav_name = "microphone"
websocket.mode = "2pass"
print("new user connected", flush=True)
try:
async for message in websocket:
if isinstance(message, str):
messagejson = json.loads(message)
if "is_speaking" in messagejson:
websocket.is_speaking = messagejson["is_speaking"]
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
if "chunk_interval" in messagejson:
websocket.chunk_interval = messagejson["chunk_interval"]
if "wav_name" in messagejson:
websocket.wav_name = messagejson.get("wav_name")
if "chunk_size" in messagejson:
chunk_size = messagejson["chunk_size"]
if isinstance(chunk_size, str):
chunk_size = chunk_size.split(',')
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
if "encoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
if "decoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"]
if "hotword" in messagejson:
websocket.status_dict_asr["hotword"] = messagejson["hotword"]
if "mode" in messagejson:
websocket.mode = messagejson["mode"]
websocket.status_dict_vad["chunk_size"] = int(websocket.status_dict_asr_online["chunk_size"][1]*60/websocket.chunk_interval)
if len(frames_asr_online) > 0 or len(frames_asr) > 0 or not isinstance(message, str):
if not isinstance(message, str):
frames.append(message)
duration_ms = len(message)//32
websocket.vad_pre_idx += duration_ms
# asr online
frames_asr_online.append(message)
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
if len(frames_asr_online) % websocket.chunk_interval == 0 or websocket.status_dict_asr_online["is_final"]:
if websocket.mode == "2pass" or websocket.mode == "online":
audio_in = b"".join(frames_asr_online)
try:
await async_asr_online(websocket, audio_in)
except:
print(f"error in asr streaming, {websocket.status_dict_asr_online}")
frames_asr_online = []
if speech_start:
frames_asr.append(message)
# vad online
try:
speech_start_i, speech_end_i = await async_vad(websocket, message)
except:
print("error in vad")
if speech_start_i != -1:
speech_start = True
beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
frames_pre = frames[-beg_bias:]
frames_asr = []
frames_asr.extend(frames_pre)
# asr punc offline
if speech_end_i != -1 or not websocket.is_speaking:
# print("vad end point")
if websocket.mode == "2pass" or websocket.mode == "offline":
audio_in = b"".join(frames_asr)
try:
await async_asr(websocket, audio_in)
except:
print("error in asr offline")
frames_asr = []
speech_start = False
frames_asr_online = []
websocket.status_dict_asr_online["cache"] = {}
if not websocket.is_speaking:
websocket.vad_pre_idx = 0
frames = []
websocket.status_dict_vad["cache"] = {}
else:
frames = frames[-20:]
except websockets.ConnectionClosed:
print("ConnectionClosed...", websocket_users,flush=True)
await ws_reset(websocket)
websocket_users.remove(websocket)
except websockets.InvalidState:
print("InvalidState...")
except Exception as e:
print("Exception:", e)
async def async_vad(websocket, audio_in):
segments_result = model_vad.generate(input=audio_in, **websocket.status_dict_vad)[0]["value"]
# print(segments_result)
speech_start = -1
speech_end = -1
if len(segments_result) == 0 or len(segments_result) > 1:
return speech_start, speech_end
if segments_result[0][0] != -1:
speech_start = segments_result[0][0]
if segments_result[0][1] != -1:
speech_end = segments_result[0][1]
return speech_start, speech_end
async def async_asr(websocket, audio_in):
if len(audio_in) > 0:
# print(len(audio_in))
rec_result = model_asr.generate(input=audio_in, **websocket.status_dict_asr)[0]
# print("offline_asr, ", rec_result)
if model_punc is not None and len(rec_result["text"])>0:
# print("offline, before punc", rec_result, "cache", websocket.status_dict_punc)
rec_result = model_punc.generate(input=rec_result['text'], **websocket.status_dict_punc)[0]
# print("offline, after punc", rec_result)
if len(rec_result["text"])>0:
# print("offline", rec_result)
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({"mode": mode, "text": rec_result["text"], "wav_name": websocket.wav_name,"is_final":websocket.is_speaking})
await websocket.send(message)
async def async_asr_online(websocket, audio_in):
if len(audio_in) > 0:
# print(websocket.status_dict_asr_online.get("is_final", False))
rec_result = model_asr_streaming.generate(input=audio_in, **websocket.status_dict_asr_online)[0]
# print("online, ", rec_result)
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False):
return
# websocket.status_dict_asr_online["cache"] = dict()
if len(rec_result["text"]):
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({"mode": mode, "text": rec_result["text"], "wav_name": websocket.wav_name,"is_final":websocket.is_speaking})
await websocket.send(message)
if len(args.certfile)>0:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
ssl_cert = args.certfile
ssl_key = args.keyfile
ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
else:
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()