You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

122 lines
4.1 KiB
Python

from typing import *
from base_kb import ThreadSafeObject, CachePool
from kb_config import get_vs_path
from langchain.vectorstores.faiss import FAISS
# from langchain.docstore.in_memory import InMemoryDocstore
from langchain.schema import Document
import os
# # patch FAISS to include doc id in Document.metadata
# def _new_ds_search(self, search: str) -> Union[str, Document]:
# if search not in self._dict:
# return f"ID {search} not found."
# else:
# doc = self._dict[search]
# if isinstance(doc, Document):
# doc.metadata["id"] = search
# return doc
# InMemoryDocstore.search = _new_ds_search
class ThreadSafeFaiss(ThreadSafeObject):
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
def docs_count(self) -> int:
return len(self._obj.docstore._dict)
def save(self, path: str, create_path: bool = True):
with self.acquire():
if not os.path.isdir(path) and create_path:
os.makedirs(path)
ret = self._obj.save_local(path)
return ret
def clear(self):
ret = []
with self.acquire():
ids = list(self._obj.docstore._dict.keys())
if ids:
ret = self._obj.delete(ids)
assert len(self._obj.docstore._dict) == 0
return ret
class _FaissPool(CachePool):
def new_vector_store(
self,
local_model_path,
device,
) -> FAISS:
# embeddings = EmbeddingsFunAdapter(embed_model)
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name=local_model_path,
model_kwargs={'device': device})
doc = Document(page_content="init", metadata={})
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
ids = list(vector_store.docstore._dict.keys())
vector_store.delete(ids)
return vector_store
def save_vector_store(self, kb_name: str, path: str=None):
if cache := self.get(kb_name):
return cache.save(path)
def unload_vector_store(self, kb_name: str):
if cache := self.get(kb_name):
self.pop(kb_name)
class KBFaissPool(_FaissPool):
def load_vector_store(
self,
kb_name,
vector_name,
embed_local_model_path,
embed_device,
create=True,
) -> ThreadSafeFaiss:
self.atomic.acquire()
cache = self.get((kb_name, vector_name))
if cache is None:
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
self.set((kb_name, vector_name), item)
with item.acquire(msg="初始化"):
self.atomic.release()
vs_path = get_vs_path(kb_name, vector_name)
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
# load the embedding model
embeddings = self.load_kb_embeddings(local_model_path=embed_local_model_path, embed_device=embed_device)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT", allow_dangerous_deserialization=True)
elif create:
# create an empty vector store
if not os.path.exists(vs_path):
os.makedirs(vs_path)
vector_store = self.new_vector_store(local_model_path=embed_local_model_path, device=embed_device)
vector_store.save_local(vs_path)
else:
raise RuntimeError(f"knowledge base {kb_name} not exist.")
item.obj = vector_store
item.finish_loading()
else:
self.atomic.release()
return self.get((kb_name, vector_name))
kb_faiss_pool = KBFaissPool(cache_num=2)