from fastapi import FastAPI, HTTPException, BackgroundTasks from qa_Ask import QAService, match_query, store_data from pydantic import BaseModel from collections import deque import requests import os import time import uuid import json import shutil import yaml import logging app = FastAPI() import sys # 配置日志记录到文件和终端 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('log/app.log'), logging.StreamHandler(sys.stdout) # 添加控制台处理程序 ] ) logger = logging.getLogger(__name__) class QuestionRequest(BaseModel): question: str scoreThreshold: float class QuestionResponse(BaseModel): code: int msg: str data: list class QuestionItem(BaseModel): questionId: str questionList: list[str] class InputText(BaseModel): inputText: str class ExtractedInfo(BaseModel): name: str cardNumber: str idNnumber: str # 读取配置文件 with open('config/config.yaml', 'r') as config_file: config_data = yaml.safe_load(config_file) knowledge_base_file = config_data['knowledge_base_file'] api_url = config_data['api']['url'] path = config_data['output_file_path'] max_knowledge_bases = config_data['max_knowledge_bases'] def load_knowledge_bases(): """加载知识库名称列表""" if os.path.exists(knowledge_base_file): with open(knowledge_base_file, "r") as file: return file.read().splitlines() else: return [] def save_knowledge_bases(names): """保存知识库名称列表到文件""" with open(knowledge_base_file, "w") as file: file.write("\n".join(names)) def update_kb(kb_name, qa_service, path, max_knowledge_bases): """更新知识库""" store_data(qa_service, path) if len(recent_knowledge_bases) == max_knowledge_bases: folder_to_delete = recent_knowledge_bases.popleft() shutil.rmtree(f"knowledge_base/{folder_to_delete}") recent_knowledge_bases.append(kb_name) save_knowledge_bases(recent_knowledge_bases) os.remove(path) logger.info(f"Knowledge base updated: {kb_name}\n" f"Please wait while the database is being updated···") recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases) def text_to_number(text_id): chinese_nums = {'零': '0', '一': '1', '二': '2', '三': '3', '四': '4', '五': '5', '六': '6', '七': '7', '八': '8', '九': '9'} for chinese_num, arabic_num in chinese_nums.items(): text_id = text_id.replace(chinese_num, arabic_num) return text_id @app.post("/updateDatabase") async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks): """接收问题数据并异步保存为JSON文件,触发后台更新任务""" try: json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2) path = "output.json" with open(path, "w", encoding="utf-8") as file: file.write(json_data) kb_name = str(uuid.uuid4()) device = None qa_service = QAService(kb_name, device) background_tasks.add_task( update_kb, kb_name, qa_service, path, max_knowledge_bases ) return {"status": "success", "message": "Please wait while the database is being updated···"} except Exception as e: logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}") # raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") return {"status": "error", "message": "update task error···"} @app.post("/matchQuestion") def match_question(request: QuestionRequest): """匹配问题的端点""" try: logger.info(f"match_question:Request: {request}") start_time = time.time() query = request.question newest = recent_knowledge_bases[-1] top_k = 3 device = None qa_service = QAService(newest, device) result = match_query(qa_service, query, top_k, request.scoreThreshold) response = QuestionResponse(code=200, msg="success", data=result) stop_time = time.time() duration = stop_time - start_time logger.info(f"match_question:Matched question in {duration} seconds. " f"Response: {result}") return response except Exception as e: logger.error(f"Error matching question: {e}") return QuestionResponse(code=500, msg="success", data=[]) @app.post("/extractInformation") async def extract_information(input_data: InputText): """提取信息的端点""" try: inputText = input_data.inputText from paddlenlp import Taskflow corrector = Taskflow("text_correction") data = corrector(inputText) target_value = data[0]['target'] converted_id = text_to_number(target_value) schema = ["姓名", '嫌疑人', '涉案人员', "身份证号", "交易证件号", "卡号", "交易卡号", "银行卡号", ] ie = Taskflow('information_extraction', schema=schema, model='uie-base') extracted_info = ie(converted_id) result = {} for item in extracted_info: for key, value in item.items(): result[key.lower()] = value[0]['text'] extracted_result = ExtractedInfo( name=result.get('姓名', '') or result.get('嫌疑人', '') or result.get('涉案人员', ''), cardNumber=result.get('卡号', '') or result.get('交易卡号', '') or result.get('银行卡号', ''), idNnumber=result.get('身份证号', '') or result.get('交易证件号', '') or result.get('交易证件号', '') ) return extracted_result except Exception as e: logger.error(f"Error extracting information: {e}") raise HTTPException(status_code=500, detail="Internal Server Error") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)