增加文字转embeddings接口

main
fanpt 11 months ago
parent c4f34445d5
commit 301c3e6999

@ -1,3 +1,4 @@
# coding=gbk
from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi import FastAPI, HTTPException, BackgroundTasks
from qa_Ask import QAService, match_query, store_data from qa_Ask import QAService, match_query, store_data
from pydantic import BaseModel from pydantic import BaseModel
@ -9,27 +10,31 @@ import uuid
import json import json
import shutil import shutil
import yaml import yaml
import sys
import logging import logging
from typing import List
from faiss_kb_service import FaissKBService
from faiss_kb_service import EmbeddingsFunAdapter
app = FastAPI() app = FastAPI()
import sys # 配置日志记录到文件和终端
# 配置日志记录到文件和终端
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s', format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[ handlers=[
logging.FileHandler('log/app.log'), logging.FileHandler('log/app.log'),
logging.StreamHandler(sys.stdout) # 这里添加控制台处理程序 logging.StreamHandler(sys.stdout) # 这里添加控制台处理程序
] ]
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class QuestionRequest(BaseModel): class QuestionRequest(BaseModel):
question: str question: str
scoreThreshold: float
class EmbeddingResponse(BaseModel):
embeddings: List[float]
class QuestionResponse(BaseModel): class QuestionResponse(BaseModel):
@ -39,9 +44,11 @@ class QuestionResponse(BaseModel):
class QuestionItem(BaseModel): class QuestionItem(BaseModel):
questionCode: str questionId: str
questionList: list[str] questionList: list[str]
class EmbeddingRequest(BaseModel):
text: str
with open('config/config.yaml', 'r') as config_file: with open('config/config.yaml', 'r') as config_file:
config_data = yaml.safe_load(config_file) config_data = yaml.safe_load(config_file)
@ -53,7 +60,7 @@ max_knowledge_bases = config_data['max_knowledge_bases']
def load_knowledge_bases(): def load_knowledge_bases():
"""加载知识库名称列表""" """加载知识库名称列表"""
if os.path.exists(knowledge_base_file): if os.path.exists(knowledge_base_file):
with open(knowledge_base_file, "r") as file: with open(knowledge_base_file, "r") as file:
return file.read().splitlines() return file.read().splitlines()
@ -62,13 +69,13 @@ def load_knowledge_bases():
def save_knowledge_bases(names): def save_knowledge_bases(names):
"""保存知识库名称列表到文件""" """保存知识库名称列表到文件"""
with open(knowledge_base_file, "w") as file: with open(knowledge_base_file, "w") as file:
file.write("\n".join(names)) file.write("\n".join(names))
def update_kb(kb_name, qa_service, path, max_knowledge_bases): def update_kb(kb_name, qa_service, path, max_knowledge_bases):
"""更新知识库""" """更新知识库"""
store_data(qa_service, path) store_data(qa_service, path)
if len(recent_knowledge_bases) == max_knowledge_bases: if len(recent_knowledge_bases) == max_knowledge_bases:
@ -80,11 +87,11 @@ def update_kb(kb_name, qa_service, path, max_knowledge_bases):
os.remove(path) os.remove(path)
logger.info(f"Knowledge base updated: {kb_name}\n" logger.info(f"Knowledge base updated: {kb_name}\n"
f"Please wait while the database is being updated···") f"Please wait while the database is being updated···")
def fetch_and_write_data(api_url, path): def fetch_and_write_data(api_url, path):
"""从API获取数据并写入文件""" """从API获取数据并写入文件"""
try: try:
response = requests.get(api_url) response = requests.get(api_url)
response_data = response.json() response_data = response.json()
@ -102,11 +109,26 @@ def fetch_and_write_data(api_url, path):
except Exception as e: except Exception as e:
logger.error(f"Error fetching data from API: {e}") logger.error(f"Error fetching data from API: {e}")
return False return False
bge_large_zh_v1_5_config = config_data.get('bge_large_zh_v1_5', {})
embed_model_path = bge_large_zh_v1_5_config.get('embed_model_path', 'default_path_if_not_provided')
recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)
faiss_service = FaissKBService(kb_name= recent_knowledge_bases[-1], embed_model_path=embed_model_path, device=None)
@app.post("/embeddings/", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest):
"""使用FaissKBService实例来获取嵌入向量"""
embed_func = EmbeddingsFunAdapter(faiss_service.embed_model_path, faiss_service.device)
try:
embeddings = embed_func.embed_query(request.text)
embeddings_list = embeddings
return EmbeddingResponse(embeddings=embeddings_list)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/updateDatabase") @app.post("/updateDatabase")
async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks): async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks):
"""接收问题数据并异步保存为JSON文件触发后台更新任务""" """接收问题数据并异步保存为JSON文件触发后台更新任务"""
try: try:
json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2) json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2)
path = "output.json" path = "output.json"
@ -123,16 +145,17 @@ async def save_to_json(question_items: list[QuestionItem], background_tasks: Bac
update_kb, kb_name, qa_service, path, max_knowledge_bases update_kb, kb_name, qa_service, path, max_knowledge_bases
) )
return {"status": "success", "message": "Please wait while the database is being updated···"} return {"status": "success", "message": "Please wait while the database is being updated···"}
except Exception as e: except Exception as e:
logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}") logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}")
# raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") # raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
return {"status": "error", "message": "update task error···"} return {"status": "error", "message": "update task error···"}
@app.post("/matchQuestion") @app.post("/matchQuestion")
def match_question(request: QuestionRequest): def match_question(request: QuestionRequest):
"""匹配问题的端点""" """匹配问题的端点"""
try: try:
logger.info(f"match_question:Request: {request}") logger.info(f"match_question:Request: {request}")
start_time = time.time() start_time = time.time()
@ -141,12 +164,11 @@ def match_question(request: QuestionRequest):
newest = recent_knowledge_bases[-1] newest = recent_knowledge_bases[-1]
top_k = 3 top_k = 3
score_threshold = 0.1
device = None device = None
qa_service = QAService(newest, device) qa_service = QAService(newest, device)
result = match_query(qa_service, query, top_k, score_threshold) result = match_query(qa_service, query, top_k, request.scoreThreshold)
response = QuestionResponse(code=200, msg="success", data=result) response = QuestionResponse(code=200, msg="success", data=result)
stop_time = time.time() stop_time = time.time()

@ -51,7 +51,7 @@ def load_testing_data(file_path):
with open(file_path, encoding='utf-8') as f: with open(file_path, encoding='utf-8') as f:
data = json.load(f) data = json.load(f)
for item in data: for item in data:
question_code = item['questionCode'] question_code = item['questionId']
question_list.extend(item['questionList']) question_list.extend(item['questionList'])
id_list.extend([create_question_id(question_code, j, q) for j, q in enumerate(item['questionList'])]) id_list.extend([create_question_id(question_code, j, q) for j, q in enumerate(item['questionList'])])

Loading…
Cancel
Save