diff --git a/helper.py b/helper.py index 4112b77..761ed96 100644 --- a/helper.py +++ b/helper.py @@ -142,6 +142,26 @@ class LiveChatConfig: cursor.close() return results + @property + def product_info(self): + cursor = self.conn.cursor() + cursor.execute("select value from live_config where key = 'product_name'") + product_name = cursor.fetchone()[0] + cursor.execute("select value from live_config where key = 'product_specification'") + product_specification = cursor.fetchone()[0] + cursor.execute("select value from live_config where key = 'product_description'") + product_description = cursor.fetchone()[0] + cursor.close() + return product_name, product_specification, product_description + + @property + def livetalking_sessionid(self): + cursor = self.conn.cursor() + cursor.execute("select value from live_config where key = 'livetalking_sessionid'") + livetalking_sessionid = cursor.fetchone()[0] + cursor.close() + return int(livetalking_sessionid) + def update_chat_enable_status(self): cursor = self.conn.cursor() cursor.execute("update live_config set value = 1 where key = 'chat_enable_status'") diff --git a/liveMan.py b/liveMan.py index eb26d06..2d4b2ea 100644 --- a/liveMan.py +++ b/liveMan.py @@ -306,7 +306,6 @@ class DouyinLiveWebReply: def __init__(self, queue: PromptQueue): self.queue = queue self.system_text_list = [] - self.session_id = 0 self.backend_token = '' self.live_chat_config = LiveChatConfig() self.punctuation = ",.!;:,。!?:;" @@ -328,9 +327,10 @@ class DouyinLiveWebReply: } def _gen(): - response = requests.post(f'{self.live_chat_config.ollama_address}/live-digital-avatar-manage/ollama/generate', json=payload, - headers={'Authorization': f'Bearer {self.live_chat_config.backend_token}'}, - stream=True) + response = requests.post( + f'{self.live_chat_config.ollama_address}/live-digital-avatar-manage/ollama/generate', json=payload, + headers={'Authorization': f'Bearer {self.live_chat_config.backend_token}'}, + stream=True) buffer = '' for line in response.iter_lines(): if not line: @@ -351,9 +351,10 @@ class DouyinLiveWebReply: if stream: return _gen() else: - response = requests.post(f'{self.live_chat_config.ollama_address}/live-digital-avatar-manage/ollama/generate', json=payload, - headers={'Authorization': f'Bearer {self.live_chat_config.backend_token}'}, - timeout=10).content.decode()[5:] + response = requests.post( + f'{self.live_chat_config.ollama_address}/live-digital-avatar-manage/ollama/generate', json=payload, + headers={'Authorization': f'Bearer {self.live_chat_config.backend_token}'}, + timeout=10).content.decode()[5:] response = json.loads(response) return response['message']['content'] @@ -370,7 +371,7 @@ class DouyinLiveWebReply: f'{self.live_chat_config.livetalking_address}/human', json={ "type": "echo", - "sessionid": self.session_id, + "sessionid": self.live_chat_config.livetalking_sessionid, "text": text }, timeout=5 @@ -381,9 +382,12 @@ class DouyinLiveWebReply: 优先从用户交互队列中取提示词,如果没有用户交互的数据,则输出系统提示词 """ live_chat_config.update_chat_enable_status() + logger.info(f'livetalking address -> {self.live_chat_config.livetalking_address}') + logger.info(f'ollama_address -> {self.live_chat_config.ollama_address}') while True: try: - is_speaking = requests.post(f'{self.live_chat_config.livetalking_address}/is_speaking', json={'sessionid': self.session_id}, + is_speaking = requests.post(f'{self.live_chat_config.livetalking_address}/is_speaking', + json={'sessionid': self.live_chat_config.livetalking_sessionid}, timeout=5).json()['data'] if is_speaking: time.sleep(0.1) @@ -391,23 +395,44 @@ class DouyinLiveWebReply: prompt_data = self.queue.get(False) if prompt_data is not None: + product_name, product_specification, product_description = self.live_chat_config.product_info # live_chat: 弹幕 message_type, prompt, live_chat = prompt_data - if message_type == MessageType.ENTER_LIVE_ROOM.value and \ - random.random() >= self.live_chat_config.enter_live_room_prob / 100: - continue - elif message_type == MessageType.FOLLOW.value and \ - random.random() >= self.live_chat_config.follow_prob / 100: - continue - elif message_type == MessageType.LIKE.value and \ - random.random() >= self.live_chat_config.like_prob / 100: - continue - elif message_type == MessageType.GIFT.value and \ - random.random() >= self.live_chat_config.gift_prob / 100: - continue - elif message_type == MessageType.CHAT.value and \ - random.random() >= self.live_chat_config.chat_prob / 100: - continue + if message_type == MessageType.ENTER_LIVE_ROOM.value: + if random.random() >= self.live_chat_config.enter_live_room_prob / 100: + continue + else: + prompt = prompt.format(product_name=product_name, + product_specification=product_specification, + product_description=product_description) + elif message_type == MessageType.FOLLOW.value: + if random.random() >= self.live_chat_config.follow_prob / 100: + continue + else: + prompt = prompt.format(product_name=product_name, + product_specification=product_specification, + product_description=product_description) + elif message_type == MessageType.LIKE.value: + if random.random() >= self.live_chat_config.like_prob / 100: + continue + else: + prompt = prompt.format(product_name=product_name, + product_specification=product_specification, + product_description=product_description) + elif message_type == MessageType.GIFT.value: + if random.random() >= self.live_chat_config.gift_prob / 100: + continue + else: + prompt = prompt.format(product_name=product_name, + product_specification=product_specification, + product_description=product_description) + elif message_type == MessageType.CHAT.value: + if random.random() >= self.live_chat_config.chat_prob / 100: + continue + else: + prompt = prompt.format(product_name=product_name, + product_specification=product_specification, + product_description=product_description) if live_chat is not None: logger.info(f'弹幕: {live_chat}') llm_output = self._llm(self.live_chat_config.product_related_prompt.format(content=live_chat)) diff --git a/message_processor.py b/message_processor.py index 26d3e3a..45208a0 100644 --- a/message_processor.py +++ b/message_processor.py @@ -14,7 +14,7 @@ def parse_chat_msg(payload, queue): user_name = message.user.nick_name user_id = message.user.id content = message.content - prompt = live_chat_config.chat_prompt.format(content=content) + prompt = live_chat_config.chat_prompt.format(content=content, product_name='{product_name}', product_specification='{product_specification}', product_description='{product_description}') queue.put((MessageType.CHAT.value, prompt, content)) # logger.info(f"【聊天msg】[{user_id}]{user_name}: {content}") # logger.info(f"队列数量: {queue.qsize()}") @@ -28,7 +28,7 @@ def parse_gif_msg(payload, queue): user_name = message.user.nick_name gift_name = message.gift.name gift_count = message.combo_count - prompt = live_chat_config.gift_prompt.format(user_name=user_name, gift_count=gift_count, gift_name=gift_name) + prompt = live_chat_config.gift_prompt.format(user_name=user_name, gift_count=gift_count, gift_name=gift_name, product_name='{product_name}', product_specification='{product_specification}', product_description='{product_description}') queue.put((MessageType.GIFT.value, prompt, None)) # logger.info(f"【礼物msg】{user_name} 送出了 {gift_name}x{gift_count}") # logger.info(f"队列数量: {queue.qsize()}") @@ -41,7 +41,7 @@ def parse_like_msg(payload, queue): message = LikeMessage().parse(payload) user_name = message.user.nick_name count = message.count - prompt = live_chat_config.like_prompt.format(user_name=user_name, count=count) + prompt = live_chat_config.like_prompt.format(user_name=user_name, count=count, product_name='{product_name}', product_specification='{product_specification}', product_description='{product_description}') queue.put((MessageType.LIKE.value, prompt, None)) # logger.info(f"【点赞msg】{user_name} 点了{count}个赞") # logger.info(f"队列数量: {queue.qsize()}") @@ -57,7 +57,7 @@ def parse_member_msg(payload, queue): gender = message.user.gender if gender in (0, 1): gender = ["女", "男"][gender] - prompt = live_chat_config.enter_live_room_prompt.format(user_name=user_name) + prompt = live_chat_config.enter_live_room_prompt.format(user_name=user_name, product_name='{product_name}', product_specification='{product_specification}', product_description='{product_description}') queue.put((MessageType.ENTER_LIVE_ROOM.value, prompt, None)) # logger.info(f"【进场msg】[{user_id}][{gender}]{user_name} 进入了直播间") # logger.info(f"队列数量: {queue.qsize()}") @@ -70,7 +70,7 @@ def parse_social_msg(payload, queue): message = SocialMessage().parse(payload) user_name = message.user.nick_name user_id = message.user.id - prompt = live_chat_config.follow_prompt.format(user_name=user_name) + prompt = live_chat_config.follow_prompt.format(user_name=user_name, product_name='{product_name}', product_specification='{product_specification}', product_description='{product_description}') queue.put((MessageType.FOLLOW.value, prompt, None)) # logger.info(f"【关注msg】[{user_id}]{user_name} 关注了主播") # logger.info(f"队列数量: {queue.qsize()}")