You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
120 lines
3.6 KiB
Python
120 lines
3.6 KiB
Python
# qa_operations.py
|
|
import json
|
|
from sentence_transformers import CrossEncoder
|
|
from faiss_kb_service import FaissKBService
|
|
from langchain.docstore.document import Document
|
|
from base_kb import KnowledgeFile
|
|
from pydantic import BaseModel
|
|
|
|
|
|
import yaml
|
|
|
|
# Read the configuration file
|
|
with open('config/config.yaml', 'r') as config_file:
|
|
config = yaml.safe_load(config_file)
|
|
|
|
# Access the 'bge-large-zh-v1.5' configuration
|
|
bge_large_zh_v1_5_config = config.get('bge_large_zh_v1_5', {})
|
|
embed_model_path = bge_large_zh_v1_5_config.get('embed_model_path', 'default_path_if_not_provided')
|
|
|
|
class QAService():
|
|
def __init__(self, kb_name, device=None) -> None:
|
|
self.kb_name = kb_name
|
|
self.device = device
|
|
self.fkbs = FaissKBService(kb_name, embed_model_path=embed_model_path, device=device)
|
|
self.fkbs.do_create_kb()
|
|
|
|
def update_qa_doc(self, qa_file_id, doc_list, id_list):
|
|
self.delete_qa_file(qa_file_id)
|
|
doc_infos = self.fkbs.do_add_doc(doc_list, ids=id_list)
|
|
self.fkbs.save_vector_store()
|
|
|
|
def delete_qa_file(self, qa_file_id):
|
|
|
|
kb_file = KnowledgeFile(qa_file_id, self.kb_name)
|
|
|
|
self.fkbs.do_delete_doc(kb_file, not_refresh_vs_cache=True)
|
|
|
|
def search(self, query, top_k=3, score_threshold=0.1, reranked=False):
|
|
docs = self.fkbs.do_search(query, top_k, score_threshold)
|
|
return docs
|
|
|
|
|
|
def create_question_id(question_code, j, test_question):
|
|
return f"{question_code}@{j}@{test_question}"
|
|
|
|
|
|
def load_testing_data(file_path):
|
|
question_list = []
|
|
id_list = []
|
|
|
|
with open(file_path, encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
for item in data:
|
|
question_code = item['questionId']
|
|
question_list.extend(item['questionList'])
|
|
id_list.extend([create_question_id(question_code, j, q) for j, q in enumerate(item['questionList'])])
|
|
|
|
return question_list, id_list
|
|
|
|
|
|
def convert_to_doc_list(question_list, id_list, qa_file_id):
|
|
doc_list = []
|
|
for question, id in zip(question_list, id_list):
|
|
metadata = {'source': qa_file_id, 'id': id}
|
|
doc = Document(page_content=question, metadata=metadata)
|
|
doc_list.append(doc)
|
|
return doc_list
|
|
|
|
|
|
def store_data(qa_service, path):
|
|
|
|
question_list, id_list = load_testing_data(path)
|
|
print('Loaded data!')
|
|
|
|
qa_file_id = 'QA_TEST_1' # the source of the qa, using for data cleaning, make sure to be unique
|
|
|
|
doc_list = convert_to_doc_list(question_list, id_list, qa_file_id)
|
|
|
|
qa_service.update_qa_doc(qa_file_id, doc_list, id_list)
|
|
|
|
print("Data stored in the knowledge base successfully!")
|
|
|
|
|
|
|
|
class MatchInfo(BaseModel):
|
|
matchQuestionCode: str
|
|
matchQuestion: str
|
|
matchScore: str
|
|
|
|
|
|
def match_query(qa_service, query, top_k=3, score_threshold=0.1):
|
|
docs = qa_service.search(query, top_k, 1 - score_threshold)
|
|
response = []
|
|
if docs:
|
|
for doc, similarity_score in docs:
|
|
doc_id = doc.metadata['id']
|
|
question_code = doc_id.split('@')[0]
|
|
match_info = MatchInfo(
|
|
matchQuestionCode=question_code,
|
|
matchQuestion=doc.page_content,
|
|
matchScore=f"{1 - similarity_score:.3f}" # 返回字段
|
|
)
|
|
response.append(match_info)
|
|
|
|
return response
|
|
|
|
|
|
if __name__ == "__main__":
|
|
kb_name = 'my_kb_test'
|
|
path = "test.json"
|
|
device = None
|
|
qa_service = QAService(kb_name, device)
|
|
|
|
store_data(qa_service, path)
|
|
|
|
kb_name = 'my_kb_test_2'
|
|
qa_service = QAService(kb_name, device)
|
|
top_k = 3
|
|
score_threshold = 0.1
|
|
match_query(qa_service, "你好吗?", top_k, score_threshold) |