You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

198 lines
6.2 KiB
Python

# -*- coding: utf-8 -*-
from fastapi import FastAPI, HTTPException, BackgroundTasks
from qa_Ask import QAService, match_query, store_data
from pydantic import BaseModel
from collections import deque
import requests
import os
import time
import uuid
import json
import shutil
import yaml
import sys
import logging
from typing import List
from faiss_kb_service import FaissKBService
from faiss_kb_service import EmbeddingsFunAdapter
app = FastAPI()
# 配置日志记录到文件和终端
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('log/app.log'),
logging.StreamHandler(sys.stdout) # 这里添加控制台处理程序
]
)
logger = logging.getLogger(__name__)
class QuestionRequest(BaseModel):
question: str
scoreThreshold: float
class EmbeddingResponse(BaseModel):
embeddings: List[float]
class QuestionResponse(BaseModel):
code: int
msg: str
data: list
class QuestionItem(BaseModel):
questionId: str
questionList: list[str]
class EmbeddingRequest(BaseModel):
text: str
with open('config/config.yaml', 'r') as config_file:
config_data = yaml.safe_load(config_file)
knowledge_base_file = config_data['knowledge_base_file']
api_url = config_data['api']['url']
path = config_data['output_file_path']
max_knowledge_bases = config_data['max_knowledge_bases']
def load_knowledge_bases():
"""加载知识库名称列表"""
if os.path.exists(knowledge_base_file):
with open(knowledge_base_file, "r") as file:
return file.read().splitlines()
else:
return []
def save_knowledge_bases(names):
"""保存知识库名称列表到文件"""
with open(knowledge_base_file, "w") as file:
file.write("\n".join(names))
def update_kb(kb_name, qa_service, path, max_knowledge_bases):
"""更新知识库"""
store_data(qa_service, path)
if len(recent_knowledge_bases) == max_knowledge_bases:
folder_to_delete = recent_knowledge_bases.popleft()
shutil.rmtree(f"knowledge_base/{folder_to_delete}")
recent_knowledge_bases.append(kb_name)
save_knowledge_bases(recent_knowledge_bases)
os.remove(path)
logger.info(f"Knowledge base updated: {kb_name}\n"
f"Please wait while the database is being updated···")
def fetch_and_write_data(api_url, path):
"""从API获取数据并写入文件"""
try:
response = requests.get(api_url)
response_data = response.json()
if response.status_code == 200 and response_data["code"] == 200:
question_items = response_data["data"]
with open(path, "w", encoding="utf-8") as file:
json.dump(question_items, file, ensure_ascii=False, indent=2)
return True
else:
logger.error(f"Failed to fetch data from API. Status code: {response.status_code}, Response data: {response_data}")
return False
except Exception as e:
logger.error(f"Error fetching data from API: {e}")
return False
bge_large_zh_v1_5_config = config_data.get('bge_large_zh_v1_5', {})
embed_model_path = bge_large_zh_v1_5_config.get('embed_model_path', 'default_path_if_not_provided')
recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)
faiss_service = FaissKBService(kb_name= recent_knowledge_bases[-1], embed_model_path=embed_model_path, device=None)
@app.post("/embeddings/", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest):
"""使用FaissKBService实例来获取嵌入向量"""
embed_func = EmbeddingsFunAdapter(faiss_service.embed_model_path, faiss_service.device)
try:
embeddings = embed_func.embed_query(request.text)
embeddings_list = embeddings
return EmbeddingResponse(embeddings=embeddings_list)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/updateDatabase")
async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks):
"""接收问题数据并异步保存为JSON文件触发后台更新任务"""
try:
json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2)
path = "output.json"
with open(path, "w", encoding="utf-8") as file:
file.write(json_data)
kb_name = str(uuid.uuid4())
device = None
qa_service = QAService(kb_name, device)
background_tasks.add_task(
update_kb, kb_name, qa_service, path, max_knowledge_bases
)
return {"status": "success", "message": "Please wait while the database is being updated···"}
except Exception as e:
logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}")
# raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
return {"status": "error", "message": "update task error···"}
@app.post("/matchQuestion")
def match_question(request: QuestionRequest):
"""匹配问题的端点"""
try:
1 year ago
logger.info(f"match_question:Request: {request}")
start_time = time.time()
query = request.question
newest = recent_knowledge_bases[-1]
top_k = 3
device = None
qa_service = QAService(newest, device)
result = match_query(qa_service, query, top_k, request.scoreThreshold)
response = QuestionResponse(code=200, msg="success", data=result)
stop_time = time.time()
duration = stop_time - start_time
1 year ago
logger.info(f"match_question:Matched question in {duration} seconds. "
f"Response: {result}")
return response
except Exception as e:
logger.error(f"Error matching question: {e}")
return QuestionResponse(code=500, msg="success", data=[])
recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)
if fetch_and_write_data(api_url, path):
kb_name = str(uuid.uuid4())
device = None
qa_service = QAService(kb_name, device)
update_kb(kb_name, qa_service, path, max_knowledge_bases)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)