You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

417 lines
16 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import asyncio
import codecs
import gzip
import hashlib
import json
import re
import string
import subprocess
import threading
import time
import traceback
import urllib.parse
from contextlib import contextmanager
from unittest.mock import patch
import httpx
import requests
import websocket
from py_mini_racer import MiniRacer
from protobuf.douyin import *
from message_processor import *
from helper import resource_path, PromptQueue
from settings import live_talking_host, Backend
@contextmanager
def patched_popen_encoding(encoding='utf-8'):
original_popen_init = subprocess.Popen.__init__
def new_popen_init(self, *args, **kwargs):
kwargs['encoding'] = encoding
original_popen_init(self, *args, **kwargs)
with patch.object(subprocess.Popen, '__init__', new_popen_init):
yield
def generateSignature(wss, script_file='sign.js'):
"""
出现gbk编码问题则修改 python模块subprocess.py的源码中Popen类的__init__函数参数encoding值为 "utf-8"
"""
params = ("live_id,aid,version_code,webcast_sdk_version,"
"room_id,sub_room_id,sub_channel_id,did_rule,"
"user_unique_id,device_platform,device_type,ac,"
"identity").split(',')
wss_params = urllib.parse.urlparse(wss).query.split('&')
wss_maps = {i.split('=')[0]: i.split("=")[-1] for i in wss_params}
tpl_params = [f"{i}={wss_maps.get(i, '')}" for i in params]
param = ','.join(tpl_params)
md5 = hashlib.md5()
md5.update(param.encode())
md5_param = md5.hexdigest()
with codecs.open(script_file, 'r', encoding='utf8') as f:
script = f.read()
ctx = MiniRacer()
ctx.eval(script)
try:
signature = ctx.call("get_sign", md5_param)
return signature
except Exception as e:
logger.error(e)
# 以下代码对应js脚本为sign_v0.js
# context = execjs.compile(script)
# with patched_popen_encoding(encoding='utf-8'):
# ret = context.call('getSign', {'X-MS-STUB': md5_param})
# return ret.get('X-Bogus')
def generateMsToken(length=107):
"""
产生请求头部cookie中的msToken字段其实为随机的107位字符
:param length:字符位数
:return:msToken
"""
random_str = ''
base_str = string.ascii_letters + string.digits + '=_'
_len = len(base_str) - 1
for _ in range(length):
random_str += base_str[random.randint(0, _len)]
return random_str
class DouyinLiveWebFetcher:
def __init__(self, live_id, ws_open_event, queue: PromptQueue):
"""
直播间弹幕抓取对象
:param live_id: 直播间的直播id打开直播间web首页的链接如https://live.douyin.com/261378947940
其中的261378947940即是live_id
"""
self.__ttwid = None
self.__room_id = None
self.live_id = live_id
self.live_url = "https://live.douyin.com/"
self.user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) " \
"Chrome/120.0.0.0 Safari/537.36"
self.queue = queue
self.ws_open_event = ws_open_event
def start(self):
self._connectWebSocket()
def stop(self):
self.ws.close()
@property
def ttwid(self):
"""
产生请求头部cookie中的ttwid字段访问抖音网页版直播间首页可以获取到响应cookie中的ttwid
:return: ttwid
"""
if self.__ttwid:
return self.__ttwid
headers = {
"User-Agent": self.user_agent,
}
try:
response = requests.get(self.live_url, headers=headers)
response.raise_for_status()
except Exception as err:
logger.error("【X】Request the live url error: ", err)
else:
self.__ttwid = response.cookies.get('ttwid')
return self.__ttwid
@property
def room_id(self):
"""
根据直播间的地址获取到真正的直播间roomId有时会有错误可以重试请求解决
:return:room_id
"""
if self.__room_id:
return self.__room_id
url = self.live_url + self.live_id
headers = {
"User-Agent": self.user_agent,
"cookie": f"ttwid={self.ttwid}&msToken={generateMsToken()}; __ac_nonce=0123407cc00a9e438deb4",
}
try:
response = requests.get(url, headers=headers)
response.raise_for_status()
except Exception as err:
logger.error("【X】Request the live room url error: ", err)
else:
match = re.search(r'roomId\\":\\"(\d+)\\"', response.text)
if match is None or len(match.groups()) < 1:
logger.error("【X】No match found for roomId")
self.__room_id = match.group(1)
return self.__room_id
def get_room_status(self):
"""
获取直播间开播状态:
room_status: 2 直播已结束
room_status: 0 直播进行中
"""
url = ('https://live.douyin.com/webcast/room/web/enter/?aid=6383'
'&app_name=douyin_web&live_id=1&device_platform=web&language=zh-CN&enter_from=web_live'
'&cookie_enabled=true&screen_width=1536&screen_height=864&browser_language=zh-CN&browser_platform=Win32'
'&browser_name=Edge&browser_version=133.0.0.0'
f'&web_rid={self.live_id}'
f'&room_id_str={self.room_id}'
'&enter_source=&is_need_double_stream=false&insert_task_id=&live_reason='
'&msToken=&a_bogus=')
resp = requests.get(url, headers={
'User-Agent': self.user_agent,
'Cookie': f'ttwid={self.ttwid};'
})
data = resp.json().get('data')
if data:
room_status = data.get('room_status')
user = data.get('user')
user_id = user.get('id_str')
nickname = user.get('nickname')
logger.info(f"{nickname}】[{user_id}]直播间:{['正在直播', '已结束'][bool(room_status)]}.")
def _connectWebSocket(self):
"""
连接抖音直播间websocket服务器请求直播间数据
"""
wss = ("wss://webcast100-ws-web-lq.douyin.com/webcast/im/push/v2/?app_name=douyin_web"
"&version_code=180800&webcast_sdk_version=1.0.14-beta.0"
"&update_version_code=1.0.14-beta.0&compress=gzip&device_platform=web&cookie_enabled=true"
"&screen_width=1536&screen_height=864&browser_language=zh-CN&browser_platform=Win32"
"&browser_name=Mozilla"
"&browser_version=5.0%20(Windows%20NT%2010.0;%20Win64;%20x64)%20AppleWebKit/537.36%20(KHTML,"
"%20like%20Gecko)%20Chrome/126.0.0.0%20Safari/537.36"
"&browser_online=true&tz_name=Asia/Shanghai"
"&cursor=d-1_u-1_fh-7392091211001140287_t-1721106114633_r-1"
f"&internal_ext=internal_src:dim|wss_push_room_id:{self.room_id}|wss_push_did:7319483754668557238"
f"|first_req_ms:1721106114541|fetch_time:1721106114633|seq:1|wss_info:0-1721106114633-0-0|"
f"wrds_v:7392094459690748497"
f"&host=https://live.douyin.com&aid=6383&live_id=1&did_rule=3&endpoint=live_pc&support_wrds=1"
f"&user_unique_id=7319483754668557238&im_path=/webcast/im/fetch/&identity=audience"
f"&need_persist_msg_count=15&insert_task_id=&live_reason=&room_id={self.room_id}&heartbeatDuration=0")
script_file = resource_path('sign.js')
signature = generateSignature(wss, script_file)
wss += f"&signature={signature}"
headers = {
"cookie": f"ttwid={self.ttwid}",
'user-agent': self.user_agent,
}
self.ws = websocket.WebSocketApp(wss,
header=headers,
on_open=self._wsOnOpen,
on_message=self._wsOnMessage,
on_error=self._wsOnError,
on_close=self._wsOnClose)
try:
self.ws.run_forever()
except Exception:
self.stop()
raise
def _sendHeartbeat(self):
"""
发送心跳包
"""
while True:
try:
heartbeat = PushFrame(payload_type='hb').SerializeToString()
self.ws.send(heartbeat, websocket.ABNF.OPCODE_PING)
# print("【√】发送心跳包")
except Exception as e:
logger.error("【X】心跳包检测错误: ", e)
break
else:
time.sleep(5)
def _wsOnOpen(self, ws):
"""
连接建立成功
"""
logger.info("【√】WebSocket连接成功.")
self.ws_open_event.set()
threading.Thread(target=self._sendHeartbeat).start()
def _wsOnMessage(self, ws, message):
"""
接收到数据
:param ws: websocket实例
:param message: 数据
"""
# 根据proto结构体解析对象
package = PushFrame().parse(message)
response = Response().parse(gzip.decompress(package.payload))
# 返回直播间服务器链接存活确认消息,便于持续获取数据
if response.need_ack:
ack = PushFrame(log_id=package.log_id,
payload_type='ack',
payload=response.internal_ext.encode('utf-8')
).SerializeToString()
ws.send(ack, websocket.ABNF.OPCODE_BINARY)
# 根据消息类别解析消息体
for msg in response.messages_list:
method = msg.method
try:
funcs = {
'WebcastChatMessage': parse_chat_msg, # 聊天消息
'WebcastGiftMessage': parse_gif_msg, # 礼物消息
'WebcastLikeMessage': parse_like_msg, # 点赞消息
'WebcastMemberMessage': parse_member_msg, # 进入直播间消息
'WebcastSocialMessage': parse_social_msg, # 关注消息
# 'WebcastRoomUserSeqMessage': parse_room_user_seq_msg, # 直播间统计
# 'WebcastFansclubMessage': parse_fansclub_msg, # 粉丝团消息
'WebcastControlMessage': self.parse_control_msg, # 直播间状态消息
# 'WebcastEmojiChatMessage': parse_emoji_chat_msg, # 聊天表情包消息
# 'WebcastRoomStatsMessage': parse_room_stats_msg, # 直播间统计信息
# 'WebcastRoomMessage': parse_room_msg, # 直播间信息
# 'WebcastRoomRankMessage': parse_rank_msg, # 直播间排行榜信息
# 'WebcastRoomStreamAdaptationMessage': parse_room_stream_adaptation_msg, # 直播间流配置
}
if method in funcs:
funcs[method](msg.payload, self.queue)
except Exception:
logger.error(traceback.format_exc())
def parse_control_msg(self, payload):
'''直播间状态消息'''
message = ControlMessage().parse(payload)
if message.status == 3:
logger.info("直播间已结束")
self.stop()
def _wsOnError(self, ws, error):
logger.error("WebSocket error: ", error)
def _wsOnClose(self, ws, *args):
self.get_room_status()
logger.info("WebSocket connection closed.")
class DouyinLiveWebReply:
def __init__(self, queue: PromptQueue):
self.queue = queue
self.system_text_list = []
self.session_id = 0
self.backend_token = ''
self.live_chat_config = LiveChatConfig()
self.punctuation = ",.!;:,。!?:;"
def _llm(self, prompt, stream=False):
payload = {
"model": "qwen3:30b-a3b",
"messages": [
{
"role": "user",
"content": f"{prompt}/no_think"
}
],
"options": {
"temperature": 0.5
},
"stream": False,
"filterThink": True
}
def _gen():
response = requests.post(f'{Backend.backend_host}{Backend.ollama_uri}', json=payload,
headers={'Authorization': f'Bearer {self.live_chat_config.backend_token}'},
stream=True)
buffer = ''
for line in response.iter_lines():
if not line:
continue
logger.info(f'llm output -> {line.decode()}')
data = json.loads(line.decode()[5:])
content = data['message']['content']
for char in content:
buffer += char
if char in self.punctuation:
if len(buffer.strip()) < 10:
continue # 不够长,继续缓冲
yield buffer.strip()
buffer = ''
if buffer.strip():
yield buffer.strip()
if stream:
return _gen()
else:
response = requests.post(f'{Backend.backend_host}{Backend.ollama_uri}', json=payload,
headers={'Authorization': f'Bearer {self.live_chat_config.backend_token}'},
timeout=10).content.decode()[5:]
response = json.loads(response)
return response['message']['content']
async def post_to_human(self, text: str):
async with httpx.AsyncClient() as client:
await client.post(
f'{live_talking_host}/human',
json={
"type": "echo",
"sessionid": self.session_id,
"text": text
},
timeout=5
)
def __call__(self):
"""
优先从用户交互队列中取提示词,如果没有用户交互的数据,则输出系统提示词
"""
while True:
try:
is_speaking = requests.post(f'{live_talking_host}/is_speaking', json={'sessionid': self.session_id},
timeout=5).json()['data']
if is_speaking:
time.sleep(0.1)
continue
prompt_data = self.queue.get(False)
if prompt_data is not None:
# live_chat: 弹幕
prompt, live_chat = prompt_data
# logger.info(f'处理提示词: {prompt}')
if live_chat is not None:
llm_output = self._llm(self.live_chat_config.product_related_prompt.format(content=live_chat))
logger.info(f'判断弹幕是否和商品有关: {llm_output}')
if llm_output != '':
continue
reply_messages = self._llm(prompt, True)
for reply_message in reply_messages:
asyncio.run(self.post_to_human(reply_message))
logger.info(f'输出回复: {reply_message}')
# is_speaking此时是False需要等一段时间再查询
time.sleep(0.5)
else:
# 用户交互队列为空,输出系统文案
# time.sleep(1)
system_messages = self.live_chat_config.system_messages
reply_message = system_messages[random.randint(0, len(system_messages) - 1)]
asyncio.run(self.post_to_human(reply_message))
logger.info(f'输出系统文案: {reply_message}')
time.sleep(1)
except Exception:
# 发生异常,输出系统文案
logger.error(traceback.format_exc())
time.sleep(5)
system_messages = self.live_chat_config.system_messages
reply_message = system_messages[random.randint(0, len(system_messages) - 1)]
asyncio.run(self.post_to_human(reply_message))