You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
195 lines
5.4 KiB
Python
195 lines
5.4 KiB
Python
|
|
import sys
|
|
|
|
from kb_config import logger
|
|
from sentence_transformers import CrossEncoder
|
|
from faiss_kb_service import FaissKBService, DocumentWithVectorStoreId
|
|
from langchain.docstore.document import Document
|
|
|
|
class QAReranker():
|
|
|
|
def __init__(self,
|
|
model_name_or_path: str,
|
|
top_n: int = 3,
|
|
device: str = "cuda:0",
|
|
max_length: int = 1024,
|
|
batch_size: int = 32
|
|
):
|
|
|
|
self._model = CrossEncoder(model_name=model_name_or_path, max_length=max_length, device=device)
|
|
self.top_n = top_n
|
|
self.batch_size = batch_size
|
|
|
|
def rank(
|
|
self,
|
|
documents,
|
|
query,
|
|
):
|
|
if len(documents) == 0:
|
|
return []
|
|
sentence_pairs = [[query, _doc] for _doc in documents]
|
|
results = self._model.predict(sentences=sentence_pairs,
|
|
batch_size=self.batch_size,
|
|
convert_to_tensor=True
|
|
)
|
|
top_k = self.top_n if self.top_n < len(results) else len(results)
|
|
|
|
scores, indices = results.topk(top_k)
|
|
final_results = []
|
|
for score, index in zip(scores, indices):
|
|
doc = documents[index]
|
|
final_results.append((doc, score))
|
|
return final_results
|
|
|
|
|
|
from base_kb import KnowledgeFile
|
|
|
|
class QAService():
|
|
|
|
|
|
def __init__(self, kb_name, device) -> None:
|
|
|
|
embed_model_path = 'bge-large-zh-v1.5'
|
|
reranker_model_path = '/export/zt/chatchat/model/bge-reranker-large'
|
|
|
|
fkbs = FaissKBService(kb_name, embed_model_path=embed_model_path, device=device)
|
|
fkbs.do_create_kb()
|
|
self.fkbs = fkbs
|
|
|
|
# self.reranker_model = QAReranker(
|
|
# device=device,
|
|
# model_name_or_path=reranker_model_path
|
|
# )
|
|
|
|
self.kb_name = kb_name
|
|
|
|
def delete_qa_file(self, qa_file_id):
|
|
kb_file = KnowledgeFile(qa_file_id, self.kb_name)
|
|
self.fkbs.do_delete_doc(kb_file, not_refresh_vs_cache=True)
|
|
|
|
def update_qa_doc(self, qa_file_id, doc_list, id_list):
|
|
self.delete_qa_file(qa_file_id)
|
|
|
|
doc_infos = self.fkbs.do_add_doc(doc_list, ids=id_list)
|
|
logger.info('fassi add docs: ' + str(len(doc_infos)))
|
|
|
|
self.fkbs.save_vector_store()
|
|
|
|
|
|
def search(self,
|
|
query,
|
|
top_k = 3,
|
|
score_threshold = 0.1,
|
|
reranked=False):
|
|
|
|
docs = self.fkbs.do_search(query, top_k, score_threshold)
|
|
# print(docs)
|
|
|
|
# docs = [DocumentWithVectorStoreId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
|
|
# print(docs)
|
|
|
|
# if reranked:
|
|
# reranked_docs = self.reranker_model.rank(
|
|
# documents=docs,
|
|
# query=query
|
|
# )
|
|
|
|
# print(reranked_docs)
|
|
# docs = reranked_docs
|
|
|
|
# rst = [doc.page_content for doc in docs]
|
|
# return rst
|
|
return docs
|
|
|
|
|
|
import json
|
|
|
|
|
|
def create_question_id(intent_code, j, test_question):
|
|
return f"{intent_code}@{j}@{test_question}"
|
|
|
|
|
|
|
|
def load_testing_data(file_path):
|
|
|
|
test_data_list = []
|
|
question_list = []
|
|
id_list = []
|
|
|
|
with open(file_path, encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
for i, item in enumerate(data):
|
|
test_question = item['testQuestion']
|
|
intent_code = item['expectIntentCode']
|
|
test_data_list.append((test_question, intent_code))
|
|
|
|
q_list = item['expectIntentQuestionExample']
|
|
for j, q in enumerate(q_list):
|
|
q_id = create_question_id(intent_code, j, test_question)
|
|
question_list.append(q)
|
|
id_list.append(q_id)
|
|
return test_data_list, question_list, id_list
|
|
|
|
|
|
def convert_to_doc_list(question_list, id_list, qa_file_id):
|
|
doc_list = []
|
|
for question, id in zip(question_list, id_list):
|
|
|
|
metadata = {
|
|
'source': qa_file_id,
|
|
'id': id
|
|
}
|
|
doc = Document(page_content=question, metadata=metadata)
|
|
doc_list.append(doc)
|
|
|
|
return doc_list
|
|
|
|
|
|
def work():
|
|
|
|
kb_name = 'my_kb_test'
|
|
device = None
|
|
qa_service = QAService(kb_name, device)
|
|
|
|
|
|
test_data_list, question_list, id_list = load_testing_data(r'test_data/testing_data.json')
|
|
print('Loaded data!')
|
|
|
|
qa_file_id = 'QA_TEST_2' # the source of the qa, using for data cleaning, make sure to be unique
|
|
|
|
|
|
doc_list = convert_to_doc_list(question_list, id_list, qa_file_id)
|
|
|
|
qa_service.update_qa_doc(qa_file_id, doc_list, id_list)
|
|
|
|
# rst = qa_service.search(test_data_list[0][0])
|
|
# print(rst)
|
|
|
|
|
|
# rst = qa_service.search(test_data_list[1][0])
|
|
# print(rst)
|
|
|
|
cnt = 0
|
|
for query, code in test_data_list:
|
|
rst = qa_service.search(query)
|
|
if do_test(query, code, rst):
|
|
cnt += 1
|
|
|
|
print(str(cnt) + '/' + str(len(test_data_list)))
|
|
|
|
|
|
def do_test(query, expected_intent_code, rst):
|
|
if rst is None or len(rst)==0:
|
|
print('Empty: ' + query)
|
|
return False
|
|
|
|
rst_doc = rst[0]
|
|
page_content = rst_doc[0].page_content
|
|
intent_code = rst_doc[0].metadata['id'].split('@')[0]
|
|
print(query + ' vs ' + page_content + ' : ' + expected_intent_code + ' vs ' + intent_code)
|
|
# return expected_intent_code == intent_code
|
|
return True
|
|
|
|
work()
|
|
|