from typing import *

from base_kb import ThreadSafeObject, CachePool

from kb_config import get_vs_path
from langchain.vectorstores.faiss import FAISS
# from langchain.docstore.in_memory import InMemoryDocstore
from langchain.schema import Document
import os



# # patch FAISS to include doc id in Document.metadata
# def _new_ds_search(self, search: str) -> Union[str, Document]:
#     if search not in self._dict:
#         return f"ID {search} not found."
#     else:
#         doc = self._dict[search]
#         if isinstance(doc, Document):
#             doc.metadata["id"] = search
#         return doc
# InMemoryDocstore.search = _new_ds_search


class ThreadSafeFaiss(ThreadSafeObject):
    def __repr__(self) -> str:
        cls = type(self).__name__
        return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>"

    def docs_count(self) -> int:
        return len(self._obj.docstore._dict)

    def save(self, path: str, create_path: bool = True):
        with self.acquire():
            if not os.path.isdir(path) and create_path:
                os.makedirs(path)
            ret = self._obj.save_local(path)

        return ret

    def clear(self):
        ret = []
        with self.acquire():
            ids = list(self._obj.docstore._dict.keys())
            if ids:
                ret = self._obj.delete(ids)
                assert len(self._obj.docstore._dict) == 0
        return ret


class _FaissPool(CachePool):
    def new_vector_store(
        self,
        local_model_path,
        device,
    ) -> FAISS:
        
        # embeddings = EmbeddingsFunAdapter(embed_model)
        from langchain.embeddings.huggingface import HuggingFaceEmbeddings
        embeddings = HuggingFaceEmbeddings(model_name=local_model_path,
                                                    model_kwargs={'device': device})

        doc = Document(page_content="init", metadata={})
        vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
        ids = list(vector_store.docstore._dict.keys())
        vector_store.delete(ids)
        return vector_store

    def save_vector_store(self, kb_name: str, path: str=None):
        if cache := self.get(kb_name):
            return cache.save(path)

    def unload_vector_store(self, kb_name: str):
        if cache := self.get(kb_name):
            self.pop(kb_name)


class KBFaissPool(_FaissPool):

    def load_vector_store(
            self,
            kb_name,
            vector_name,
            embed_local_model_path,
            embed_device,
            create=True,
    ) -> ThreadSafeFaiss:
        
        self.atomic.acquire()
        cache = self.get((kb_name, vector_name)) 

        if cache is None:
            item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
            self.set((kb_name, vector_name), item)

            with item.acquire(msg="初始化"):
                self.atomic.release()
                vs_path = get_vs_path(kb_name, vector_name)

                if os.path.isfile(os.path.join(vs_path, "index.faiss")):
                    # load the embedding model
                    embeddings = self.load_kb_embeddings(local_model_path=embed_local_model_path, embed_device=embed_device)
                    vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")

                elif create:
                    # create an empty vector store
                    if not os.path.exists(vs_path):
                        os.makedirs(vs_path)
                    vector_store = self.new_vector_store(local_model_path=embed_local_model_path, device=embed_device)
                    vector_store.save_local(vs_path)
                    
                else:
                    raise RuntimeError(f"knowledge base {kb_name} not exist.")
                item.obj = vector_store
                item.finish_loading()
        else:
            self.atomic.release()
        return self.get((kb_name, vector_name))

kb_faiss_pool = KBFaissPool(cache_num=2)