You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

210 lines
6.3 KiB
Python

# -*- coding: utf-8 -*-
import yaml
import sys
import os
import time
import uuid
import json
import shutil
import logging
from collections import deque
from pydantic import BaseModel
from fastapi import BackgroundTasks
from fastapi import FastAPI, HTTPException
from qa_Ask import QAService, match_query, store_data
app = FastAPI()
# 配置日志记录到文件和终端
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('log/app.log'),
logging.StreamHandler(sys.stdout) # 添加控制台处理程序
]
)
logger = logging.getLogger(__name__)
class QuestionRequest(BaseModel):
question: str
11 months ago
scoreThreshold: float
class QuestionResponse(BaseModel):
code: int
msg: str
data: list
class QuestionItem(BaseModel):
11 months ago
questionId: str
questionList: list[str]
class InputText(BaseModel):
inputText: str
class ExtractedInfo(BaseModel):
name: str
cardNumber: str
idNumber: str
with open('config/config.yaml', 'r') as config_file:
config_data = yaml.safe_load(config_file)
knowledge_base_file = config_data['knowledge_base_file']
api_url = config_data['api']['url']
path = config_data['output_file_path']
max_knowledge_bases = config_data['max_knowledge_bases']
def load_knowledge_bases():
"""加载知识库名称列表"""
if os.path.exists(knowledge_base_file):
with open(knowledge_base_file, "r") as file:
return file.read().splitlines()
else:
return []
def save_knowledge_bases(names):
"""保存知识库名称列表到文件"""
with open(knowledge_base_file, "w") as file:
file.write("\n".join(names))
def update_kb(kb_name, qa_service, path, max_knowledge_bases):
"""更新知识库"""
store_data(qa_service, path)
if len(recent_knowledge_bases) == max_knowledge_bases:
folder_to_delete = recent_knowledge_bases.popleft()
shutil.rmtree(f"knowledge_base/{folder_to_delete}")
recent_knowledge_bases.append(kb_name)
save_knowledge_bases(recent_knowledge_bases)
os.remove(path)
logger.info(f"Knowledge base updated: {kb_name}\n"
f"Please wait while the database is being updated···")
recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)
def text_to_number(text_id):
chinese_nums = {'': '0', '': '1', '': '2', '': '3', '': '4', '': '5', '': '6', '': '7', '': '8', '': '9'}
translation_table = str.maketrans(chinese_nums)
return text_id.translate(translation_table)
@app.post("/updateDatabase")
async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks):
"""接收问题数据并异步保存为JSON文件触发后台更新任务"""
try:
json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2)
path = "output.json"
with open(path, "w", encoding="utf-8") as file:
file.write(json_data)
kb_name = str(uuid.uuid4())
device = None
qa_service = QAService(kb_name, device)
background_tasks.add_task(
update_kb, kb_name, qa_service, path, max_knowledge_bases
)
return {"status": "success", "message": "Please wait while the database is being updated···"}
except Exception as e:
logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}")
# raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
return {"status": "error", "message": "update task error···"}
@app.post("/matchQuestion")
def match_question(request: QuestionRequest):
"""匹配问题的端点"""
try:
logger.info(f"match_question:Request: {request}")
start_time = time.time()
query = request.question
newest = recent_knowledge_bases[-1]
top_k = 3
device = None
qa_service = QAService(newest, device)
11 months ago
result = match_query(qa_service, query, top_k, request.scoreThreshold)
response = QuestionResponse(code=200, msg="success", data=result)
stop_time = time.time()
duration = stop_time - start_time
logger.info(f"match_question:Matched question in {duration} seconds. "
f"Response: {result}")
return response
except Exception as e:
logger.error(f"Error matching question: {e}")
return QuestionResponse(code=500, msg="success", data=[])
from paddlenlp import Taskflow
corrector = Taskflow("text_correction")
schema = ["姓名", '嫌疑人', '涉案人员', "身份证号", "交易证件号", "卡号", "交易卡号", "银行卡号", ]
name = Taskflow('information_extraction', schema=schema[:2], model='uie-base')
identity = Taskflow('information_extraction', schema=schema[3:5], model='uie-base')
card = Taskflow('information_extraction', schema=schema[5:8], model='uie-base')
11 months ago
@app.post("/extractInformation")
async def extract_information(input_data: InputText):
"""提取信息的端点"""
try:
input_text = input_data.inputText
data = corrector(input_text)
target_value = data[0]['target']
converted_id = text_to_number(target_value + '')
extracted_info = {}
for model_name, model in zip(["name", "identity", "card"], [name, identity, card]):
extracted_info[model_name] = model(converted_id)
result = {}
for model_name, info_list in extracted_info.items():
for item in info_list:
for key, value in item.items():
result[key.lower()] = value[0]['text']
extracted_result = ExtractedInfo(
name=result.get('姓名', '') or result.get('嫌疑人', '') or result.get('涉案人员', ''),
cardNumber=result.get('卡号', '') or result.get('交易卡号', '') or result.get('银行卡号', ''),
idNumber=result.get('身份证号', '') or result.get('交易证件号', '') or result.get('交易证件号', '')
)
return extracted_result
except Exception as e:
logger.error(f"Error extracting information: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)