import sys from kb_config import logger from sentence_transformers import CrossEncoder from faiss_kb_service import FaissKBService, DocumentWithVectorStoreId from langchain.docstore.document import Document class QAReranker(): def __init__(self, model_name_or_path: str, top_n: int = 3, device: str = "cuda:0", max_length: int = 1024, batch_size: int = 32 ): self._model = CrossEncoder(model_name=model_name_or_path, max_length=max_length, device=device) self.top_n = top_n self.batch_size = batch_size def rank( self, documents, query, ): if len(documents) == 0: return [] sentence_pairs = [[query, _doc] for _doc in documents] results = self._model.predict(sentences=sentence_pairs, batch_size=self.batch_size, convert_to_tensor=True ) top_k = self.top_n if self.top_n < len(results) else len(results) scores, indices = results.topk(top_k) final_results = [] for score, index in zip(scores, indices): doc = documents[index] final_results.append((doc, score)) return final_results from base_kb import KnowledgeFile class QAService(): def __init__(self, kb_name, device) -> None: embed_model_path = 'bge-large-zh-v1.5' reranker_model_path = '/export/zt/chatchat/model/bge-reranker-large' fkbs = FaissKBService(kb_name, embed_model_path=embed_model_path, device=device) fkbs.do_create_kb() self.fkbs = fkbs # self.reranker_model = QAReranker( # device=device, # model_name_or_path=reranker_model_path # ) self.kb_name = kb_name def delete_qa_file(self, qa_file_id): kb_file = KnowledgeFile(qa_file_id, self.kb_name) self.fkbs.do_delete_doc(kb_file, not_refresh_vs_cache=True) def update_qa_doc(self, qa_file_id, doc_list, id_list): self.delete_qa_file(qa_file_id) doc_infos = self.fkbs.do_add_doc(doc_list, ids=id_list) logger.info('fassi add docs: ' + str(len(doc_infos))) self.fkbs.save_vector_store() def search(self, query, top_k = 3, score_threshold = 0.1, reranked=False): docs = self.fkbs.do_search(query, top_k, score_threshold) # print(docs) # docs = [DocumentWithVectorStoreId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] # print(docs) # if reranked: # reranked_docs = self.reranker_model.rank( # documents=docs, # query=query # ) # print(reranked_docs) # docs = reranked_docs # rst = [doc.page_content for doc in docs] # return rst return docs import json def create_question_id(intent_code, j, test_question): return f"{intent_code}@{j}@{test_question}" def load_testing_data(file_path): test_data_list = [] question_list = [] id_list = [] with open(file_path, encoding='utf-8') as f: data = json.load(f) for i, item in enumerate(data): test_question = item['testQuestion'] intent_code = item['expectIntentCode'] test_data_list.append((test_question, intent_code)) q_list = item['expectIntentQuestionExample'] for j, q in enumerate(q_list): q_id = create_question_id(intent_code, j, test_question) question_list.append(q) id_list.append(q_id) return test_data_list, question_list, id_list def convert_to_doc_list(question_list, id_list, qa_file_id): doc_list = [] for question, id in zip(question_list, id_list): metadata = { 'source': qa_file_id, 'id': id } doc = Document(page_content=question, metadata=metadata) doc_list.append(doc) return doc_list def work(): kb_name = 'my_kb_test' device = None qa_service = QAService(kb_name, device) test_data_list, question_list, id_list = load_testing_data(r'test_data/testing_data.json') print('Loaded data!') qa_file_id = 'QA_TEST_2' # the source of the qa, using for data cleaning, make sure to be unique doc_list = convert_to_doc_list(question_list, id_list, qa_file_id) qa_service.update_qa_doc(qa_file_id, doc_list, id_list) # rst = qa_service.search(test_data_list[0][0]) # print(rst) # rst = qa_service.search(test_data_list[1][0]) # print(rst) cnt = 0 for query, code in test_data_list: rst = qa_service.search(query) if do_test(query, code, rst): cnt += 1 print(str(cnt) + '/' + str(len(test_data_list))) def do_test(query, expected_intent_code, rst): if rst is None or len(rst)==0: print('Empty: ' + query) return False rst_doc = rst[0] page_content = rst_doc[0].page_content intent_code = rst_doc[0].metadata['id'].split('@')[0] print(query + ' vs ' + page_content + ' : ' + expected_intent_code + ' vs ' + intent_code) # return expected_intent_code == intent_code return True work()