diff --git a/fast_api.py b/fast_api.py index 070dda2..10003fc 100644 --- a/fast_api.py +++ b/fast_api.py @@ -1,3 +1,4 @@ +# coding=gbk from fastapi import FastAPI, HTTPException, BackgroundTasks from qa_Ask import QAService, match_query, store_data from pydantic import BaseModel @@ -9,27 +10,31 @@ import uuid import json import shutil import yaml +import sys import logging - +from typing import List +from faiss_kb_service import FaissKBService +from faiss_kb_service import EmbeddingsFunAdapter app = FastAPI() -import sys - -# 配置日志记录到文件和终端 +# ־¼ļն logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('log/app.log'), - logging.StreamHandler(sys.stdout) # 这里添加控制台处理程序 + logging.StreamHandler(sys.stdout) # ӿ̨ ] ) logger = logging.getLogger(__name__) class QuestionRequest(BaseModel): question: str + scoreThreshold: float +class EmbeddingResponse(BaseModel): + embeddings: List[float] class QuestionResponse(BaseModel): @@ -39,9 +44,11 @@ class QuestionResponse(BaseModel): class QuestionItem(BaseModel): - questionCode: str + questionId: str questionList: list[str] +class EmbeddingRequest(BaseModel): + text: str with open('config/config.yaml', 'r') as config_file: config_data = yaml.safe_load(config_file) @@ -53,7 +60,7 @@ max_knowledge_bases = config_data['max_knowledge_bases'] def load_knowledge_bases(): - """加载知识库名称列表""" + """֪ʶб""" if os.path.exists(knowledge_base_file): with open(knowledge_base_file, "r") as file: return file.read().splitlines() @@ -62,13 +69,13 @@ def load_knowledge_bases(): def save_knowledge_bases(names): - """保存知识库名称列表到文件""" + """֪ʶбļ""" with open(knowledge_base_file, "w") as file: file.write("\n".join(names)) def update_kb(kb_name, qa_service, path, max_knowledge_bases): - """更新知识库""" + """֪ʶ""" store_data(qa_service, path) if len(recent_knowledge_bases) == max_knowledge_bases: @@ -80,11 +87,11 @@ def update_kb(kb_name, qa_service, path, max_knowledge_bases): os.remove(path) logger.info(f"Knowledge base updated: {kb_name}\n" - f"Please wait while the database is being updated···") + f"Please wait while the database is being updated") def fetch_and_write_data(api_url, path): - """从API获取数据并写入文件""" + """APIȡݲдļ""" try: response = requests.get(api_url) response_data = response.json() @@ -102,11 +109,26 @@ def fetch_and_write_data(api_url, path): except Exception as e: logger.error(f"Error fetching data from API: {e}") return False +bge_large_zh_v1_5_config = config_data.get('bge_large_zh_v1_5', {}) +embed_model_path = bge_large_zh_v1_5_config.get('embed_model_path', 'default_path_if_not_provided') +recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases) +faiss_service = FaissKBService(kb_name= recent_knowledge_bases[-1], embed_model_path=embed_model_path, device=None) + +@app.post("/embeddings/", response_model=EmbeddingResponse) +async def get_embeddings(request: EmbeddingRequest): + """ʹFaissKBServiceʵȡǶ""" + embed_func = EmbeddingsFunAdapter(faiss_service.embed_model_path, faiss_service.device) + try: + embeddings = embed_func.embed_query(request.text) + embeddings_list = embeddings + return EmbeddingResponse(embeddings=embeddings_list) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) @app.post("/updateDatabase") async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks): - """接收问题数据并异步保存为JSON文件,触发后台更新任务""" + """ݲ첽ΪJSONļ̨""" try: json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2) path = "output.json" @@ -123,16 +145,17 @@ async def save_to_json(question_items: list[QuestionItem], background_tasks: Bac update_kb, kb_name, qa_service, path, max_knowledge_bases ) - return {"status": "success", "message": "Please wait while the database is being updated···"} + return {"status": "success", "message": "Please wait while the database is being updated"} except Exception as e: logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}") # raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") - return {"status": "error", "message": "update task error···"} + return {"status": "error", "message": "update task error"} + @app.post("/matchQuestion") def match_question(request: QuestionRequest): - """匹配问题的端点""" + """ƥĶ˵""" try: logger.info(f"match_question:Request: {request}") start_time = time.time() @@ -141,12 +164,11 @@ def match_question(request: QuestionRequest): newest = recent_knowledge_bases[-1] top_k = 3 - score_threshold = 0.1 device = None qa_service = QAService(newest, device) - result = match_query(qa_service, query, top_k, score_threshold) + result = match_query(qa_service, query, top_k, request.scoreThreshold) response = QuestionResponse(code=200, msg="success", data=result) stop_time = time.time() diff --git a/qa_Ask.py b/qa_Ask.py index 9cf3479..b89bb10 100644 --- a/qa_Ask.py +++ b/qa_Ask.py @@ -51,7 +51,7 @@ def load_testing_data(file_path): with open(file_path, encoding='utf-8') as f: data = json.load(f) for item in data: - question_code = item['questionCode'] + question_code = item['questionId'] question_list.extend(item['questionList']) id_list.extend([create_question_id(question_code, j, q) for j, q in enumerate(item['questionList'])])