敏感词替换和过滤

main
Zhangzhichao@123 1 month ago
parent e15b9b95fd
commit 71fa859cd4

@ -1,11 +1,20 @@
import os import os
import sqlite3 import sqlite3
import sys import sys
from enum import Enum
from multiprocessing import Queue from multiprocessing import Queue
from settings import sqlite_file from settings import sqlite_file
from queue import Empty from queue import Empty
class MessageType(Enum):
ENTER_LIVE_ROOM = 1
FOLLOW = 2
LIKE = 3
GIFT = 4
CHAT = 5
def resource_path(relative_path, base_dir='.'): def resource_path(relative_path, base_dir='.'):
# PyInstaller打包后会把文件放到临时文件夹 _MEIPASS # PyInstaller打包后会把文件放到临时文件夹 _MEIPASS
try: try:
@ -81,6 +90,10 @@ class LiveChatConfig:
def backend_token(self): def backend_token(self):
return self._query_config('backend_token') return self._query_config('backend_token')
@property
def refine_system_message_prompt(self):
return self._query_config('refine_system_message_prompt')
@property @property
def system_messages(self) -> list: def system_messages(self) -> list:
results = [] results = []
@ -92,6 +105,16 @@ class LiveChatConfig:
cursor.close() cursor.close()
return results return results
@property
def prohibited_words(self) -> dict:
results = {}
cursor = self.conn.cursor()
cursor.execute('select word, substitutes from prohibited_word')
rows = cursor.fetchall()
for word, substitutes in rows:
results[word] = substitutes
return results
class PromptQueue: class PromptQueue:
def __init__(self, maxsize=0): def __init__(self, maxsize=0):

@ -20,7 +20,7 @@ from py_mini_racer import MiniRacer
from protobuf.douyin import * from protobuf.douyin import *
from message_processor import * from message_processor import *
from helper import resource_path, PromptQueue from helper import resource_path, PromptQueue, MessageType
from settings import live_talking_host, Backend from settings import live_talking_host, Backend
@ -358,6 +358,13 @@ class DouyinLiveWebReply:
response = json.loads(response) response = json.loads(response)
return response['message']['content'] return response['message']['content']
def reply_message_postprocess(self, text):
prohibited_words = self.live_chat_config.prohibited_words
for prohibited_word, substitutes in prohibited_words.items():
if prohibited_word in text:
text = text.replace(prohibited_word, substitutes if substitutes is not None else '')
return text
async def post_to_human(self, text: str): async def post_to_human(self, text: str):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
await client.post( await client.post(
@ -385,15 +392,31 @@ class DouyinLiveWebReply:
prompt_data = self.queue.get(False) prompt_data = self.queue.get(False)
if prompt_data is not None: if prompt_data is not None:
# live_chat: 弹幕 # live_chat: 弹幕
prompt, live_chat = prompt_data message_type, prompt, live_chat = prompt_data
# logger.info(f'处理提示词: {prompt}') if message_type == MessageType.ENTER_LIVE_ROOM.value and \
random.random() >= self.live_chat_config.enter_live_room_prob / 100:
continue
elif message_type == MessageType.FOLLOW.value and \
random.random() >= self.live_chat_config.follow_prob / 100:
continue
elif message_type == MessageType.LIKE.value and \
random.random() >= self.live_chat_config.like_prob / 100:
continue
elif message_type == MessageType.GIFT.value and \
random.random() >= self.live_chat_config.gift_prob / 100:
continue
elif message_type == MessageType.CHAT.value and \
random.random() >= self.live_chat_config.chat_prob / 100:
continue
if live_chat is not None: if live_chat is not None:
logger.info(f'弹幕: {live_chat}')
llm_output = self._llm(self.live_chat_config.product_related_prompt.format(content=live_chat)) llm_output = self._llm(self.live_chat_config.product_related_prompt.format(content=live_chat))
logger.info(f'判断弹幕是否和商品有关: {llm_output}') logger.info(f'判断弹幕是否违反中国大陆法律和政策: {llm_output}')
if llm_output != '': if llm_output != '':
continue continue
reply_messages = self._llm(prompt, True) reply_messages = self._llm(prompt, True)
for reply_message in reply_messages: for reply_message in reply_messages:
reply_message = self.reply_message_postprocess(reply_message)
asyncio.run(self.post_to_human(reply_message)) asyncio.run(self.post_to_human(reply_message))
logger.info(f'输出回复: {reply_message}') logger.info(f'输出回复: {reply_message}')
# is_speaking此时是False需要等一段时间再查询 # is_speaking此时是False需要等一段时间再查询
@ -403,6 +426,10 @@ class DouyinLiveWebReply:
# time.sleep(1) # time.sleep(1)
system_messages = self.live_chat_config.system_messages system_messages = self.live_chat_config.system_messages
reply_message = system_messages[random.randint(0, len(system_messages) - 1)] reply_message = system_messages[random.randint(0, len(system_messages) - 1)]
llm_prompt = self.live_chat_config.refine_system_message_prompt.format(content=reply_message)
reply_messages = self._llm(llm_prompt, True)
for reply_message in reply_messages:
reply_message = self.reply_message_postprocess(reply_message)
asyncio.run(self.post_to_human(reply_message)) asyncio.run(self.post_to_human(reply_message))
logger.info(f'输出系统文案: {reply_message}') logger.info(f'输出系统文案: {reply_message}')
time.sleep(1) time.sleep(1)
@ -413,4 +440,8 @@ class DouyinLiveWebReply:
time.sleep(5) time.sleep(5)
system_messages = self.live_chat_config.system_messages system_messages = self.live_chat_config.system_messages
reply_message = system_messages[random.randint(0, len(system_messages) - 1)] reply_message = system_messages[random.randint(0, len(system_messages) - 1)]
llm_prompt = self.live_chat_config.refine_system_message_prompt.format(content=reply_message)
reply_messages = self._llm(llm_prompt, True)
for reply_message in reply_messages:
reply_message = self.reply_message_postprocess(reply_message)
asyncio.run(self.post_to_human(reply_message)) asyncio.run(self.post_to_human(reply_message))

@ -25,7 +25,7 @@ if __name__ == '__main__':
freeze_support() freeze_support()
args = parse_args() args = parse_args()
queue = PromptQueue(1000) queue = PromptQueue(10)
ws_open_event = Event() ws_open_event = Event()
fetch_process = Process(target=fetch_user_chat_content, args=(args.live_id, ws_open_event, queue)) fetch_process = Process(target=fetch_user_chat_content, args=(args.live_id, ws_open_event, queue))

@ -1,56 +1,56 @@
from protobuf.douyin import * from protobuf.douyin import *
import random import random
from loguru import logger from loguru import logger
from helper import LiveChatConfig from helper import LiveChatConfig, MessageType
live_chat_config = LiveChatConfig() live_chat_config = LiveChatConfig()
def parse_chat_msg(payload, queue): def parse_chat_msg(payload, queue):
"""聊天消息""" """聊天消息"""
if not random.random() < live_chat_config.chat_prob / 100: # if not random.random() < live_chat_config.chat_prob / 100:
return # return
message = ChatMessage().parse(payload) message = ChatMessage().parse(payload)
user_name = message.user.nick_name user_name = message.user.nick_name
user_id = message.user.id user_id = message.user.id
content = message.content content = message.content
prompt = live_chat_config.chat_prompt.format(content=content) prompt = live_chat_config.chat_prompt.format(content=content)
queue.put((prompt, content)) queue.put((MessageType.CHAT.value, prompt, content))
# logger.info(f"【聊天msg】[{user_id}]{user_name}: {content}") # logger.info(f"【聊天msg】[{user_id}]{user_name}: {content}")
# logger.info(f"队列数量: {queue.qsize()}") # logger.info(f"队列数量: {queue.qsize()}")
def parse_gif_msg(payload, queue): def parse_gif_msg(payload, queue):
"""礼物消息""" """礼物消息"""
if not random.random() < live_chat_config.gift_prob / 100: # if not random.random() < live_chat_config.gift_prob / 100:
return # return
message = GiftMessage().parse(payload) message = GiftMessage().parse(payload)
user_name = message.user.nick_name user_name = message.user.nick_name
gift_name = message.gift.name gift_name = message.gift.name
gift_count = message.combo_count gift_count = message.combo_count
prompt = live_chat_config.gift_prompt.format(user_name=user_name, gift_count=gift_count, gift_name=gift_name) prompt = live_chat_config.gift_prompt.format(user_name=user_name, gift_count=gift_count, gift_name=gift_name)
queue.put((prompt, None)) queue.put((MessageType.GIFT.value, prompt, None))
# logger.info(f"【礼物msg】{user_name} 送出了 {gift_name}x{gift_count}") # logger.info(f"【礼物msg】{user_name} 送出了 {gift_name}x{gift_count}")
# logger.info(f"队列数量: {queue.qsize()}") # logger.info(f"队列数量: {queue.qsize()}")
def parse_like_msg(payload, queue): def parse_like_msg(payload, queue):
'''点赞消息''' '''点赞消息'''
if not random.random() < live_chat_config.like_prob / 100: # if not random.random() < live_chat_config.like_prob / 100:
return # return
message = LikeMessage().parse(payload) message = LikeMessage().parse(payload)
user_name = message.user.nick_name user_name = message.user.nick_name
count = message.count count = message.count
prompt = live_chat_config.like_prompt.format(user_name=user_name, count=count) prompt = live_chat_config.like_prompt.format(user_name=user_name, count=count)
queue.put((prompt, None)) queue.put((MessageType.LIKE.value, prompt, None))
# logger.info(f"【点赞msg】{user_name} 点了{count}个赞") # logger.info(f"【点赞msg】{user_name} 点了{count}个赞")
# logger.info(f"队列数量: {queue.qsize()}") # logger.info(f"队列数量: {queue.qsize()}")
def parse_member_msg(payload, queue): def parse_member_msg(payload, queue):
'''进入直播间消息''' '''进入直播间消息'''
if not random.random() < live_chat_config.enter_live_room_prob / 100: # if not random.random() < live_chat_config.enter_live_room_prob / 100:
return # return
message = MemberMessage().parse(payload) message = MemberMessage().parse(payload)
user_name = message.user.nick_name user_name = message.user.nick_name
user_id = message.user.id user_id = message.user.id
@ -58,20 +58,20 @@ def parse_member_msg(payload, queue):
if gender in (0, 1): if gender in (0, 1):
gender = ["", ""][gender] gender = ["", ""][gender]
prompt = live_chat_config.enter_live_room_prompt.format(user_name=user_name) prompt = live_chat_config.enter_live_room_prompt.format(user_name=user_name)
queue.put((prompt, None)) queue.put((MessageType.ENTER_LIVE_ROOM.value, prompt, None))
# logger.info(f"【进场msg】[{user_id}][{gender}]{user_name} 进入了直播间") # logger.info(f"【进场msg】[{user_id}][{gender}]{user_name} 进入了直播间")
# logger.info(f"队列数量: {queue.qsize()}") # logger.info(f"队列数量: {queue.qsize()}")
def parse_social_msg(payload, queue): def parse_social_msg(payload, queue):
'''关注消息''' '''关注消息'''
if not random.random() < live_chat_config.follow_prob / 100: # if not random.random() < live_chat_config.follow_prob / 100:
return # return
message = SocialMessage().parse(payload) message = SocialMessage().parse(payload)
user_name = message.user.nick_name user_name = message.user.nick_name
user_id = message.user.id user_id = message.user.id
prompt = live_chat_config.follow_prompt.format(user_name=user_name) prompt = live_chat_config.follow_prompt.format(user_name=user_name)
queue.put((prompt, None)) queue.put((MessageType.FOLLOW.value, prompt, None))
# logger.info(f"【关注msg】[{user_id}]{user_name} 关注了主播") # logger.info(f"【关注msg】[{user_id}]{user_name} 关注了主播")
# logger.info(f"队列数量: {queue.qsize()}") # logger.info(f"队列数量: {queue.qsize()}")

Loading…
Cancel
Save