|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
import yaml
|
|
|
import sys
|
|
|
import os
|
|
|
import time
|
|
|
import uuid
|
|
|
import json
|
|
|
import shutil
|
|
|
import logging
|
|
|
from collections import deque
|
|
|
from pydantic import BaseModel
|
|
|
from fastapi import BackgroundTasks
|
|
|
from fastapi import FastAPI, HTTPException
|
|
|
from qa_Ask import QAService, match_query, store_data
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
# 配置日志记录到文件和终端
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO,
|
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
|
handlers=[
|
|
|
logging.FileHandler('log/app.log'),
|
|
|
logging.StreamHandler(sys.stdout) # 添加控制台处理程序
|
|
|
]
|
|
|
)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class QuestionRequest(BaseModel):
|
|
|
question: str
|
|
|
scoreThreshold: float
|
|
|
|
|
|
|
|
|
class QuestionResponse(BaseModel):
|
|
|
code: int
|
|
|
msg: str
|
|
|
data: list
|
|
|
|
|
|
|
|
|
class QuestionItem(BaseModel):
|
|
|
questionId: str
|
|
|
questionList: list[str]
|
|
|
|
|
|
|
|
|
class InputText(BaseModel):
|
|
|
inputText: str
|
|
|
|
|
|
|
|
|
class ExtractedInfo(BaseModel):
|
|
|
name: str
|
|
|
cardNumber: str
|
|
|
idNumber: str
|
|
|
|
|
|
|
|
|
with open('config/config.yaml', 'r') as config_file:
|
|
|
config_data = yaml.safe_load(config_file)
|
|
|
|
|
|
knowledge_base_file = config_data['knowledge_base_file']
|
|
|
api_url = config_data['api']['url']
|
|
|
path = config_data['output_file_path']
|
|
|
max_knowledge_bases = config_data['max_knowledge_bases']
|
|
|
|
|
|
|
|
|
def load_knowledge_bases():
|
|
|
"""加载知识库名称列表"""
|
|
|
if os.path.exists(knowledge_base_file):
|
|
|
with open(knowledge_base_file, "r") as file:
|
|
|
return file.read().splitlines()
|
|
|
else:
|
|
|
return []
|
|
|
|
|
|
|
|
|
def save_knowledge_bases(names):
|
|
|
"""保存知识库名称列表到文件"""
|
|
|
with open(knowledge_base_file, "w") as file:
|
|
|
file.write("\n".join(names))
|
|
|
|
|
|
|
|
|
def update_kb(kb_name, qa_service, path, max_knowledge_bases):
|
|
|
"""更新知识库"""
|
|
|
store_data(qa_service, path)
|
|
|
|
|
|
if len(recent_knowledge_bases) == max_knowledge_bases:
|
|
|
folder_to_delete = recent_knowledge_bases.popleft()
|
|
|
shutil.rmtree(f"knowledge_base/{folder_to_delete}")
|
|
|
|
|
|
recent_knowledge_bases.append(kb_name)
|
|
|
save_knowledge_bases(recent_knowledge_bases)
|
|
|
|
|
|
os.remove(path)
|
|
|
logger.info(f"Knowledge base updated: {kb_name}\n"
|
|
|
f"Please wait while the database is being updated···")
|
|
|
|
|
|
|
|
|
recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)
|
|
|
|
|
|
|
|
|
def text_to_number(text_id):
|
|
|
chinese_nums = {'零': '0', '一': '1', '二': '2', '三': '3', '四': '4', '五': '5', '六': '6', '七': '7', '八': '8', '九': '9'}
|
|
|
translation_table = str.maketrans(chinese_nums)
|
|
|
return text_id.translate(translation_table)
|
|
|
|
|
|
|
|
|
@app.post("/updateDatabase")
|
|
|
async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks):
|
|
|
"""接收问题数据并异步保存为JSON文件,触发后台更新任务"""
|
|
|
try:
|
|
|
json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2)
|
|
|
path = "output.json"
|
|
|
|
|
|
with open(path, "w", encoding="utf-8") as file:
|
|
|
file.write(json_data)
|
|
|
|
|
|
kb_name = str(uuid.uuid4())
|
|
|
|
|
|
device = None
|
|
|
qa_service = QAService(kb_name, device)
|
|
|
|
|
|
background_tasks.add_task(
|
|
|
update_kb, kb_name, qa_service, path, max_knowledge_bases
|
|
|
)
|
|
|
|
|
|
return {"status": "success", "message": "Please wait while the database is being updated···"}
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}")
|
|
|
# raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
|
|
|
return {"status": "error", "message": "update task error···"}
|
|
|
|
|
|
|
|
|
@app.post("/matchQuestion")
|
|
|
def match_question(request: QuestionRequest):
|
|
|
"""匹配问题的端点"""
|
|
|
try:
|
|
|
logger.info(f"match_question:Request: {request}")
|
|
|
start_time = time.time()
|
|
|
query = request.question
|
|
|
|
|
|
newest = recent_knowledge_bases[-1]
|
|
|
|
|
|
top_k = 3
|
|
|
|
|
|
device = None
|
|
|
qa_service = QAService(newest, device)
|
|
|
|
|
|
result = match_query(qa_service, query, top_k, request.scoreThreshold)
|
|
|
|
|
|
response = QuestionResponse(code=200, msg="success", data=result)
|
|
|
stop_time = time.time()
|
|
|
duration = stop_time - start_time
|
|
|
|
|
|
logger.info(f"match_question:Matched question in {duration} seconds. "
|
|
|
f"Response: {result}")
|
|
|
return response
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error matching question: {e}")
|
|
|
return QuestionResponse(code=500, msg="success", data=[])
|
|
|
|
|
|
|
|
|
from paddlenlp import Taskflow
|
|
|
corrector = Taskflow("text_correction")
|
|
|
schema = ["姓名", '嫌疑人', '涉案人员', "身份证号", "交易证件号", "卡号", "交易卡号", "银行卡号", ]
|
|
|
|
|
|
name = Taskflow('information_extraction', schema=schema[:2], model='uie-base')
|
|
|
identity = Taskflow('information_extraction', schema=schema[3:5], model='uie-base')
|
|
|
card = Taskflow('information_extraction', schema=schema[5:8], model='uie-base')
|
|
|
|
|
|
|
|
|
@app.post("/extractInformation")
|
|
|
async def extract_information(input_data: InputText):
|
|
|
"""提取信息的端点"""
|
|
|
try:
|
|
|
input_text = input_data.inputText
|
|
|
|
|
|
data = corrector(input_text)
|
|
|
target_value = data[0]['target']
|
|
|
converted_id = text_to_number(target_value + '。')
|
|
|
|
|
|
extracted_info = {}
|
|
|
for model_name, model in zip(["name", "identity", "card"], [name, identity, card]):
|
|
|
extracted_info[model_name] = model(converted_id)
|
|
|
|
|
|
result = {}
|
|
|
for model_name, info_list in extracted_info.items():
|
|
|
for item in info_list:
|
|
|
for key, value in item.items():
|
|
|
result[key.lower()] = value[0]['text']
|
|
|
|
|
|
extracted_result = ExtractedInfo(
|
|
|
name=result.get('姓名', '') or result.get('嫌疑人', '') or result.get('涉案人员', ''),
|
|
|
cardNumber=result.get('卡号', '') or result.get('交易卡号', '') or result.get('银行卡号', ''),
|
|
|
idNumber=result.get('身份证号', '') or result.get('交易证件号', '') or result.get('交易证件号', '')
|
|
|
)
|
|
|
|
|
|
return extracted_result
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error extracting information: {e}")
|
|
|
raise HTTPException(status_code=500, detail="Internal Server Error")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import uvicorn
|
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
|