You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

312 lines
9.8 KiB
Python

import json
import os
import sqlite3
import sys
from enum import Enum
from multiprocessing import Queue
from settings import sqlite_file
from queue import Empty
import random
import json
class MessageType(Enum):
ENTER_LIVE_ROOM = 1
FOLLOW = 2
LIKE = 3
GIFT = 4
CHAT = 5
def resource_path(relative_path, base_dir='.'):
# PyInstaller打包后会把文件放到临时文件夹 _MEIPASS
try:
base_path = sys._MEIPASS
except AttributeError:
base_path = os.path.abspath(base_dir)
return os.path.join(base_path, relative_path)
class LiveChatConfig:
def __init__(self):
# self.conn = sqlite3.connect(resource_path(sqlite_file, '..'))
self.conn = sqlite3.connect(os.path.join('..', sqlite_file))
self.conn.execute("PRAGMA journal_mode=WAL;")
def _query_config(self, key):
cursor = self.conn.cursor()
cursor.execute('select value from config where key=?', (key,))
result = cursor.fetchone()
result = result[0] if result is not None else None
cursor.close()
return result
@property
def enter_live_room_prob(self):
result = self._query_config('reply_prob_enter_live_room')
return int(result) if result is not None else None
@property
def follow_prob(self):
result = self._query_config('reply_prob_follow')
return int(result) if result is not None else None
@property
def like_prob(self):
result = self._query_config('reply_prob_like')
return int(result) if result is not None else None
@property
def gift_prob(self):
result = self._query_config('reply_prob_gift')
return int(result) if result is not None else None
@property
def chat_prob(self):
result = self._query_config('reply_prob_chat')
return int(result) if result is not None else None
@property
def enter_live_room_prompt(self):
return self._query_config('enter_live_room_prompt')
@property
def follow_prompt(self):
return self._query_config('follow_prompt')
@property
def like_prompt(self):
return self._query_config('like_prompt')
@property
def gift_prompt(self):
return self._query_config('gift_prompt')
@property
def chat_prompt(self):
return self._query_config('chat_prompt')
@property
def product_related_prompt(self):
return self._query_config('product_related_prompt')
@property
def backend_token(self):
return self._query_config('backend_token')
@property
def refine_system_message_prompt(self):
return self._query_config('refine_system_message_prompt')
@property
def live_id(self):
cursor = self.conn.cursor()
cursor.execute("select value from live_config where key='live_id'")
live_id = cursor.fetchone()[0]
cursor.close()
return live_id
@property
def livetalking_address(self):
cursor = self.conn.cursor()
cursor.execute("select value from live_config where key='livetalking_address'")
livetalking_address = cursor.fetchone()[0]
cursor.close()
return livetalking_address
@property
def ollama_address(self):
cursor = self.conn.cursor()
cursor.execute("select value from live_config where key='ollama_address'")
ollama_address = cursor.fetchone()[0]
cursor.close()
return ollama_address
@property
def system_messages(self) -> str:
results = {}
cursor = self.conn.cursor()
cursor.execute('select message, type from system_message order by id')
rows = cursor.fetchall()
for row in rows:
message, _type = row
results[_type] = message
cursor.close()
results = json.dumps(results, ensure_ascii=False, indent=4)
return results
def messages(self, batch_number):
cursor = self.conn.cursor()
cursor.execute('select count(0) from message where batch_number = ?', (batch_number,))
count = cursor.fetchone()[0]
cursor.close()
return count
@property
def precedence_reply_message(self):
cursor = self.conn.cursor()
cursor.execute('select message, id from message where status = 2 and batch_number = 0 order by id limit 1')
count = cursor.fetchone()
cursor.close()
return count
@property
def next_reply_message(self):
cursor = self.conn.cursor()
cursor.execute('select message, id from message where status = 0 and batch_number = 0 order by id limit 1')
count = cursor.fetchone()
cursor.close()
return count
@property
def prohibited_words(self) -> dict:
results = {}
cursor = self.conn.cursor()
cursor.execute('select word, substitutes from prohibited_word')
rows = cursor.fetchall()
for word, substitutes in rows:
results[word] = substitutes
cursor.close()
return results
@property
def product_info(self):
cursor = self.conn.cursor()
cursor.execute("select value from live_config where key = 'product_name'")
product_name = cursor.fetchone()[0]
cursor.execute("select value from live_config where key = 'product_specification'")
product_specification = cursor.fetchone()[0]
cursor.execute("select value from live_config where key = 'product_description'")
product_description = cursor.fetchone()[0]
cursor.close()
return product_name, product_specification, product_description
@property
def livetalking_sessionid(self):
cursor = self.conn.cursor()
cursor.execute("select value from live_config where key = 'livetalking_sessionid'")
livetalking_sessionid = cursor.fetchone()[0]
cursor.close()
return int(livetalking_sessionid)
def update_chat_enable_status(self, status):
status_dict = {'未启动': 0, '启动中': 1, '已启动': 2, '启动失败': 3}
cursor = self.conn.cursor()
cursor.execute("update live_config set value = ? where key = 'chat_enable_status'", (status_dict[status],))
self.conn.commit()
cursor.close()
def flush_precedence_reply_message(self):
cursor = self.conn.cursor()
cursor.execute("UPDATE message SET status = CASE WHEN id > (SELECT MIN(id) FROM message WHERE status = 2) THEN 0 WHEN id <= (SELECT MIN(id) FROM message WHERE status = 2) THEN 1 ELSE status END;")
self.conn.commit()
cursor.close()
def update_next_reply_status(self, status, _id):
cursor = self.conn.cursor()
cursor.execute("update message set status = ? where id = ?", (status, _id))
self.conn.commit()
cursor.close()
# insert message表
def insert_message(self, message, _type, batch_number):
cursor = self.conn.cursor()
cursor.execute("insert into message (message, type, batch_number) values (?, ?, ?)",
(message, _type, batch_number))
self.conn.commit()
cursor.close()
def flush_message(self):
cursor = self.conn.cursor()
cursor.execute('delete from message where batch_number = 0 and status = 1')
self.conn.commit()
cursor.close()
def system_messages_random_mix_dict(self, product_id=None, ensure_mixed=True):
"""
随机跨组拼接,返回 {type: message} 的字典 JSON
"""
cursor = self.conn.cursor()
if product_id is None:
cursor.execute(
"""
select script_type_order_num, script_group_id, type, message
from system_message
order by script_type_order_num asc, script_group_id asc, id asc
"""
)
rows = cursor.fetchall()
else:
cursor.execute(
"""
select script_type_order_num, script_group_id, type, message
from system_message
where product_id = ?
order by script_type_order_num asc, script_group_id asc, id asc
""",
(product_id,)
)
rows = cursor.fetchall()
cursor.close()
# 分层整理
by_order = {}
for order_num, group_id, _type, msg in rows:
by_order.setdefault(order_num, [])
by_order[order_num].append({
"order": order_num,
"group": group_id,
"type": _type,
"message": msg
})
# 逐层随机挑选
result_dict = {}
used_groups = set()
orders = sorted(by_order.keys())
for order_num in orders:
candidates = by_order[order_num]
if not candidates:
continue
if ensure_mixed:
prefer = [c for c in candidates if c["group"] not in used_groups]
pick_source = prefer if prefer else candidates
else:
pick_source = candidates
choice = random.choice(pick_source)
result_dict[choice["type"]] = choice["message"]
used_groups.add(choice["group"])
return json.dumps(result_dict, ensure_ascii=False, indent=4)
class PromptQueue:
def __init__(self, maxsize=0):
self.queue = Queue(maxsize)
self.maxsize = maxsize
def put(self, item):
if self.queue.full():
try:
self.queue.get_nowait() # 丢掉最旧的数据
except:
pass
self.queue.put(item)
def get(self, block=True, timeout=None):
try:
return self.queue.get(block, timeout)
except Empty:
return None
def qsize(self):
return self.queue.qsize()
def empty(self):
return self.queue.empty()
def full(self):
return self.queue.full()