You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

187 lines
5.7 KiB
Python

from fastapi import FastAPI, HTTPException, BackgroundTasks
from qa_Ask import QAService, match_query, store_data
from pydantic import BaseModel
from collections import deque
import requests
import os
import time
import uuid
import json
import shutil
import yaml
import logging
app = FastAPI()
import sys
# 配置日志记录到文件和终端
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('log/app.log'),
logging.StreamHandler(sys.stdout) # 添加控制台处理程序
]
)
logger = logging.getLogger(__name__)
class QuestionRequest(BaseModel):
question: str
11 months ago
scoreThreshold: float
class QuestionResponse(BaseModel):
code: int
msg: str
data: list
class QuestionItem(BaseModel):
11 months ago
questionId: str
questionList: list[str]
class InputText(BaseModel):
inputText: str
class ExtractedInfo(BaseModel):
name: str
cardNumber: str
idNnumber: str
# 读取配置文件
with open('config/config.yaml', 'r') as config_file:
config_data = yaml.safe_load(config_file)
knowledge_base_file = config_data['knowledge_base_file']
api_url = config_data['api']['url']
path = config_data['output_file_path']
max_knowledge_bases = config_data['max_knowledge_bases']
def load_knowledge_bases():
"""加载知识库名称列表"""
if os.path.exists(knowledge_base_file):
with open(knowledge_base_file, "r") as file:
return file.read().splitlines()
else:
return []
def save_knowledge_bases(names):
"""保存知识库名称列表到文件"""
with open(knowledge_base_file, "w") as file:
file.write("\n".join(names))
def update_kb(kb_name, qa_service, path, max_knowledge_bases):
"""更新知识库"""
store_data(qa_service, path)
if len(recent_knowledge_bases) == max_knowledge_bases:
folder_to_delete = recent_knowledge_bases.popleft()
shutil.rmtree(f"knowledge_base/{folder_to_delete}")
recent_knowledge_bases.append(kb_name)
save_knowledge_bases(recent_knowledge_bases)
os.remove(path)
logger.info(f"Knowledge base updated: {kb_name}\n"
f"Please wait while the database is being updated···")
recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)
def text_to_number(text_id):
chinese_nums = {'': '0', '': '1', '': '2', '': '3', '': '4', '': '5', '': '6', '': '7', '': '8', '': '9'}
for chinese_num, arabic_num in chinese_nums.items():
text_id = text_id.replace(chinese_num, arabic_num)
return text_id
@app.post("/updateDatabase")
async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks):
"""接收问题数据并异步保存为JSON文件触发后台更新任务"""
try:
json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2)
path = "output.json"
with open(path, "w", encoding="utf-8") as file:
file.write(json_data)
kb_name = str(uuid.uuid4())
device = None
qa_service = QAService(kb_name, device)
background_tasks.add_task(
update_kb, kb_name, qa_service, path, max_knowledge_bases
)
return {"status": "success", "message": "Please wait while the database is being updated···"}
except Exception as e:
logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}")
# raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
return {"status": "error", "message": "update task error···"}
@app.post("/matchQuestion")
def match_question(request: QuestionRequest):
"""匹配问题的端点"""
try:
logger.info(f"match_question:Request: {request}")
start_time = time.time()
query = request.question
newest = recent_knowledge_bases[-1]
top_k = 3
device = None
qa_service = QAService(newest, device)
11 months ago
result = match_query(qa_service, query, top_k, request.scoreThreshold)
response = QuestionResponse(code=200, msg="success", data=result)
stop_time = time.time()
duration = stop_time - start_time
logger.info(f"match_question:Matched question in {duration} seconds. "
f"Response: {result}")
return response
except Exception as e:
logger.error(f"Error matching question: {e}")
return QuestionResponse(code=500, msg="success", data=[])
@app.post("/extractInformation/")
async def extract_information(input_data: InputText):
"""提取信息的端点"""
try:
inputText = input_data.inputText
from paddlenlp import Taskflow
corrector = Taskflow("text_correction")
data = corrector(inputText)
target_value = data[0]['target']
converted_id = text_to_number(target_value)
schema = ["姓名", "卡号", "身份证号"]
ie = Taskflow('information_extraction', schema=schema, model='uie-base')
extracted_info = ie(converted_id)
result = {}
for item in extracted_info:
for key, value in item.items():
result[key.lower()] = value[0]['text']
extracted_result = ExtractedInfo(name=result.get('姓名', ''),
cardNumber=result.get('卡号', ''),
idNnumber=result.get('身份证号', ''))
return extracted_result
except Exception as e:
logger.error(f"Error extracting information: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)