diff --git a/fast_api.py b/fast_api.py index 9eb0b3b..070dda2 100644 --- a/fast_api.py +++ b/fast_api.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class QuestionRequest(BaseModel): question: str - ScoreThreshold: float + class QuestionResponse(BaseModel): @@ -141,13 +141,14 @@ def match_question(request: QuestionRequest): newest = recent_knowledge_bases[-1] top_k = 3 + score_threshold = 0.1 device = None qa_service = QAService(newest, device) - result = match_query(qa_service, query, top_k, request.ScoreThreshold) + result = match_query(qa_service, query, top_k, score_threshold) - response = QuestionResponse(code=200, msg="success", data=[result]) + response = QuestionResponse(code=200, msg="success", data=result) stop_time = time.time() duration = stop_time - start_time