from langchain.embeddings.base import Embeddings from langchain.vectorstores.faiss import FAISS import threading from contextlib import contextmanager from collections import OrderedDict from typing import List, Any, Union, Tuple class ThreadSafeObject: def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None): self._obj = obj self._key = key self._pool = pool self._lock = threading.RLock() self._loaded = threading.Event() def __repr__(self) -> str: cls = type(self).__name__ return f"<{cls}: key: {self.key}, obj: {self._obj}>" @property def key(self): return self._key @contextmanager def acquire(self, owner: str = "", msg: str = "") -> FAISS: owner = owner or f"thread {threading.get_native_id()}" try: self._lock.acquire() if self._pool is not None: self._pool._cache.move_to_end(self.key) yield self._obj finally: self._lock.release() def start_loading(self): self._loaded.clear() def finish_loading(self): self._loaded.set() def wait_for_loading(self): self._loaded.wait() @property def obj(self): return self._obj @obj.setter def obj(self, val: Any): self._obj = val class CachePool: def __init__(self, cache_num: int = -1): self._cache_num = cache_num self._cache = OrderedDict() self.atomic = threading.RLock() def keys(self) -> List[str]: return list(self._cache.keys()) def _check_count(self): if isinstance(self._cache_num, int) and self._cache_num > 0: while len(self._cache) > self._cache_num: self._cache.popitem(last=False) def get(self, key: str) -> ThreadSafeObject: if cache := self._cache.get(key): cache.wait_for_loading() return cache def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject: self._cache[key] = obj self._check_count() return obj def pop(self, key: str = None) -> ThreadSafeObject: if key is None: return self._cache.popitem(last=False) else: return self._cache.pop(key, None) def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""): cache = self.get(key) if cache is None: raise RuntimeError(f"请求的资源 {key} 不存在") elif isinstance(cache, ThreadSafeObject): self._cache.move_to_end(key) return cache.acquire(owner=owner, msg=msg) else: return cache def load_kb_embeddings( self, local_model_path, embed_device ) -> Embeddings: return embeddings_pool.load_embeddings(local_model_path=local_model_path, device=embed_device) class EmbeddingsPool(CachePool): def load_embeddings(self, local_model_path, device) -> Embeddings: self.atomic.acquire() key = (local_model_path, device) if not self.get(key): item = ThreadSafeObject(key, pool=self) self.set(key, item) with item.acquire(msg="初始化"): self.atomic.release() from langchain.embeddings.huggingface import HuggingFaceEmbeddings embeddings = HuggingFaceEmbeddings(model_name=local_model_path, model_kwargs={'device': device}) item.obj = embeddings item.finish_loading() else: self.atomic.release() return self.get(key).obj embeddings_pool = EmbeddingsPool(cache_num=1) import numpy as np def normalize(embeddings: List[List[float]]) -> np.ndarray: ''' sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn ''' norm = np.linalg.norm(embeddings, axis=1) norm = np.reshape(norm, (norm.shape[0], 1)) norm = np.tile(norm, (1, len(embeddings[0]))) return np.divide(embeddings, norm) class EmbeddingsFunAdapter(Embeddings): def __init__(self, embed_local_model_path, device): self.embed_local_model_path = embed_local_model_path self.device = device def embed_documents(self, texts: List[str]) -> List[List[float]]: embeddings = embeddings_pool.load_embeddings( model=self.embed_local_model_path, device=self.device).embed_documents(texts) return normalize(embeddings).tolist() def embed_query(self, text: str) -> List[float]: embeddings = embeddings_pool.load_embeddings( local_model_path=self.embed_local_model_path, device=self.device ).embed_documents([text]) query_embed = embeddings[0] query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组 normalized_query_embed = normalize(query_embed_2d) return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回 def embed_documents( docs, local_model_path, device ): texts = [x.page_content for x in docs] metadatas = [x.metadata for x in docs] embeddings_model = embeddings_pool.load_embeddings(local_model_path=local_model_path, device=device) embeddings = embeddings_model.embed_documents(texts) if embeddings is not None: return { "texts": texts, "embeddings": embeddings, "metadatas": metadatas, } class KnowledgeFile: def __init__( self, filename: str, knowledge_base_name: str ): self.kb_name = knowledge_base_name self.filename = filename def torch_gc(): try: import torch if torch.cuda.is_available(): # with torch.cuda.device(DEVICE): torch.cuda.empty_cache() torch.cuda.ipc_collect() elif torch.backends.mps.is_available(): try: from torch.mps import empty_cache empty_cache() except Exception as e: msg = ("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本," "以支持及时清理 torch 产生的内存占用。") except Exception: ...