敏感词替换和过滤

main
Zhangzhichao@123 1 month ago
parent e15b9b95fd
commit 71fa859cd4

@ -1,11 +1,20 @@
import os
import sqlite3
import sys
from enum import Enum
from multiprocessing import Queue
from settings import sqlite_file
from queue import Empty
class MessageType(Enum):
ENTER_LIVE_ROOM = 1
FOLLOW = 2
LIKE = 3
GIFT = 4
CHAT = 5
def resource_path(relative_path, base_dir='.'):
# PyInstaller打包后会把文件放到临时文件夹 _MEIPASS
try:
@ -81,6 +90,10 @@ class LiveChatConfig:
def backend_token(self):
return self._query_config('backend_token')
@property
def refine_system_message_prompt(self):
return self._query_config('refine_system_message_prompt')
@property
def system_messages(self) -> list:
results = []
@ -92,6 +105,16 @@ class LiveChatConfig:
cursor.close()
return results
@property
def prohibited_words(self) -> dict:
results = {}
cursor = self.conn.cursor()
cursor.execute('select word, substitutes from prohibited_word')
rows = cursor.fetchall()
for word, substitutes in rows:
results[word] = substitutes
return results
class PromptQueue:
def __init__(self, maxsize=0):

@ -20,7 +20,7 @@ from py_mini_racer import MiniRacer
from protobuf.douyin import *
from message_processor import *
from helper import resource_path, PromptQueue
from helper import resource_path, PromptQueue, MessageType
from settings import live_talking_host, Backend
@ -358,6 +358,13 @@ class DouyinLiveWebReply:
response = json.loads(response)
return response['message']['content']
def reply_message_postprocess(self, text):
prohibited_words = self.live_chat_config.prohibited_words
for prohibited_word, substitutes in prohibited_words.items():
if prohibited_word in text:
text = text.replace(prohibited_word, substitutes if substitutes is not None else '')
return text
async def post_to_human(self, text: str):
async with httpx.AsyncClient() as client:
await client.post(
@ -385,15 +392,31 @@ class DouyinLiveWebReply:
prompt_data = self.queue.get(False)
if prompt_data is not None:
# live_chat: 弹幕
prompt, live_chat = prompt_data
# logger.info(f'处理提示词: {prompt}')
message_type, prompt, live_chat = prompt_data
if message_type == MessageType.ENTER_LIVE_ROOM.value and \
random.random() >= self.live_chat_config.enter_live_room_prob / 100:
continue
elif message_type == MessageType.FOLLOW.value and \
random.random() >= self.live_chat_config.follow_prob / 100:
continue
elif message_type == MessageType.LIKE.value and \
random.random() >= self.live_chat_config.like_prob / 100:
continue
elif message_type == MessageType.GIFT.value and \
random.random() >= self.live_chat_config.gift_prob / 100:
continue
elif message_type == MessageType.CHAT.value and \
random.random() >= self.live_chat_config.chat_prob / 100:
continue
if live_chat is not None:
logger.info(f'弹幕: {live_chat}')
llm_output = self._llm(self.live_chat_config.product_related_prompt.format(content=live_chat))
logger.info(f'判断弹幕是否和商品有关: {llm_output}')
if llm_output != '':
logger.info(f'判断弹幕是否违反中国大陆法律和政策: {llm_output}')
if llm_output != '':
continue
reply_messages = self._llm(prompt, True)
for reply_message in reply_messages:
reply_message = self.reply_message_postprocess(reply_message)
asyncio.run(self.post_to_human(reply_message))
logger.info(f'输出回复: {reply_message}')
# is_speaking此时是False需要等一段时间再查询
@ -403,6 +426,10 @@ class DouyinLiveWebReply:
# time.sleep(1)
system_messages = self.live_chat_config.system_messages
reply_message = system_messages[random.randint(0, len(system_messages) - 1)]
llm_prompt = self.live_chat_config.refine_system_message_prompt.format(content=reply_message)
reply_messages = self._llm(llm_prompt, True)
for reply_message in reply_messages:
reply_message = self.reply_message_postprocess(reply_message)
asyncio.run(self.post_to_human(reply_message))
logger.info(f'输出系统文案: {reply_message}')
time.sleep(1)
@ -413,4 +440,8 @@ class DouyinLiveWebReply:
time.sleep(5)
system_messages = self.live_chat_config.system_messages
reply_message = system_messages[random.randint(0, len(system_messages) - 1)]
llm_prompt = self.live_chat_config.refine_system_message_prompt.format(content=reply_message)
reply_messages = self._llm(llm_prompt, True)
for reply_message in reply_messages:
reply_message = self.reply_message_postprocess(reply_message)
asyncio.run(self.post_to_human(reply_message))

@ -25,7 +25,7 @@ if __name__ == '__main__':
freeze_support()
args = parse_args()
queue = PromptQueue(1000)
queue = PromptQueue(10)
ws_open_event = Event()
fetch_process = Process(target=fetch_user_chat_content, args=(args.live_id, ws_open_event, queue))

@ -1,56 +1,56 @@
from protobuf.douyin import *
import random
from loguru import logger
from helper import LiveChatConfig
from helper import LiveChatConfig, MessageType
live_chat_config = LiveChatConfig()
def parse_chat_msg(payload, queue):
"""聊天消息"""
if not random.random() < live_chat_config.chat_prob / 100:
return
# if not random.random() < live_chat_config.chat_prob / 100:
# return
message = ChatMessage().parse(payload)
user_name = message.user.nick_name
user_id = message.user.id
content = message.content
prompt = live_chat_config.chat_prompt.format(content=content)
queue.put((prompt, content))
queue.put((MessageType.CHAT.value, prompt, content))
# logger.info(f"【聊天msg】[{user_id}]{user_name}: {content}")
# logger.info(f"队列数量: {queue.qsize()}")
def parse_gif_msg(payload, queue):
"""礼物消息"""
if not random.random() < live_chat_config.gift_prob / 100:
return
# if not random.random() < live_chat_config.gift_prob / 100:
# return
message = GiftMessage().parse(payload)
user_name = message.user.nick_name
gift_name = message.gift.name
gift_count = message.combo_count
prompt = live_chat_config.gift_prompt.format(user_name=user_name, gift_count=gift_count, gift_name=gift_name)
queue.put((prompt, None))
queue.put((MessageType.GIFT.value, prompt, None))
# logger.info(f"【礼物msg】{user_name} 送出了 {gift_name}x{gift_count}")
# logger.info(f"队列数量: {queue.qsize()}")
def parse_like_msg(payload, queue):
'''点赞消息'''
if not random.random() < live_chat_config.like_prob / 100:
return
# if not random.random() < live_chat_config.like_prob / 100:
# return
message = LikeMessage().parse(payload)
user_name = message.user.nick_name
count = message.count
prompt = live_chat_config.like_prompt.format(user_name=user_name, count=count)
queue.put((prompt, None))
queue.put((MessageType.LIKE.value, prompt, None))
# logger.info(f"【点赞msg】{user_name} 点了{count}个赞")
# logger.info(f"队列数量: {queue.qsize()}")
def parse_member_msg(payload, queue):
'''进入直播间消息'''
if not random.random() < live_chat_config.enter_live_room_prob / 100:
return
# if not random.random() < live_chat_config.enter_live_room_prob / 100:
# return
message = MemberMessage().parse(payload)
user_name = message.user.nick_name
user_id = message.user.id
@ -58,20 +58,20 @@ def parse_member_msg(payload, queue):
if gender in (0, 1):
gender = ["", ""][gender]
prompt = live_chat_config.enter_live_room_prompt.format(user_name=user_name)
queue.put((prompt, None))
queue.put((MessageType.ENTER_LIVE_ROOM.value, prompt, None))
# logger.info(f"【进场msg】[{user_id}][{gender}]{user_name} 进入了直播间")
# logger.info(f"队列数量: {queue.qsize()}")
def parse_social_msg(payload, queue):
'''关注消息'''
if not random.random() < live_chat_config.follow_prob / 100:
return
# if not random.random() < live_chat_config.follow_prob / 100:
# return
message = SocialMessage().parse(payload)
user_name = message.user.nick_name
user_id = message.user.id
prompt = live_chat_config.follow_prompt.format(user_name=user_name)
queue.put((prompt, None))
queue.put((MessageType.FOLLOW.value, prompt, None))
# logger.info(f"【关注msg】[{user_id}]{user_name} 关注了主播")
# logger.info(f"队列数量: {queue.qsize()}")

Loading…
Cancel
Save