You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

197 lines
6.1 KiB
Python

# coding=gbk
from fastapi import FastAPI, HTTPException, BackgroundTasks
from qa_Ask import QAService, match_query, store_data
from pydantic import BaseModel
from collections import deque
import requests
import os
import time
import uuid
import json
import shutil
import yaml
import sys
import logging
from typing import List
from faiss_kb_service import FaissKBService
from faiss_kb_service import EmbeddingsFunAdapter
app = FastAPI()
# <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>־<EFBFBD><D6BE>¼<EFBFBD><C2BC><EFBFBD>ļ<EFBFBD><C4BC><EFBFBD><EFBFBD>ն<EFBFBD>
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('log/app.log'),
logging.StreamHandler(sys.stdout) # <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ӿ<EFBFBD><D3BF><EFBFBD>̨<EFBFBD><CCA8><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
]
)
logger = logging.getLogger(__name__)
class QuestionRequest(BaseModel):
question: str
scoreThreshold: float
class EmbeddingResponse(BaseModel):
embeddings: List[float]
class QuestionResponse(BaseModel):
code: int
msg: str
data: list
class QuestionItem(BaseModel):
questionId: str
questionList: list[str]
class EmbeddingRequest(BaseModel):
text: str
with open('config/config.yaml', 'r') as config_file:
config_data = yaml.safe_load(config_file)
knowledge_base_file = config_data['knowledge_base_file']
api_url = config_data['api']['url']
path = config_data['output_file_path']
max_knowledge_bases = config_data['max_knowledge_bases']
def load_knowledge_bases():
"""<EFBFBD><EFBFBD><EFBFBD><EFBFBD>֪ʶ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>б<EFBFBD>"""
if os.path.exists(knowledge_base_file):
with open(knowledge_base_file, "r") as file:
return file.read().splitlines()
else:
return []
def save_knowledge_bases(names):
"""<EFBFBD><EFBFBD><EFBFBD><EFBFBD>֪ʶ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>б<EFBFBD><EFBFBD><EFBFBD><EFBFBD>ļ<EFBFBD>"""
with open(knowledge_base_file, "w") as file:
file.write("\n".join(names))
def update_kb(kb_name, qa_service, path, max_knowledge_bases):
"""<EFBFBD><EFBFBD><EFBFBD><EFBFBD>֪ʶ<EFBFBD><EFBFBD>"""
store_data(qa_service, path)
if len(recent_knowledge_bases) == max_knowledge_bases:
folder_to_delete = recent_knowledge_bases.popleft()
shutil.rmtree(f"knowledge_base/{folder_to_delete}")
recent_knowledge_bases.append(kb_name)
save_knowledge_bases(recent_knowledge_bases)
os.remove(path)
logger.info(f"Knowledge base updated: {kb_name}\n"
f"Please wait while the database is being updated<65><64><EFBFBD><EFBFBD><EFBFBD><EFBFBD>")
def fetch_and_write_data(api_url, path):
"""<EFBFBD><EFBFBD>API<EFBFBD><EFBFBD>ȡ<EFBFBD><EFBFBD><EFBFBD>ݲ<EFBFBD>д<EFBFBD><EFBFBD><EFBFBD>ļ<EFBFBD>"""
try:
response = requests.get(api_url)
response_data = response.json()
if response.status_code == 200 and response_data["code"] == 200:
question_items = response_data["data"]
with open(path, "w", encoding="utf-8") as file:
json.dump(question_items, file, ensure_ascii=False, indent=2)
return True
else:
logger.error(f"Failed to fetch data from API. Status code: {response.status_code}, Response data: {response_data}")
return False
except Exception as e:
logger.error(f"Error fetching data from API: {e}")
return False
bge_large_zh_v1_5_config = config_data.get('bge_large_zh_v1_5', {})
embed_model_path = bge_large_zh_v1_5_config.get('embed_model_path', 'default_path_if_not_provided')
recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)
faiss_service = FaissKBService(kb_name= recent_knowledge_bases[-1], embed_model_path=embed_model_path, device=None)
@app.post("/embeddings/", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest):
"""ʹ<EFBFBD><EFBFBD>FaissKBServiceʵ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ȡǶ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>"""
embed_func = EmbeddingsFunAdapter(faiss_service.embed_model_path, faiss_service.device)
try:
embeddings = embed_func.embed_query(request.text)
embeddings_list = embeddings
return EmbeddingResponse(embeddings=embeddings_list)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/updateDatabase")
async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks):
"""<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ݲ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ΪJSON<EFBFBD>ļ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>̨<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>"""
try:
json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2)
path = "output.json"
with open(path, "w", encoding="utf-8") as file:
file.write(json_data)
kb_name = str(uuid.uuid4())
device = None
qa_service = QAService(kb_name, device)
background_tasks.add_task(
update_kb, kb_name, qa_service, path, max_knowledge_bases
)
return {"status": "success", "message": "Please wait while the database is being updated<65><64><EFBFBD><EFBFBD><EFBFBD><EFBFBD>"}
except Exception as e:
logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}")
# raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
return {"status": "error", "message": "update task error<6F><72><EFBFBD><EFBFBD><EFBFBD><EFBFBD>"}
@app.post("/matchQuestion")
def match_question(request: QuestionRequest):
"""ƥ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ķ˵<EFBFBD>"""
try:
1 year ago
logger.info(f"match_question:Request: {request}")
start_time = time.time()
query = request.question
newest = recent_knowledge_bases[-1]
top_k = 3
device = None
qa_service = QAService(newest, device)
result = match_query(qa_service, query, top_k, request.scoreThreshold)
response = QuestionResponse(code=200, msg="success", data=result)
stop_time = time.time()
duration = stop_time - start_time
1 year ago
logger.info(f"match_question:Matched question in {duration} seconds. "
f"Response: {result}")
return response
except Exception as e:
logger.error(f"Error matching question: {e}")
return QuestionResponse(code=500, msg="success", data=[])
recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)
if fetch_and_write_data(api_url, path):
kb_name = str(uuid.uuid4())
device = None
qa_service = QAService(kb_name, device)
update_kb(kb_name, qa_service, path, max_knowledge_bases)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)