You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

206 lines
6.4 KiB
Python

from langchain.embeddings.base import Embeddings
from langchain.vectorstores.faiss import FAISS
import threading
from contextlib import contextmanager
from collections import OrderedDict
from typing import List, Any, Union, Tuple
class ThreadSafeObject:
def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None):
self._obj = obj
self._key = key
self._pool = pool
self._lock = threading.RLock()
self._loaded = threading.Event()
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self.key}, obj: {self._obj}>"
@property
def key(self):
return self._key
@contextmanager
def acquire(self, owner: str = "", msg: str = "") -> FAISS:
owner = owner or f"thread {threading.get_native_id()}"
try:
self._lock.acquire()
if self._pool is not None:
self._pool._cache.move_to_end(self.key)
yield self._obj
finally:
self._lock.release()
def start_loading(self):
self._loaded.clear()
def finish_loading(self):
self._loaded.set()
def wait_for_loading(self):
self._loaded.wait()
@property
def obj(self):
return self._obj
@obj.setter
def obj(self, val: Any):
self._obj = val
class CachePool:
def __init__(self, cache_num: int = -1):
self._cache_num = cache_num
self._cache = OrderedDict()
self.atomic = threading.RLock()
def keys(self) -> List[str]:
return list(self._cache.keys())
def _check_count(self):
if isinstance(self._cache_num, int) and self._cache_num > 0:
while len(self._cache) > self._cache_num:
self._cache.popitem(last=False)
def get(self, key: str) -> ThreadSafeObject:
if cache := self._cache.get(key):
cache.wait_for_loading()
return cache
def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject:
self._cache[key] = obj
self._check_count()
return obj
def pop(self, key: str = None) -> ThreadSafeObject:
if key is None:
return self._cache.popitem(last=False)
else:
return self._cache.pop(key, None)
def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""):
cache = self.get(key)
if cache is None:
raise RuntimeError(f"请求的资源 {key} 不存在")
elif isinstance(cache, ThreadSafeObject):
self._cache.move_to_end(key)
return cache.acquire(owner=owner, msg=msg)
else:
return cache
def load_kb_embeddings(
self,
local_model_path,
embed_device
) -> Embeddings:
return embeddings_pool.load_embeddings(local_model_path=local_model_path, device=embed_device)
class EmbeddingsPool(CachePool):
def load_embeddings(self, local_model_path, device) -> Embeddings:
self.atomic.acquire()
key = (local_model_path, device)
if not self.get(key):
item = ThreadSafeObject(key, pool=self)
self.set(key, item)
with item.acquire(msg="初始化"):
self.atomic.release()
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name=local_model_path,
model_kwargs={'device': device})
item.obj = embeddings
item.finish_loading()
else:
self.atomic.release()
return self.get(key).obj
embeddings_pool = EmbeddingsPool(cache_num=1)
import numpy as np
def normalize(embeddings: List[List[float]]) -> np.ndarray:
'''
sklearn.preprocessing.normalize 的替代使用 L2避免安装 scipy, scikit-learn
'''
norm = np.linalg.norm(embeddings, axis=1)
norm = np.reshape(norm, (norm.shape[0], 1))
norm = np.tile(norm, (1, len(embeddings[0])))
return np.divide(embeddings, norm)
class EmbeddingsFunAdapter(Embeddings):
def __init__(self, embed_local_model_path, device):
self.embed_local_model_path = embed_local_model_path
self.device = device
def embed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = embeddings_pool.load_embeddings(
model=self.embed_local_model_path,
device=self.device).embed_documents(texts)
return normalize(embeddings).tolist()
def embed_query(self, text: str) -> List[float]:
embeddings = embeddings_pool.load_embeddings(
local_model_path=self.embed_local_model_path,
device=self.device
).embed_documents([text])
query_embed = embeddings[0]
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
def embed_documents(
docs,
local_model_path,
device
):
texts = [x.page_content for x in docs]
metadatas = [x.metadata for x in docs]
embeddings_model = embeddings_pool.load_embeddings(local_model_path=local_model_path, device=device)
embeddings = embeddings_model.embed_documents(texts)
if embeddings is not None:
return {
"texts": texts,
"embeddings": embeddings,
"metadatas": metadatas,
}
class KnowledgeFile:
def __init__(
self,
filename: str,
knowledge_base_name: str
):
self.kb_name = knowledge_base_name
self.filename = filename
def torch_gc():
try:
import torch
if torch.cuda.is_available():
# with torch.cuda.device(DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif torch.backends.mps.is_available():
try:
from torch.mps import empty_cache
empty_cache()
except Exception as e:
msg = ("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,"
"以支持及时清理 torch 产生的内存占用。")
except Exception:
...