From a92df7fe60077a8f4adb04ca3315ea646d885128 Mon Sep 17 00:00:00 2001 From: fanpt <320622572@qq.com> Date: Mon, 4 Mar 2024 14:49:53 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0api=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fast_api.py | 96 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 fast_api.py diff --git a/fast_api.py b/fast_api.py new file mode 100644 index 0000000..ed331e3 --- /dev/null +++ b/fast_api.py @@ -0,0 +1,96 @@ +from qa_amend import QAService, match_query, store_data +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +import json +import os +import logging +import time +import requests + +app = FastAPI() +kb_name = 'my_kb_test' +device = None +qa_service = QAService(kb_name, device) + +logging.basicConfig(filename=r'E:\Project\BGE\virtual_patient_qa\app.log', level=logging.INFO) + +class QuestionRequest(BaseModel): + question: str + +class QuestionResponse(BaseModel): + code: int + msg: str + data: list + +class QuestionItem(BaseModel): + questionCode: str + questionList: list[str] + +@app.post("/updateDatabase") +async def save_to_json(question_items: list[QuestionItem]): + try: + json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2) + + with open("output.json", "w", encoding="utf-8") as file: + file.write(json_data) + file.close() # 确保文件被关闭 + + if store_data("output.json"): + os.remove("output.json") # 删除文件 + return {"status": "success", "message": "数据更新成功"} + else: + return {"status": "failure", "message": "数据未更新"} + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") + +@app.post("/matchQuestion") +def match_question(request: QuestionRequest): + try: + query = request.question + top_k = 3 + score_threshold = 0.1 + + result = match_query(qa_service, query, top_k, score_threshold) + response = QuestionResponse(code=200, msg="success", data=result) + + return response + + except Exception as e: + logging.error(f"Error in /matchQuestion: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + +def query_and_write_to_file(api_url, output_file_path): + try: + response = requests.get(api_url) + response_data = response.json() + + if response.status_code == 200 and response_data["code"] == 200: + question_items = response_data["data"] + + with open(output_file_path, "w", encoding="utf-8") as file: + json.dump(question_items, file, ensure_ascii=False, indent=2) + file.close() # 确保文件被关闭 + + logging.info(f"Data written to {output_file_path}") + + if store_data(output_file_path): + os.remove(output_file_path) # 删除文件 + return {"status": "success", "message": "数据初始化更新成功"} + else: + return {"status": "failure", "message": "数据初始化更新失败"} + else: + logging.error(f"接口调用失败:{response_data['msg']}") + return False + except Exception as e: + logging.error(f"发生错误:{str(e)}") + return False + +api_url = "http://192.168.10.73:8891/virtual-patient-manage/qaKnowledge/queryQaKnowledge" +output_file_path = "output.json" +query_and_write_to_file(api_url, output_file_path) + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8000)