|
|
|
@ -1,4 +1,5 @@
|
|
|
|
|
# coding=gbk
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
|
|
|
|
from qa_Ask import QAService, match_query, store_data
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
@ -18,13 +19,13 @@ from faiss_kb_service import EmbeddingsFunAdapter
|
|
|
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
# 配置日志记录到文件和终端
|
|
|
|
|
# 配置日志记录到文件和终端
|
|
|
|
|
logging.basicConfig(
|
|
|
|
|
level=logging.INFO,
|
|
|
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
|
|
|
handlers=[
|
|
|
|
|
logging.FileHandler('log/app.log'),
|
|
|
|
|
logging.StreamHandler(sys.stdout) # 这里添加控制台处理程序
|
|
|
|
|
logging.StreamHandler(sys.stdout) # 这里添加控制台处理程序
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@ -60,7 +61,7 @@ max_knowledge_bases = config_data['max_knowledge_bases']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_knowledge_bases():
|
|
|
|
|
"""加载知识库名称列表"""
|
|
|
|
|
"""加载知识库名称列表"""
|
|
|
|
|
if os.path.exists(knowledge_base_file):
|
|
|
|
|
with open(knowledge_base_file, "r") as file:
|
|
|
|
|
return file.read().splitlines()
|
|
|
|
@ -69,13 +70,13 @@ def load_knowledge_bases():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_knowledge_bases(names):
|
|
|
|
|
"""保存知识库名称列表到文件"""
|
|
|
|
|
"""保存知识库名称列表到文件"""
|
|
|
|
|
with open(knowledge_base_file, "w") as file:
|
|
|
|
|
file.write("\n".join(names))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_kb(kb_name, qa_service, path, max_knowledge_bases):
|
|
|
|
|
"""更新知识库"""
|
|
|
|
|
"""更新知识库"""
|
|
|
|
|
store_data(qa_service, path)
|
|
|
|
|
|
|
|
|
|
if len(recent_knowledge_bases) == max_knowledge_bases:
|
|
|
|
@ -87,11 +88,11 @@ def update_kb(kb_name, qa_service, path, max_knowledge_bases):
|
|
|
|
|
|
|
|
|
|
os.remove(path)
|
|
|
|
|
logger.info(f"Knowledge base updated: {kb_name}\n"
|
|
|
|
|
f"Please wait while the database is being updated···")
|
|
|
|
|
f"Please wait while the database is being updated···")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fetch_and_write_data(api_url, path):
|
|
|
|
|
"""从API获取数据并写入文件"""
|
|
|
|
|
"""从API获取数据并写入文件"""
|
|
|
|
|
try:
|
|
|
|
|
response = requests.get(api_url)
|
|
|
|
|
response_data = response.json()
|
|
|
|
@ -117,7 +118,7 @@ faiss_service = FaissKBService(kb_name= recent_knowledge_bases[-1], embed_model_
|
|
|
|
|
|
|
|
|
|
@app.post("/embeddings/", response_model=EmbeddingResponse)
|
|
|
|
|
async def get_embeddings(request: EmbeddingRequest):
|
|
|
|
|
"""使用FaissKBService实例来获取嵌入向量"""
|
|
|
|
|
"""使用FaissKBService实例来获取嵌入向量"""
|
|
|
|
|
embed_func = EmbeddingsFunAdapter(faiss_service.embed_model_path, faiss_service.device)
|
|
|
|
|
try:
|
|
|
|
|
embeddings = embed_func.embed_query(request.text)
|
|
|
|
@ -128,7 +129,7 @@ async def get_embeddings(request: EmbeddingRequest):
|
|
|
|
|
|
|
|
|
|
@app.post("/updateDatabase")
|
|
|
|
|
async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks):
|
|
|
|
|
"""接收问题数据并异步保存为JSON文件,触发后台更新任务"""
|
|
|
|
|
"""接收问题数据并异步保存为JSON文件,触发后台更新任务"""
|
|
|
|
|
try:
|
|
|
|
|
json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2)
|
|
|
|
|
path = "output.json"
|
|
|
|
@ -145,17 +146,17 @@ async def save_to_json(question_items: list[QuestionItem], background_tasks: Bac
|
|
|
|
|
update_kb, kb_name, qa_service, path, max_knowledge_bases
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return {"status": "success", "message": "Please wait while the database is being updated···"}
|
|
|
|
|
return {"status": "success", "message": "Please wait while the database is being updated···"}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}")
|
|
|
|
|
# raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
|
|
|
|
|
return {"status": "error", "message": "update task error···"}
|
|
|
|
|
return {"status": "error", "message": "update task error···"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/matchQuestion")
|
|
|
|
|
def match_question(request: QuestionRequest):
|
|
|
|
|
"""匹配问题的端点"""
|
|
|
|
|
"""匹配问题的端点"""
|
|
|
|
|
try:
|
|
|
|
|
logger.info(f"match_question:Request: {request}")
|
|
|
|
|
start_time = time.time()
|
|
|
|
|