# coding=gbk from fastapi import FastAPI, HTTPException, BackgroundTasks from qa_Ask import QAService, match_query, store_data from pydantic import BaseModel from collections import deque import requests import os import time import uuid import json import shutil import yaml import sys import logging from typing import List from faiss_kb_service import FaissKBService from faiss_kb_service import EmbeddingsFunAdapter app = FastAPI() # 配置日志记录到文件和终端 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('log/app.log'), logging.StreamHandler(sys.stdout) # 这里添加控制台处理程序 ] ) logger = logging.getLogger(__name__) class QuestionRequest(BaseModel): question: str scoreThreshold: float class EmbeddingResponse(BaseModel): embeddings: List[float] class QuestionResponse(BaseModel): code: int msg: str data: list class QuestionItem(BaseModel): questionId: str questionList: list[str] class EmbeddingRequest(BaseModel): text: str with open('config/config.yaml', 'r') as config_file: config_data = yaml.safe_load(config_file) knowledge_base_file = config_data['knowledge_base_file'] api_url = config_data['api']['url'] path = config_data['output_file_path'] max_knowledge_bases = config_data['max_knowledge_bases'] def load_knowledge_bases(): """加载知识库名称列表""" if os.path.exists(knowledge_base_file): with open(knowledge_base_file, "r") as file: return file.read().splitlines() else: return [] def save_knowledge_bases(names): """保存知识库名称列表到文件""" with open(knowledge_base_file, "w") as file: file.write("\n".join(names)) def update_kb(kb_name, qa_service, path, max_knowledge_bases): """更新知识库""" store_data(qa_service, path) if len(recent_knowledge_bases) == max_knowledge_bases: folder_to_delete = recent_knowledge_bases.popleft() shutil.rmtree(f"knowledge_base/{folder_to_delete}") recent_knowledge_bases.append(kb_name) save_knowledge_bases(recent_knowledge_bases) os.remove(path) logger.info(f"Knowledge base updated: {kb_name}\n" f"Please wait while the database is being updated···") def fetch_and_write_data(api_url, path): """从API获取数据并写入文件""" try: response = requests.get(api_url) response_data = response.json() if response.status_code == 200 and response_data["code"] == 200: question_items = response_data["data"] with open(path, "w", encoding="utf-8") as file: json.dump(question_items, file, ensure_ascii=False, indent=2) return True else: logger.error(f"Failed to fetch data from API. Status code: {response.status_code}, Response data: {response_data}") return False except Exception as e: logger.error(f"Error fetching data from API: {e}") return False bge_large_zh_v1_5_config = config_data.get('bge_large_zh_v1_5', {}) embed_model_path = bge_large_zh_v1_5_config.get('embed_model_path', 'default_path_if_not_provided') recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases) faiss_service = FaissKBService(kb_name= recent_knowledge_bases[-1], embed_model_path=embed_model_path, device=None) @app.post("/embeddings/", response_model=EmbeddingResponse) async def get_embeddings(request: EmbeddingRequest): """使用FaissKBService实例来获取嵌入向量""" embed_func = EmbeddingsFunAdapter(faiss_service.embed_model_path, faiss_service.device) try: embeddings = embed_func.embed_query(request.text) embeddings_list = embeddings return EmbeddingResponse(embeddings=embeddings_list) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/updateDatabase") async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks): """接收问题数据并异步保存为JSON文件,触发后台更新任务""" try: json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2) path = "output.json" with open(path, "w", encoding="utf-8") as file: file.write(json_data) kb_name = str(uuid.uuid4()) device = None qa_service = QAService(kb_name, device) background_tasks.add_task( update_kb, kb_name, qa_service, path, max_knowledge_bases ) return {"status": "success", "message": "Please wait while the database is being updated···"} except Exception as e: logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}") # raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") return {"status": "error", "message": "update task error···"} @app.post("/matchQuestion") def match_question(request: QuestionRequest): """匹配问题的端点""" try: logger.info(f"match_question:Request: {request}") start_time = time.time() query = request.question newest = recent_knowledge_bases[-1] top_k = 3 device = None qa_service = QAService(newest, device) result = match_query(qa_service, query, top_k, request.scoreThreshold) response = QuestionResponse(code=200, msg="success", data=result) stop_time = time.time() duration = stop_time - start_time logger.info(f"match_question:Matched question in {duration} seconds. " f"Response: {result}") return response except Exception as e: logger.error(f"Error matching question: {e}") return QuestionResponse(code=500, msg="success", data=[]) recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases) if fetch_and_write_data(api_url, path): kb_name = str(uuid.uuid4()) device = None qa_service = QAService(kb_name, device) update_kb(kb_name, qa_service, path, max_knowledge_bases) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)