import os import shutil from base_kb import EmbeddingsFunAdapter, KnowledgeFile, torch_gc, embed_documents from kb_config import get_kb_path, get_vs_path from faiss_cache import ThreadSafeFaiss, kb_faiss_pool from langchain.docstore.document import Document from typing import List, Dict, Optional, Tuple class DocumentWithVectorStoreId(Document): """ 矢量化后的文档 """ id: str = None score: float = 3.0 class FaissKBService(): vs_path: str kb_path: str def __init__(self, kb_name, embed_model_path, device) -> None: self.kb_name = kb_name self.embed_model_path = embed_model_path self.device = device self.do_init() def get_vs_path(self): return get_vs_path(self.kb_name, self.vector_name) def get_kb_path(self): return get_kb_path(self.kb_name) def load_vector_store(self) -> ThreadSafeFaiss: return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, vector_name=self.vector_name, embed_local_model_path=self.embed_model_path, embed_device=self.device) def save_vector_store(self): self.load_vector_store().save(self.vs_path) def get_doc_by_ids(self, ids: List[str]) -> List[Document]: with self.load_vector_store().acquire() as vs: return [vs.docstore._dict.get(id) for id in ids] def del_doc_by_ids(self, ids: List[str]) -> bool: with self.load_vector_store().acquire() as vs: vs.delete(ids) def do_init(self): self.vector_name = 'FAISS' self.kb_path = self.get_kb_path() self.vs_path = self.get_vs_path() def do_create_kb(self): if not os.path.exists(self.vs_path): os.makedirs(self.vs_path) self.load_vector_store() def do_drop_kb(self): self.clear_vs() try: shutil.rmtree(self.kb_path) except Exception: ... def do_search(self, query: str, top_k: int, score_threshold, ) -> List[Tuple[Document, float]]: embed_func = EmbeddingsFunAdapter(self.embed_model_path, self.device) embeddings = embed_func.embed_query(query) with self.load_vector_store().acquire() as vs: docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) return docs def _docs_to_embeddings(self, docs: List[Document]) -> Dict: return embed_documents(docs, self.embed_model_path, self.device) def do_add_doc(self, docs: List[Document], **kwargs, ) -> List[Dict]: data = self._docs_to_embeddings(docs) # 将向量化单独出来可以减少向量库的锁定时间 with self.load_vector_store().acquire() as vs: ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]), metadatas=data["metadatas"], ids=kwargs.get("ids")) if not kwargs.get("not_refresh_vs_cache"): vs.save_local(self.vs_path) doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] torch_gc() return doc_infos def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): with self.load_vector_store().acquire() as vs: ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source").lower() == kb_file.filename.lower()] if len(ids) > 0: vs.delete(ids) if not kwargs.get("not_refresh_vs_cache"): vs.save_local(self.vs_path) return ids def do_clear_vs(self): with kb_faiss_pool.atomic: kb_faiss_pool.pop((self.kb_name, self.vector_name)) try: shutil.rmtree(self.vs_path) except Exception: ... os.makedirs(self.vs_path, exist_ok=True) def exist_doc(self, file_name: str): if super().exist_doc(file_name): return "in_db" content_path = os.path.join(self.kb_path, "content") if os.path.isfile(os.path.join(content_path, file_name)): return "in_folder" else: return False