diff --git a/helper.py b/helper.py index 4deed2a..d575155 100644 --- a/helper.py +++ b/helper.py @@ -1,11 +1,20 @@ import os import sqlite3 import sys +from enum import Enum from multiprocessing import Queue from settings import sqlite_file from queue import Empty +class MessageType(Enum): + ENTER_LIVE_ROOM = 1 + FOLLOW = 2 + LIKE = 3 + GIFT = 4 + CHAT = 5 + + def resource_path(relative_path, base_dir='.'): # PyInstaller打包后会把文件放到临时文件夹 _MEIPASS try: @@ -81,6 +90,10 @@ class LiveChatConfig: def backend_token(self): return self._query_config('backend_token') + @property + def refine_system_message_prompt(self): + return self._query_config('refine_system_message_prompt') + @property def system_messages(self) -> list: results = [] @@ -92,6 +105,16 @@ class LiveChatConfig: cursor.close() return results + @property + def prohibited_words(self) -> dict: + results = {} + cursor = self.conn.cursor() + cursor.execute('select word, substitutes from prohibited_word') + rows = cursor.fetchall() + for word, substitutes in rows: + results[word] = substitutes + return results + class PromptQueue: def __init__(self, maxsize=0): diff --git a/liveMan.py b/liveMan.py index 386050d..2a6c6a1 100644 --- a/liveMan.py +++ b/liveMan.py @@ -20,7 +20,7 @@ from py_mini_racer import MiniRacer from protobuf.douyin import * from message_processor import * -from helper import resource_path, PromptQueue +from helper import resource_path, PromptQueue, MessageType from settings import live_talking_host, Backend @@ -358,6 +358,13 @@ class DouyinLiveWebReply: response = json.loads(response) return response['message']['content'] + def reply_message_postprocess(self, text): + prohibited_words = self.live_chat_config.prohibited_words + for prohibited_word, substitutes in prohibited_words.items(): + if prohibited_word in text: + text = text.replace(prohibited_word, substitutes if substitutes is not None else '') + return text + async def post_to_human(self, text: str): async with httpx.AsyncClient() as client: await client.post( @@ -385,15 +392,31 @@ class DouyinLiveWebReply: prompt_data = self.queue.get(False) if prompt_data is not None: # live_chat: 弹幕 - prompt, live_chat = prompt_data - # logger.info(f'处理提示词: {prompt}') + message_type, prompt, live_chat = prompt_data + if message_type == MessageType.ENTER_LIVE_ROOM.value and \ + random.random() >= self.live_chat_config.enter_live_room_prob / 100: + continue + elif message_type == MessageType.FOLLOW.value and \ + random.random() >= self.live_chat_config.follow_prob / 100: + continue + elif message_type == MessageType.LIKE.value and \ + random.random() >= self.live_chat_config.like_prob / 100: + continue + elif message_type == MessageType.GIFT.value and \ + random.random() >= self.live_chat_config.gift_prob / 100: + continue + elif message_type == MessageType.CHAT.value and \ + random.random() >= self.live_chat_config.chat_prob / 100: + continue if live_chat is not None: + logger.info(f'弹幕: {live_chat}') llm_output = self._llm(self.live_chat_config.product_related_prompt.format(content=live_chat)) - logger.info(f'判断弹幕是否和商品有关: {llm_output}') - if llm_output != '是': + logger.info(f'判断弹幕是否违反中国大陆法律和政策: {llm_output}') + if llm_output != '否': continue reply_messages = self._llm(prompt, True) for reply_message in reply_messages: + reply_message = self.reply_message_postprocess(reply_message) asyncio.run(self.post_to_human(reply_message)) logger.info(f'输出回复: {reply_message}') # is_speaking此时是False,需要等一段时间再查询 @@ -403,8 +426,12 @@ class DouyinLiveWebReply: # time.sleep(1) system_messages = self.live_chat_config.system_messages reply_message = system_messages[random.randint(0, len(system_messages) - 1)] - asyncio.run(self.post_to_human(reply_message)) - logger.info(f'输出系统文案: {reply_message}') + llm_prompt = self.live_chat_config.refine_system_message_prompt.format(content=reply_message) + reply_messages = self._llm(llm_prompt, True) + for reply_message in reply_messages: + reply_message = self.reply_message_postprocess(reply_message) + asyncio.run(self.post_to_human(reply_message)) + logger.info(f'输出系统文案: {reply_message}') time.sleep(1) except Exception: @@ -413,4 +440,8 @@ class DouyinLiveWebReply: time.sleep(5) system_messages = self.live_chat_config.system_messages reply_message = system_messages[random.randint(0, len(system_messages) - 1)] - asyncio.run(self.post_to_human(reply_message)) + llm_prompt = self.live_chat_config.refine_system_message_prompt.format(content=reply_message) + reply_messages = self._llm(llm_prompt, True) + for reply_message in reply_messages: + reply_message = self.reply_message_postprocess(reply_message) + asyncio.run(self.post_to_human(reply_message)) diff --git a/main.py b/main.py index 80f0209..233febe 100644 --- a/main.py +++ b/main.py @@ -25,7 +25,7 @@ if __name__ == '__main__': freeze_support() args = parse_args() - queue = PromptQueue(1000) + queue = PromptQueue(10) ws_open_event = Event() fetch_process = Process(target=fetch_user_chat_content, args=(args.live_id, ws_open_event, queue)) diff --git a/message_processor.py b/message_processor.py index 1ed1559..26d3e3a 100644 --- a/message_processor.py +++ b/message_processor.py @@ -1,56 +1,56 @@ from protobuf.douyin import * import random from loguru import logger -from helper import LiveChatConfig +from helper import LiveChatConfig, MessageType live_chat_config = LiveChatConfig() def parse_chat_msg(payload, queue): """聊天消息""" - if not random.random() < live_chat_config.chat_prob / 100: - return + # if not random.random() < live_chat_config.chat_prob / 100: + # return message = ChatMessage().parse(payload) user_name = message.user.nick_name user_id = message.user.id content = message.content prompt = live_chat_config.chat_prompt.format(content=content) - queue.put((prompt, content)) + queue.put((MessageType.CHAT.value, prompt, content)) # logger.info(f"【聊天msg】[{user_id}]{user_name}: {content}") # logger.info(f"队列数量: {queue.qsize()}") def parse_gif_msg(payload, queue): """礼物消息""" - if not random.random() < live_chat_config.gift_prob / 100: - return + # if not random.random() < live_chat_config.gift_prob / 100: + # return message = GiftMessage().parse(payload) user_name = message.user.nick_name gift_name = message.gift.name gift_count = message.combo_count prompt = live_chat_config.gift_prompt.format(user_name=user_name, gift_count=gift_count, gift_name=gift_name) - queue.put((prompt, None)) + queue.put((MessageType.GIFT.value, prompt, None)) # logger.info(f"【礼物msg】{user_name} 送出了 {gift_name}x{gift_count}") # logger.info(f"队列数量: {queue.qsize()}") def parse_like_msg(payload, queue): '''点赞消息''' - if not random.random() < live_chat_config.like_prob / 100: - return + # if not random.random() < live_chat_config.like_prob / 100: + # return message = LikeMessage().parse(payload) user_name = message.user.nick_name count = message.count prompt = live_chat_config.like_prompt.format(user_name=user_name, count=count) - queue.put((prompt, None)) + queue.put((MessageType.LIKE.value, prompt, None)) # logger.info(f"【点赞msg】{user_name} 点了{count}个赞") # logger.info(f"队列数量: {queue.qsize()}") def parse_member_msg(payload, queue): '''进入直播间消息''' - if not random.random() < live_chat_config.enter_live_room_prob / 100: - return + # if not random.random() < live_chat_config.enter_live_room_prob / 100: + # return message = MemberMessage().parse(payload) user_name = message.user.nick_name user_id = message.user.id @@ -58,20 +58,20 @@ def parse_member_msg(payload, queue): if gender in (0, 1): gender = ["女", "男"][gender] prompt = live_chat_config.enter_live_room_prompt.format(user_name=user_name) - queue.put((prompt, None)) + queue.put((MessageType.ENTER_LIVE_ROOM.value, prompt, None)) # logger.info(f"【进场msg】[{user_id}][{gender}]{user_name} 进入了直播间") # logger.info(f"队列数量: {queue.qsize()}") def parse_social_msg(payload, queue): '''关注消息''' - if not random.random() < live_chat_config.follow_prob / 100: - return + # if not random.random() < live_chat_config.follow_prob / 100: + # return message = SocialMessage().parse(payload) user_name = message.user.nick_name user_id = message.user.id prompt = live_chat_config.follow_prompt.format(user_name=user_name) - queue.put((prompt, None)) + queue.put((MessageType.FOLLOW.value, prompt, None)) # logger.info(f"【关注msg】[{user_id}]{user_name} 关注了主播") # logger.info(f"队列数量: {queue.qsize()}")