import os
import shutil


from base_kb import EmbeddingsFunAdapter, KnowledgeFile, torch_gc, embed_documents
from kb_config import get_kb_path, get_vs_path
from faiss_cache import ThreadSafeFaiss, kb_faiss_pool

from langchain.docstore.document import Document
from typing import List, Dict, Optional, Tuple


class DocumentWithVectorStoreId(Document):
    """
    矢量化后的文档
    """
    id: str = None
    score: float = 3.0


class FaissKBService():
    vs_path: str
    kb_path: str

    def __init__(self, kb_name, embed_model_path, device) -> None:
        self.kb_name = kb_name
        self.embed_model_path = embed_model_path
        self.device = device
        self.do_init()


    def get_vs_path(self):
        return get_vs_path(self.kb_name, self.vector_name)

    def get_kb_path(self):
        return get_kb_path(self.kb_name)

    def load_vector_store(self) -> ThreadSafeFaiss:
        return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
                                               vector_name=self.vector_name,
                                               embed_local_model_path=self.embed_model_path,
                                               embed_device=self.device)

    def save_vector_store(self):
        self.load_vector_store().save(self.vs_path)

    def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
        with self.load_vector_store().acquire() as vs:
            return [vs.docstore._dict.get(id) for id in ids]

    def del_doc_by_ids(self, ids: List[str]) -> bool:
        with self.load_vector_store().acquire() as vs:
            vs.delete(ids)

    def do_init(self):
        self.vector_name = 'FAISS'
        self.kb_path = self.get_kb_path()
        self.vs_path = self.get_vs_path()

    def do_create_kb(self):
        if not os.path.exists(self.vs_path):
            os.makedirs(self.vs_path)
        self.load_vector_store()

    def do_drop_kb(self):
        self.clear_vs()
        try:
            shutil.rmtree(self.kb_path)
        except Exception:
            ...

    def do_search(self,
                  query: str,
                  top_k: int,
                  score_threshold,
                  ) -> List[Tuple[Document, float]]:
        embed_func = EmbeddingsFunAdapter(self.embed_model_path, self.device)

        embeddings = embed_func.embed_query(query)
        with self.load_vector_store().acquire() as vs:
            docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
        return docs
    

    def _docs_to_embeddings(self, docs: List[Document]) -> Dict:
        return embed_documents(docs, self.embed_model_path, self.device)


    def do_add_doc(self,
                   docs: List[Document],
                   **kwargs,
                   ) -> List[Dict]:
        data = self._docs_to_embeddings(docs) # 将向量化单独出来可以减少向量库的锁定时间

        with self.load_vector_store().acquire() as vs:
            ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]),
                                    metadatas=data["metadatas"],
                                    ids=kwargs.get("ids"))
            if not kwargs.get("not_refresh_vs_cache"):
                vs.save_local(self.vs_path)
        doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
        torch_gc()
        return doc_infos

    def do_delete_doc(self,
                      kb_file: KnowledgeFile,
                      **kwargs):
        with self.load_vector_store().acquire() as vs:
            ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source").lower() == kb_file.filename.lower()]
            if len(ids) > 0:
                vs.delete(ids)
            if not kwargs.get("not_refresh_vs_cache"):
                vs.save_local(self.vs_path)
        return ids

    def do_clear_vs(self):
        with kb_faiss_pool.atomic:
            kb_faiss_pool.pop((self.kb_name, self.vector_name))
        try:
            shutil.rmtree(self.vs_path)
        except Exception:
            ...
        os.makedirs(self.vs_path, exist_ok=True)

    def exist_doc(self, file_name: str):
        if super().exist_doc(file_name):
            return "in_db"

        content_path = os.path.join(self.kb_path, "content")
        if os.path.isfile(os.path.join(content_path, file_name)):
            return "in_folder"
        else:
            return False