diff --git a/helper.py b/helper.py index 645e3ba..64f722b 100644 --- a/helper.py +++ b/helper.py @@ -1,3 +1,4 @@ +import json import os import sqlite3 import sys @@ -121,16 +122,33 @@ class LiveChatConfig: return ollama_address @property - def system_messages(self) -> list: - results = [] + def system_messages(self) -> str: + results = {} cursor = self.conn.cursor() - cursor.execute('select message from system_message order by id') + cursor.execute('select message, type from system_message order by id') rows = cursor.fetchall() - for message in rows: - results.append(message[0]) + for row in rows: + message, _type = row + results[_type] = message cursor.close() + results = json.dumps(results, ensure_ascii=False, indent=4) return results + def messages(self, batch_number): + cursor = self.conn.cursor() + cursor.execute('select count(0) from message where batch_number = ?', (batch_number,)) + count = cursor.fetchone()[0] + cursor.close() + return count + + @property + def next_reply_message(self): + cursor = self.conn.cursor() + cursor.execute('select message, id from message where status = 0 and batch_number = 0 order by id limit 1') + count = cursor.fetchone() + cursor.close() + return count + @property def prohibited_words(self) -> dict: results = {} @@ -169,6 +187,31 @@ class LiveChatConfig: self.conn.commit() cursor.close() + def update_next_reply_status(self, status, _id): + cursor = self.conn.cursor() + cursor.execute("update message set status = ? where id = ?", (status, _id)) + self.conn.commit() + cursor.close() + + # insert message表 + def insert_message(self, message, _type, batch_number): + cursor = self.conn.cursor() + cursor.execute("insert into message (message, type, batch_number) values (?, ?, ?)", + (message, _type, batch_number)) + self.conn.commit() + cursor.close() + + def flush_message(self): + cursor = self.conn.cursor() + # 1.删除所有batch_number=0且status=1的数据 + cursor.execute('delete from message where batch_number = 0 and status = 1') + # 2.将batch_number=1的数据的更新为0 + cursor.execute('update message set batch_number = 0 where batch_number = 1') + # 3.生成新的备用系统文案,batch_number=1 + + self.conn.commit() + cursor.close() + class PromptQueue: def __init__(self, maxsize=0): diff --git a/liveMan.py b/liveMan.py index 45be89b..0fd7723 100644 --- a/liveMan.py +++ b/liveMan.py @@ -317,6 +317,7 @@ class DouyinLiveWebReply: self.live_chat_config = LiveChatConfig() self.punctuation = ",.!;:,。!?:;" self.system_message_index = 0 + self.response_queue = PromptQueue(10) def _llm(self, prompt, stream=False): payload = { @@ -385,6 +386,41 @@ class DouyinLiveWebReply: timeout=5 ) + def post_to_human_sync(self, text: str): + """ + 同步调用post_to_human + :param text: 要发送的文本内容 + """ + response = requests.post( + f'{self.live_chat_config.livetalking_address}/human', + json={ + "type": "echo", + "sessionid": self.live_chat_config.livetalking_sessionid, + "text": text + }, + timeout=5 + ) + if response.status_code != 200: + logger.error(f'Failed to post to human: {response.text}') + + def generate_messages(self, batch_number=0): + message_count = self.live_chat_config.messages(batch_number) + if message_count == 0: + logger.info(f'生成系统文案,batch_number: {batch_number}') + # 结合原始样例话术,拼接提示词,调用Ollama,生成洗稿后的话术 + system_messages = self.live_chat_config.system_messages + llm_prompt = self.live_chat_config.refine_system_message_prompt.format( + content=system_messages) + reply_messages = self._llm(llm_prompt, False) + # 处理reply_messages,先转换为json对象,将key和value分别对应type和message存入sqlite message表中,并统一给batch_number赋值为0 + # 正则匹配处理reply_messages,只保留大括号及其范围内的字符串 + reply_messages = re.findall(r'\{.*?\}', reply_messages, re.DOTALL)[0] + reply_messages = json.loads(reply_messages) + # 遍历reply_messages对象,insert message + for _type, message in reply_messages.items(): + self.live_chat_config.insert_message(message, _type, batch_number) + logger.info(f'入库备用系统文案:{_type} | {message}') + def __call__(self): """ 优先从用户交互队列中取提示词,如果没有用户交互的数据,则输出系统提示词 @@ -392,6 +428,8 @@ class DouyinLiveWebReply: live_chat_config.update_chat_enable_status('已启动') logger.info(f'livetalking address -> {self.live_chat_config.livetalking_address}') logger.info(f'ollama_address -> {self.live_chat_config.ollama_address}') + # 加一个计数器,统计is_speaking连续为False的次数,如果超过10次,才算真正的未在说话 + is_not_speaking_count = 0 while True: try: is_speaking = requests.post(f'{self.live_chat_config.livetalking_address}/is_speaking', @@ -399,76 +437,92 @@ class DouyinLiveWebReply: timeout=5).json()['data'] if is_speaking: time.sleep(0.1) - continue + prompt_data = self.queue.get(False) + if prompt_data is not None: + product_name, product_specification, product_description = self.live_chat_config.product_info + # live_chat: 弹幕 + message_type, prompt, live_chat = prompt_data + if message_type == MessageType.ENTER_LIVE_ROOM.value: + if random.random() >= self.live_chat_config.enter_live_room_prob / 100: + continue + else: + prompt = prompt.format(product_name=product_name, + product_specification=product_specification, + product_description=product_description) + elif message_type == MessageType.FOLLOW.value: + if random.random() >= self.live_chat_config.follow_prob / 100: + continue + else: + prompt = prompt.format(product_name=product_name, + product_specification=product_specification, + product_description=product_description) + elif message_type == MessageType.LIKE.value: + if random.random() >= self.live_chat_config.like_prob / 100: + continue + else: + prompt = prompt.format(product_name=product_name, + product_specification=product_specification, + product_description=product_description) + elif message_type == MessageType.GIFT.value: + if random.random() >= self.live_chat_config.gift_prob / 100: + continue + else: + prompt = prompt.format(product_name=product_name, + product_specification=product_specification, + product_description=product_description) + elif message_type == MessageType.CHAT.value: + if random.random() >= self.live_chat_config.chat_prob / 100: + continue + else: + prompt = prompt.format(product_name=product_name, + product_specification=product_specification, + product_description=product_description) + if live_chat is not None: + logger.info(f'弹幕: {live_chat}') + llm_output = self._llm( + self.live_chat_config.product_related_prompt.format(content=live_chat)) + logger.info(f'判断弹幕是否违反中国大陆法律和政策: {llm_output}') + if llm_output != '否': + continue + reply_message = self._llm(prompt, False) + self.response_queue.put(reply_message) + # is_speaking此时是False,需要等一段时间再查询 + time.sleep(0.5) + else: + # 用户交互队列为空,输出系统文案和备用系统文案 + if not self.live_chat_config.next_reply_message: + logger.info('备用系统文案已用完,重新生成备用系统文案') + self.live_chat_config.flush_message() + self.generate_messages(1) - prompt_data = self.queue.get(False) - if prompt_data is not None: - product_name, product_specification, product_description = self.live_chat_config.product_info - # live_chat: 弹幕 - message_type, prompt, live_chat = prompt_data - if message_type == MessageType.ENTER_LIVE_ROOM.value: - if random.random() >= self.live_chat_config.enter_live_room_prob / 100: - continue - else: - prompt = prompt.format(product_name=product_name, - product_specification=product_specification, - product_description=product_description) - elif message_type == MessageType.FOLLOW.value: - if random.random() >= self.live_chat_config.follow_prob / 100: - continue - else: - prompt = prompt.format(product_name=product_name, - product_specification=product_specification, - product_description=product_description) - elif message_type == MessageType.LIKE.value: - if random.random() >= self.live_chat_config.like_prob / 100: - continue - else: - prompt = prompt.format(product_name=product_name, - product_specification=product_specification, - product_description=product_description) - elif message_type == MessageType.GIFT.value: - if random.random() >= self.live_chat_config.gift_prob / 100: - continue - else: - prompt = prompt.format(product_name=product_name, - product_specification=product_specification, - product_description=product_description) - elif message_type == MessageType.CHAT.value: - if random.random() >= self.live_chat_config.chat_prob / 100: - continue - else: - prompt = prompt.format(product_name=product_name, - product_specification=product_specification, - product_description=product_description) - if live_chat is not None: - logger.info(f'弹幕: {live_chat}') - llm_output = self._llm(self.live_chat_config.product_related_prompt.format(content=live_chat)) - logger.info(f'判断弹幕是否违反中国大陆法律和政策: {llm_output}') - if llm_output != '否': - continue - reply_messages = self._llm(prompt, True) - for reply_message in reply_messages: - reply_message = self.reply_message_postprocess(reply_message) - asyncio.run(self.post_to_human(reply_message)) - logger.info(f'输出回复: {reply_message}') - # is_speaking此时是False,需要等一段时间再查询 - time.sleep(0.5) + continue else: - # 用户交互队列为空,输出系统文案 - # time.sleep(1) - system_messages = self.live_chat_config.system_messages - reply_message = system_messages[self.system_message_index] - self.system_message_index += 1 - if self.system_message_index >= len(system_messages): - self.system_message_index = 0 - llm_prompt = self.live_chat_config.refine_system_message_prompt.format(content=reply_message) - reply_messages = self._llm(llm_prompt, True) - for reply_message in reply_messages: - reply_message = self.reply_message_postprocess(reply_message) - asyncio.run(self.post_to_human(reply_message)) - logger.info(f'输出系统文案: {reply_message}') - time.sleep(1) + time.sleep(0.1) + is_not_speaking_count += 1 + if is_not_speaking_count == 20: + logger.info('连续20次请求Livetalking未在说话,开始回复') + # 调用Livetalking说话 + # 判断response_queue是否为空,如果不为空,则取出回复内容并调用livetalking,否则从数据库中取出备用系统文案 + reply_message = '' + if not self.response_queue.empty(): + reply_message = self.response_queue.get() + reply_message = self.reply_message_postprocess(reply_message) + else: + reply_message_data = self.live_chat_config.next_reply_message + if not reply_message_data: + logger.info('备用系统文案已用完,重新生成备用系统文案') + self.generate_messages(0) + self.generate_messages(1) + continue + reply_message, _id = reply_message_data + # 说完之后把状态改为1 + logger.info(f'更新备用系统文案id:{_id}状态为: 1') + self.live_chat_config.update_next_reply_status(1, _id) + + # asyncio.run(self.post_to_human(reply_message)) + logger.info(f'开始播放: {reply_message}') + self.post_to_human_sync(reply_message) + is_not_speaking_count = 0 except Exception: # 发生异常,输出系统文案 diff --git a/test.py b/test.py new file mode 100644 index 0000000..5e9d4aa --- /dev/null +++ b/test.py @@ -0,0 +1,43 @@ +import re +import sqlite3 +from settings import sqlite_file +import os +import json + +def messages(): + conn = sqlite3.connect('D:/code/live_chat.db') + cursor = conn.cursor() + cursor.execute( + 'select message, type, id from message where status = 0 and batch_number = 0 order by id limit 1') + count = cursor.fetchone() + cursor.close() + return count + +print(messages()) + +# +# reply_messages = """ +# ```json +# { +# "开场欢迎": "家人们晚上好呀!今天小妹特意准备了咱们男同胞最爱的纯粮老白干,这波福利真的超值 +# !您们先别急着走,听我唠唠,今天这个价格只有今晚能抢到,明天就恢复原价啦!", +# "酒的背景与历史": "咱们中国喝酒的讲究可太有讲究啦!老白干这可是有上千年历史的宝藏酒,传承了 +# 数百年的古法工艺,每一代酿酒师傅都用心打磨。想想看,几百年前的古人也是这样品酒的,现在咱们还能尝 +# 到同样的味道,是不是特别有缘分?这酒产自北方,有'千年传承酒中瑰宝'的称号,固态发酵工艺让酒香更醇 +# 厚,入口绵长,回味悠长。", +# "口感细致描绘": "重点说说这酒的口感!喝酒最怕什么?怕辣嗓子、怕上头、怕第二天头疼对吧?老白 +# 干的秘诀就是'入口柔顺,落喉清爽'。第一口就像老朋友轻拍肩膀那么舒服,咽下去后不仅不刺激喉咙,反而 +# 在嘴里留下淡淡粮食香。酒劲儿恰到好处,温润不烈,喝过的都说这才是'纯粮好酒'!平时小酌解压又提神, +# 第二天还能精神满满。", +# "饮用场景渲染": "这样的好酒放在家里,随时来一杯都超享受!晚上聚餐时倒上小半杯,配点下酒菜, +# 暖洋洋的特别惬意。朋友聚会来两杯,不辣喉不烧头,喝着才自在。一个人独饮也超舒服,配点花生米,一杯 +# 下肚整个人都放松了。过年送礼更显体面,自己喝着放心,送人也倍儿有面子。", +# "工艺与匠心": "为什么这酒喝着这么舒服?因为坚持用纯粮固态发酵,每一步都严格把关,用的是玉米 +# 、小麦、高粱这些天然原料,发酵周期长达数月。不像有些酒掺点酒精加点香精就敢卖!喝酒要喝得安心,老 +# 白干就靠这份用心,才让几代酒友都爱不释手。" +# } +# ``` +# """ +# reply_messages = re.findall(r'\{.*?\}', reply_messages, re.DOTALL) +# print(reply_messages[0]) +#