|
|
from langchain.embeddings.base import Embeddings
|
|
|
from langchain.vectorstores.faiss import FAISS
|
|
|
import threading
|
|
|
from contextlib import contextmanager
|
|
|
from collections import OrderedDict
|
|
|
from typing import List, Any, Union, Tuple
|
|
|
|
|
|
|
|
|
class ThreadSafeObject:
|
|
|
def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None):
|
|
|
self._obj = obj
|
|
|
self._key = key
|
|
|
self._pool = pool
|
|
|
self._lock = threading.RLock()
|
|
|
self._loaded = threading.Event()
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
cls = type(self).__name__
|
|
|
return f"<{cls}: key: {self.key}, obj: {self._obj}>"
|
|
|
|
|
|
@property
|
|
|
def key(self):
|
|
|
return self._key
|
|
|
|
|
|
@contextmanager
|
|
|
def acquire(self, owner: str = "", msg: str = "") -> FAISS:
|
|
|
owner = owner or f"thread {threading.get_native_id()}"
|
|
|
try:
|
|
|
self._lock.acquire()
|
|
|
if self._pool is not None:
|
|
|
self._pool._cache.move_to_end(self.key)
|
|
|
yield self._obj
|
|
|
finally:
|
|
|
self._lock.release()
|
|
|
|
|
|
def start_loading(self):
|
|
|
self._loaded.clear()
|
|
|
|
|
|
def finish_loading(self):
|
|
|
self._loaded.set()
|
|
|
|
|
|
def wait_for_loading(self):
|
|
|
self._loaded.wait()
|
|
|
|
|
|
@property
|
|
|
def obj(self):
|
|
|
return self._obj
|
|
|
|
|
|
@obj.setter
|
|
|
def obj(self, val: Any):
|
|
|
self._obj = val
|
|
|
|
|
|
|
|
|
class CachePool:
|
|
|
def __init__(self, cache_num: int = -1):
|
|
|
self._cache_num = cache_num
|
|
|
self._cache = OrderedDict()
|
|
|
self.atomic = threading.RLock()
|
|
|
|
|
|
def keys(self) -> List[str]:
|
|
|
return list(self._cache.keys())
|
|
|
|
|
|
def _check_count(self):
|
|
|
if isinstance(self._cache_num, int) and self._cache_num > 0:
|
|
|
while len(self._cache) > self._cache_num:
|
|
|
self._cache.popitem(last=False)
|
|
|
|
|
|
def get(self, key: str) -> ThreadSafeObject:
|
|
|
if cache := self._cache.get(key):
|
|
|
cache.wait_for_loading()
|
|
|
return cache
|
|
|
|
|
|
def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject:
|
|
|
self._cache[key] = obj
|
|
|
self._check_count()
|
|
|
return obj
|
|
|
|
|
|
def pop(self, key: str = None) -> ThreadSafeObject:
|
|
|
if key is None:
|
|
|
return self._cache.popitem(last=False)
|
|
|
else:
|
|
|
return self._cache.pop(key, None)
|
|
|
|
|
|
def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""):
|
|
|
cache = self.get(key)
|
|
|
if cache is None:
|
|
|
raise RuntimeError(f"请求的资源 {key} 不存在")
|
|
|
elif isinstance(cache, ThreadSafeObject):
|
|
|
self._cache.move_to_end(key)
|
|
|
return cache.acquire(owner=owner, msg=msg)
|
|
|
else:
|
|
|
return cache
|
|
|
|
|
|
def load_kb_embeddings(
|
|
|
self,
|
|
|
local_model_path,
|
|
|
embed_device
|
|
|
) -> Embeddings:
|
|
|
return embeddings_pool.load_embeddings(local_model_path=local_model_path, device=embed_device)
|
|
|
|
|
|
|
|
|
class EmbeddingsPool(CachePool):
|
|
|
def load_embeddings(self, local_model_path, device) -> Embeddings:
|
|
|
self.atomic.acquire()
|
|
|
key = (local_model_path, device)
|
|
|
if not self.get(key):
|
|
|
item = ThreadSafeObject(key, pool=self)
|
|
|
self.set(key, item)
|
|
|
with item.acquire(msg="初始化"):
|
|
|
self.atomic.release()
|
|
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
|
embeddings = HuggingFaceEmbeddings(model_name=local_model_path,
|
|
|
model_kwargs={'device': device})
|
|
|
item.obj = embeddings
|
|
|
item.finish_loading()
|
|
|
else:
|
|
|
self.atomic.release()
|
|
|
return self.get(key).obj
|
|
|
|
|
|
|
|
|
embeddings_pool = EmbeddingsPool(cache_num=1)
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
def normalize(embeddings: List[List[float]]) -> np.ndarray:
|
|
|
'''
|
|
|
sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn
|
|
|
'''
|
|
|
norm = np.linalg.norm(embeddings, axis=1)
|
|
|
norm = np.reshape(norm, (norm.shape[0], 1))
|
|
|
norm = np.tile(norm, (1, len(embeddings[0])))
|
|
|
return np.divide(embeddings, norm)
|
|
|
|
|
|
|
|
|
class EmbeddingsFunAdapter(Embeddings):
|
|
|
def __init__(self, embed_local_model_path, device):
|
|
|
self.embed_local_model_path = embed_local_model_path
|
|
|
self.device = device
|
|
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
|
embeddings = embeddings_pool.load_embeddings(
|
|
|
model=self.embed_local_model_path,
|
|
|
device=self.device).embed_documents(texts)
|
|
|
return normalize(embeddings).tolist()
|
|
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
|
embeddings = embeddings_pool.load_embeddings(
|
|
|
local_model_path=self.embed_local_model_path,
|
|
|
device=self.device
|
|
|
).embed_documents([text])
|
|
|
query_embed = embeddings[0]
|
|
|
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
|
|
normalized_query_embed = normalize(query_embed_2d)
|
|
|
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
|
|
|
|
|
|
|
|
|
|
|
def embed_documents(
|
|
|
docs,
|
|
|
local_model_path,
|
|
|
device
|
|
|
):
|
|
|
|
|
|
texts = [x.page_content for x in docs]
|
|
|
metadatas = [x.metadata for x in docs]
|
|
|
|
|
|
embeddings_model = embeddings_pool.load_embeddings(local_model_path=local_model_path, device=device)
|
|
|
embeddings = embeddings_model.embed_documents(texts)
|
|
|
if embeddings is not None:
|
|
|
return {
|
|
|
"texts": texts,
|
|
|
"embeddings": embeddings,
|
|
|
"metadatas": metadatas,
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class KnowledgeFile:
|
|
|
def __init__(
|
|
|
self,
|
|
|
filename: str,
|
|
|
knowledge_base_name: str
|
|
|
):
|
|
|
self.kb_name = knowledge_base_name
|
|
|
self.filename = filename
|
|
|
|
|
|
|
|
|
def torch_gc():
|
|
|
try:
|
|
|
import torch
|
|
|
if torch.cuda.is_available():
|
|
|
# with torch.cuda.device(DEVICE):
|
|
|
torch.cuda.empty_cache()
|
|
|
torch.cuda.ipc_collect()
|
|
|
elif torch.backends.mps.is_available():
|
|
|
try:
|
|
|
from torch.mps import empty_cache
|
|
|
empty_cache()
|
|
|
except Exception as e:
|
|
|
msg = ("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,"
|
|
|
"以支持及时清理 torch 产生的内存占用。")
|
|
|
except Exception:
|
|
|
...
|
|
|
|
|
|
|