# -*- coding: utf-8 -*- import yaml import sys import os import time import uuid import json import shutil import logging from collections import deque from pydantic import BaseModel from fastapi import BackgroundTasks from fastapi import FastAPI, HTTPException from qa_Ask import QAService, match_query, store_data app = FastAPI() # 配置日志记录到文件和终端 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('log/app.log'), logging.StreamHandler(sys.stdout) # 添加控制台处理程序 ] ) logger = logging.getLogger(__name__) class QuestionRequest(BaseModel): question: str scoreThreshold: float class QuestionResponse(BaseModel): code: int msg: str data: list class QuestionItem(BaseModel): questionId: str questionList: list[str] class InputText(BaseModel): inputText: str class ExtractedInfo(BaseModel): name: str cardNumber: str idNumber: str with open('config/config.yaml', 'r') as config_file: config_data = yaml.safe_load(config_file) knowledge_base_file = config_data['knowledge_base_file'] api_url = config_data['api']['url'] path = config_data['output_file_path'] max_knowledge_bases = config_data['max_knowledge_bases'] def load_knowledge_bases(): """加载知识库名称列表""" if os.path.exists(knowledge_base_file): with open(knowledge_base_file, "r") as file: return file.read().splitlines() else: return [] def save_knowledge_bases(names): """保存知识库名称列表到文件""" with open(knowledge_base_file, "w") as file: file.write("\n".join(names)) def update_kb(kb_name, qa_service, path, max_knowledge_bases): """更新知识库""" store_data(qa_service, path) if len(recent_knowledge_bases) == max_knowledge_bases: folder_to_delete = recent_knowledge_bases.popleft() shutil.rmtree(f"knowledge_base/{folder_to_delete}") recent_knowledge_bases.append(kb_name) save_knowledge_bases(recent_knowledge_bases) os.remove(path) logger.info(f"Knowledge base updated: {kb_name}\n" f"Please wait while the database is being updated···") recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases) def text_to_number(text_id): chinese_nums = {'零': '0', '一': '1', '二': '2', '三': '3', '四': '4', '五': '5', '六': '6', '七': '7', '八': '8', '九': '9'} translation_table = str.maketrans(chinese_nums) return text_id.translate(translation_table) @app.post("/updateDatabase") async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks): """接收问题数据并异步保存为JSON文件,触发后台更新任务""" try: json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2) path = "output.json" with open(path, "w", encoding="utf-8") as file: file.write(json_data) kb_name = str(uuid.uuid4()) device = None qa_service = QAService(kb_name, device) background_tasks.add_task( update_kb, kb_name, qa_service, path, max_knowledge_bases ) return {"status": "success", "message": "Please wait while the database is being updated···"} except Exception as e: logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}") # raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") return {"status": "error", "message": "update task error···"} @app.post("/matchQuestion") def match_question(request: QuestionRequest): """匹配问题的端点""" try: logger.info(f"match_question:Request: {request}") start_time = time.time() query = request.question newest = recent_knowledge_bases[-1] top_k = 3 device = None qa_service = QAService(newest, device) result = match_query(qa_service, query, top_k, request.scoreThreshold) response = QuestionResponse(code=200, msg="success", data=result) stop_time = time.time() duration = stop_time - start_time logger.info(f"match_question:Matched question in {duration} seconds. " f"Response: {result}") return response except Exception as e: logger.error(f"Error matching question: {e}") return QuestionResponse(code=500, msg="success", data=[]) from paddlenlp import Taskflow corrector = Taskflow("text_correction") schema = ["姓名", '嫌疑人', '涉案人员', "身份证号", "交易证件号", "卡号", "交易卡号", "银行卡号", ] name = Taskflow('information_extraction', schema=schema[:2], model='uie-base') identity = Taskflow('information_extraction', schema=schema[3:5], model='uie-base') card = Taskflow('information_extraction', schema=schema[5:8], model='uie-base') @app.post("/extractInformation") async def extract_information(input_data: InputText): """提取信息的端点""" try: input_text = input_data.inputText data = corrector(input_text) target_value = data[0]['target'] converted_id = text_to_number(target_value + '。') extracted_info = {} for model_name, model in zip(["name", "identity", "card"], [name, identity, card]): extracted_info[model_name] = model(converted_id) result = {} for model_name, info_list in extracted_info.items(): for item in info_list: for key, value in item.items(): result[key.lower()] = value[0]['text'] extracted_result = ExtractedInfo( name=result.get('姓名', '') or result.get('嫌疑人', '') or result.get('涉案人员', ''), cardNumber=result.get('卡号', '') or result.get('交易卡号', '') or result.get('银行卡号', ''), idNumber=result.get('身份证号', '') or result.get('交易证件号', '') or result.get('交易证件号', '') ) return extracted_result except Exception as e: logger.error(f"Error extracting information: {e}") raise HTTPException(status_code=500, detail="Internal Server Error") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)