from typing import * from base_kb import ThreadSafeObject, CachePool from kb_config import get_vs_path from langchain.vectorstores.faiss import FAISS # from langchain.docstore.in_memory import InMemoryDocstore from langchain.schema import Document import os # # patch FAISS to include doc id in Document.metadata # def _new_ds_search(self, search: str) -> Union[str, Document]: # if search not in self._dict: # return f"ID {search} not found." # else: # doc = self._dict[search] # if isinstance(doc, Document): # doc.metadata["id"] = search # return doc # InMemoryDocstore.search = _new_ds_search class ThreadSafeFaiss(ThreadSafeObject): def __repr__(self) -> str: cls = type(self).__name__ return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>" def docs_count(self) -> int: return len(self._obj.docstore._dict) def save(self, path: str, create_path: bool = True): with self.acquire(): if not os.path.isdir(path) and create_path: os.makedirs(path) ret = self._obj.save_local(path) return ret def clear(self): ret = [] with self.acquire(): ids = list(self._obj.docstore._dict.keys()) if ids: ret = self._obj.delete(ids) assert len(self._obj.docstore._dict) == 0 return ret class _FaissPool(CachePool): def new_vector_store( self, local_model_path, device, ) -> FAISS: # embeddings = EmbeddingsFunAdapter(embed_model) from langchain.embeddings.huggingface import HuggingFaceEmbeddings embeddings = HuggingFaceEmbeddings(model_name=local_model_path, model_kwargs={'device': device}) doc = Document(page_content="init", metadata={}) vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT") ids = list(vector_store.docstore._dict.keys()) vector_store.delete(ids) return vector_store def save_vector_store(self, kb_name: str, path: str=None): if cache := self.get(kb_name): return cache.save(path) def unload_vector_store(self, kb_name: str): if cache := self.get(kb_name): self.pop(kb_name) class KBFaissPool(_FaissPool): def load_vector_store( self, kb_name, vector_name, embed_local_model_path, embed_device, create=True, ) -> ThreadSafeFaiss: self.atomic.acquire() cache = self.get((kb_name, vector_name)) if cache is None: item = ThreadSafeFaiss((kb_name, vector_name), pool=self) self.set((kb_name, vector_name), item) with item.acquire(msg="初始化"): self.atomic.release() vs_path = get_vs_path(kb_name, vector_name) if os.path.isfile(os.path.join(vs_path, "index.faiss")): # load the embedding model embeddings = self.load_kb_embeddings(local_model_path=embed_local_model_path, embed_device=embed_device) vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT", allow_dangerous_deserialization=True) elif create: # create an empty vector store if not os.path.exists(vs_path): os.makedirs(vs_path) vector_store = self.new_vector_store(local_model_path=embed_local_model_path, device=embed_device) vector_store.save_local(vs_path) else: raise RuntimeError(f"knowledge base {kb_name} not exist.") item.obj = vector_store item.finish_loading() else: self.atomic.release() return self.get((kb_name, vector_name)) kb_faiss_pool = KBFaissPool(cache_num=2)