import sys

from kb_config import logger
from sentence_transformers import CrossEncoder
from faiss_kb_service import FaissKBService, DocumentWithVectorStoreId
from langchain.docstore.document import Document

class QAReranker():

    def __init__(self,
                 model_name_or_path: str,
                 top_n: int = 3,
                 device: str = "cuda:0",
                 max_length: int = 1024,
                 batch_size: int = 32
                 ):

        self._model = CrossEncoder(model_name=model_name_or_path, max_length=max_length, device=device)
        self.top_n = top_n
        self.batch_size = batch_size

    def rank(
            self,
            documents,
            query,
    ):
        if len(documents) == 0:
            return []
        sentence_pairs = [[query, _doc] for _doc in documents]
        results = self._model.predict(sentences=sentence_pairs,
                                      batch_size=self.batch_size,
                                      convert_to_tensor=True
                                      )
        top_k = self.top_n if self.top_n < len(results) else len(results)

        scores, indices = results.topk(top_k)
        final_results = []
        for score, index in zip(scores, indices):
            doc = documents[index]
            final_results.append((doc, score))
        return final_results


from base_kb import KnowledgeFile

class QAService():


    def __init__(self, kb_name, device) -> None:

        embed_model_path = 'bge-large-zh-v1.5'
        reranker_model_path = '/export/zt/chatchat/model/bge-reranker-large'

        fkbs = FaissKBService(kb_name, embed_model_path=embed_model_path, device=device)
        fkbs.do_create_kb()
        self.fkbs = fkbs

        # self.reranker_model = QAReranker(
        #                             device=device,
        #                             model_name_or_path=reranker_model_path
        #     )

        self.kb_name = kb_name

    def delete_qa_file(self, qa_file_id):
        kb_file = KnowledgeFile(qa_file_id, self.kb_name)
        self.fkbs.do_delete_doc(kb_file, not_refresh_vs_cache=True)

    def update_qa_doc(self, qa_file_id, doc_list, id_list):
        self.delete_qa_file(qa_file_id)

        doc_infos = self.fkbs.do_add_doc(doc_list, ids=id_list)
        logger.info('fassi add docs: ' + str(len(doc_infos)))

        self.fkbs.save_vector_store()


    def search(self,
               query,
               top_k = 3,
               score_threshold = 0.1,
               reranked=False):

        docs = self.fkbs.do_search(query, top_k, score_threshold)
        # print(docs)

        # docs = [DocumentWithVectorStoreId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
        # print(docs)

        # if reranked:
        #     reranked_docs = self.reranker_model.rank(
        #                                         documents=docs,
        #                                         query=query
        #                             )

        #     print(reranked_docs)
        #     docs = reranked_docs

        # rst = [doc.page_content for doc in docs]
        # return rst
        return docs


import json


def create_question_id(intent_code, j, test_question):
    return f"{intent_code}@{j}@{test_question}"



def load_testing_data(file_path):

    test_data_list = []
    question_list = []
    id_list = []

    with open(file_path, encoding='utf-8') as f:
        data = json.load(f)
        for i, item in enumerate(data):
            test_question = item['testQuestion']
            intent_code = item['expectIntentCode']
            test_data_list.append((test_question, intent_code))

            q_list = item['expectIntentQuestionExample']
            for j, q in enumerate(q_list):
                q_id = create_question_id(intent_code, j, test_question)
                question_list.append(q)
                id_list.append(q_id)
    return test_data_list, question_list, id_list


def convert_to_doc_list(question_list, id_list, qa_file_id):
    doc_list = []
    for question, id in zip(question_list, id_list):

        metadata = {
            'source': qa_file_id,
            'id': id
        }
        doc = Document(page_content=question, metadata=metadata)
        doc_list.append(doc)

    return doc_list


def work():

    kb_name = 'my_kb_test'
    device = None
    qa_service = QAService(kb_name, device)


    test_data_list, question_list, id_list = load_testing_data(r'test_data/testing_data.json')
    print('Loaded data!')

    qa_file_id = 'QA_TEST_2'  # the source of the qa, using for data cleaning, make sure to be unique


    doc_list = convert_to_doc_list(question_list, id_list, qa_file_id)

    qa_service.update_qa_doc(qa_file_id, doc_list, id_list)

    # rst = qa_service.search(test_data_list[0][0])
    # print(rst)


    # rst = qa_service.search(test_data_list[1][0])
    # print(rst)

    cnt = 0
    for query, code in test_data_list:
        rst = qa_service.search(query)
        if do_test(query, code, rst):
            cnt += 1

    print(str(cnt) + '/' + str(len(test_data_list)))


def do_test(query, expected_intent_code, rst):
    if rst is None or len(rst)==0:
        print('Empty: ' + query)
        return False

    rst_doc = rst[0]
    page_content = rst_doc[0].page_content
    intent_code = rst_doc[0].metadata['id'].split('@')[0]
    print(query + ' vs ' + page_content + ' : ' + expected_intent_code + ' vs ' + intent_code)
    # return expected_intent_code == intent_code    
    return True

work()