diff --git a/qa_amend.py b/qa_amend.py new file mode 100644 index 0000000..a7c1769 --- /dev/null +++ b/qa_amend.py @@ -0,0 +1,109 @@ +# qa_operations.py +import json +from kb_config import logger +from sentence_transformers import CrossEncoder +from faiss_kb_service import FaissKBService +from langchain.docstore.document import Document +from base_kb import KnowledgeFile +from pydantic import BaseModel + +class QAService(): + def __init__(self, kb_name, device=None) -> None: + self.kb_name = kb_name + self.device = device + self.fkbs = FaissKBService(kb_name, embed_model_path='bge-large-zh-v1.5', device=device) + self.fkbs.do_create_kb() + + def update_qa_doc(self, qa_file_id, doc_list, id_list): + self.delete_qa_file(qa_file_id) + doc_infos = self.fkbs.do_add_doc(doc_list, ids=id_list) + logger.info('faiss add docs: ' + str(len(doc_infos))) + self.fkbs.save_vector_store() + + def delete_qa_file(self, qa_file_id): + + kb_file = KnowledgeFile(qa_file_id, self.kb_name) + self.fkbs.do_delete_doc(kb_file, not_refresh_vs_cache=True) + + def search(self, query, top_k=3, score_threshold=0.1, reranked=False): + docs = self.fkbs.do_search(query, top_k, score_threshold) + return docs + + +def create_question_id(question_code, j, test_question): + return f"{question_code}@{j}@{test_question}" + + +def load_testing_data(file_path): + question_list = [] + id_list = [] + + with open(file_path, encoding='utf-8') as f: + data = json.load(f) + for item in data: + question_code = item['questionCode'] + question_list.extend(item['questionList']) + id_list.extend([create_question_id(question_code, j, q) for j, q in enumerate(item['questionList'])]) + + return question_list, id_list + + +def convert_to_doc_list(question_list, id_list, qa_file_id): + doc_list = [] + for question, id in zip(question_list, id_list): + metadata = {'source': qa_file_id, 'id': id} + doc = Document(page_content=question, metadata=metadata) + doc_list.append(doc) + return doc_list + + +def store_data(path): + kb_name = 'my_kb_test' + device = None + qa_service = QAService(kb_name, device) + + question_list, id_list = load_testing_data(path) + print('Loaded data!') + + qa_file_id = 'QA_TEST_1' # the source of the qa, using for data cleaning, make sure to be unique + + doc_list = convert_to_doc_list(question_list, id_list, qa_file_id) + + qa_service.update_qa_doc(qa_file_id, doc_list, id_list) + + print("Data stored in the knowledge base successfully!") + return True + + +class MatchInfo(BaseModel): + matchQuestionCode: str + matchQuestion: str + matchScore: str + + +def match_query(qa_service, query, top_k=3, score_threshold=0.1): + docs = qa_service.search(query, top_k, 1 - score_threshold) + response = [] + if docs: + for doc, similarity_score in docs: + doc_id = doc.metadata['id'] + question_code = doc_id.split('@')[0] + match_info = MatchInfo( + matchQuestionCode=question_code, + matchQuestion=doc.page_content, + matchScore=f"{1 - similarity_score:.3f}" + ) + response.append(match_info) + + return response + + +if __name__ == "__main__": + store_data("test.json") + kb_name = 'my_kb_test' + device = None + qa_service = QAService(kb_name, device) + top_k = 3 + score_threshold = 0.1 + result = match_query(qa_service, "你好吗?", top_k, score_threshold) + print(result) \ No newline at end of file