You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
97 lines
3.1 KiB
Python
97 lines
3.1 KiB
Python
1 year ago
|
from qa_amend import QAService, match_query, store_data
|
||
|
from fastapi import FastAPI, HTTPException
|
||
|
from pydantic import BaseModel
|
||
|
import json
|
||
|
import os
|
||
|
import logging
|
||
|
import time
|
||
|
import requests
|
||
|
|
||
|
app = FastAPI()
|
||
|
kb_name = 'my_kb_test'
|
||
|
device = None
|
||
|
qa_service = QAService(kb_name, device)
|
||
|
|
||
|
logging.basicConfig(filename=r'E:\Project\BGE\virtual_patient_qa\app.log', level=logging.INFO)
|
||
|
|
||
|
class QuestionRequest(BaseModel):
|
||
|
question: str
|
||
|
|
||
|
class QuestionResponse(BaseModel):
|
||
|
code: int
|
||
|
msg: str
|
||
|
data: list
|
||
|
|
||
|
class QuestionItem(BaseModel):
|
||
|
questionCode: str
|
||
|
questionList: list[str]
|
||
|
|
||
|
@app.post("/updateDatabase")
|
||
|
async def save_to_json(question_items: list[QuestionItem]):
|
||
|
try:
|
||
|
json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2)
|
||
|
|
||
|
with open("output.json", "w", encoding="utf-8") as file:
|
||
|
file.write(json_data)
|
||
|
file.close() # 确保文件被关闭
|
||
|
|
||
|
if store_data("output.json"):
|
||
|
os.remove("output.json") # 删除文件
|
||
|
return {"status": "success", "message": "数据更新成功"}
|
||
|
else:
|
||
|
return {"status": "failure", "message": "数据未更新"}
|
||
|
|
||
|
except Exception as e:
|
||
|
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
|
||
|
|
||
|
@app.post("/matchQuestion")
|
||
|
def match_question(request: QuestionRequest):
|
||
|
try:
|
||
|
query = request.question
|
||
|
top_k = 3
|
||
|
score_threshold = 0.1
|
||
|
|
||
|
result = match_query(qa_service, query, top_k, score_threshold)
|
||
|
response = QuestionResponse(code=200, msg="success", data=result)
|
||
|
|
||
|
return response
|
||
|
|
||
|
except Exception as e:
|
||
|
logging.error(f"Error in /matchQuestion: {str(e)}")
|
||
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
||
|
def query_and_write_to_file(api_url, output_file_path):
|
||
|
try:
|
||
|
response = requests.get(api_url)
|
||
|
response_data = response.json()
|
||
|
|
||
|
if response.status_code == 200 and response_data["code"] == 200:
|
||
|
question_items = response_data["data"]
|
||
|
|
||
|
with open(output_file_path, "w", encoding="utf-8") as file:
|
||
|
json.dump(question_items, file, ensure_ascii=False, indent=2)
|
||
|
file.close() # 确保文件被关闭
|
||
|
|
||
|
logging.info(f"Data written to {output_file_path}")
|
||
|
|
||
|
if store_data(output_file_path):
|
||
|
os.remove(output_file_path) # 删除文件
|
||
|
return {"status": "success", "message": "数据初始化更新成功"}
|
||
|
else:
|
||
|
return {"status": "failure", "message": "数据初始化更新失败"}
|
||
|
else:
|
||
|
logging.error(f"接口调用失败:{response_data['msg']}")
|
||
|
return False
|
||
|
except Exception as e:
|
||
|
logging.error(f"发生错误:{str(e)}")
|
||
|
return False
|
||
|
|
||
|
api_url = "http://192.168.10.73:8891/virtual-patient-manage/qaKnowledge/queryQaKnowledge"
|
||
|
output_file_path = "output.json"
|
||
|
query_and_write_to_file(api_url, output_file_path)
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
import uvicorn
|
||
|
|
||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|