You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

170 lines
5.0 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from fastapi import FastAPI, HTTPException, BackgroundTasks
from qa_Ask import QAService, match_query, store_data
from pydantic import BaseModel
from collections import deque
import requests
import os
import time
import uuid
import json
import shutil
import yaml
import logging
app = FastAPI()
import sys
# 配置日志记录到文件和终端
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('log/app.log'),
logging.StreamHandler(sys.stdout) # 这里添加控制台处理程序
]
)
logger = logging.getLogger(__name__)
class QuestionRequest(BaseModel):
question: str
class QuestionResponse(BaseModel):
code: int
msg: str
data: list
class QuestionItem(BaseModel):
questionCode: str
questionList: list[str]
with open('config/config.yaml', 'r') as config_file:
config_data = yaml.safe_load(config_file)
knowledge_base_file = config_data['knowledge_base_file']
api_url = config_data['api']['url']
path = config_data['output_file_path']
max_knowledge_bases = config_data['max_knowledge_bases']
def load_knowledge_bases():
"""加载知识库名称列表"""
if os.path.exists(knowledge_base_file):
with open(knowledge_base_file, "r") as file:
return file.read().splitlines()
else:
return []
def save_knowledge_bases(names):
"""保存知识库名称列表到文件"""
with open(knowledge_base_file, "w") as file:
file.write("\n".join(names))
def update_kb(kb_name, qa_service, path, max_knowledge_bases):
"""更新知识库"""
store_data(qa_service, path)
if len(recent_knowledge_bases) == max_knowledge_bases:
folder_to_delete = recent_knowledge_bases.popleft()
shutil.rmtree(f"knowledge_base/{folder_to_delete}")
recent_knowledge_bases.append(kb_name)
save_knowledge_bases(recent_knowledge_bases)
os.remove(path)
logger.info(f"Knowledge base updated: {kb_name}")
return {"status": "success", "message": "数据库正在更新中"}
def fetch_and_write_data(api_url, path):
"""从API获取数据并写入文件"""
try:
response = requests.get(api_url)
response_data = response.json()
if response.status_code == 200 and response_data["code"] == 200:
question_items = response_data["data"]
with open(path, "w", encoding="utf-8") as file:
json.dump(question_items, file, ensure_ascii=False, indent=2)
logger.info(f"Data fetched and written to file: {path}")
return True
else:
logger.error(f"Failed to fetch data from API. Status code: {response.status_code}, Response data: {response_data}")
return False
except Exception as e:
logger.error(f"Error fetching data from API: {e}")
return False
@app.post("/updateDatabase")
async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks):
"""接收问题数据并异步保存为JSON文件触发后台更新任务"""
try:
json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2)
path = "output.json"
with open(path, "w", encoding="utf-8") as file:
file.write(json_data)
kb_name = str(uuid.uuid4())
device = None
qa_service = QAService(kb_name, device)
background_tasks.add_task(
update_kb, kb_name, qa_service, path, max_knowledge_bases
)
logger.info(f"Data saved to file: {path}, Knowledge base update task scheduled: {kb_name}")
return {"status": "success", "message": "数据库正在更新中"}
except Exception as e:
logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}")
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
@app.post("/matchQuestion")
def match_question(request: QuestionRequest):
"""匹配问题的端点"""
try:
start_time = time.time()
query = request.question
newest = recent_knowledge_bases[-1]
top_k = 3
score_threshold = 0.1
device = None
qa_service = QAService(newest, device)
result = match_query(qa_service, query, top_k, score_threshold)
response = QuestionResponse(code=200, msg="success", data=result)
stop_time = time.time()
logger.info(f"Matched question in {stop_time - start_time} seconds")
return response
except Exception as e:
logger.error(f"Error matching question: {e}")
raise HTTPException(status_code=500, detail=str(e))
recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)
if fetch_and_write_data(api_url, path):
kb_name = str(uuid.uuid4())
device = None
qa_service = QAService(kb_name, device)
update_kb(kb_name, qa_service, path, max_knowledge_bases)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)