增加向量化的库
parent
ed7673de24
commit
3724211a12
@ -0,0 +1,53 @@
|
|||||||
|
package com.supervision.config;
|
||||||
|
|
||||||
|
import cn.hutool.http.HttpUtil;
|
||||||
|
import cn.hutool.json.JSONUtil;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.ai.document.Document;
|
||||||
|
import org.springframework.ai.embedding.*;
|
||||||
|
import org.springframework.util.Assert;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
public class VectorEmbeddingClient extends AbstractEmbeddingClient {
|
||||||
|
|
||||||
|
private final Logger logger = LoggerFactory.getLogger(getClass());
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Double> embed(Document document) {
|
||||||
|
List<List<Double>> list = this.call(new EmbeddingRequest(List.of(document.getContent()), EmbeddingOptions.EMPTY))
|
||||||
|
.getResults()
|
||||||
|
.stream()
|
||||||
|
.map(Embedding::getOutput)
|
||||||
|
.toList();
|
||||||
|
return list.iterator().next();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public EmbeddingResponse call(EmbeddingRequest request) {
|
||||||
|
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
|
||||||
|
if (request.getInstructions().size() != 1) {
|
||||||
|
logger.warn(
|
||||||
|
"Ollama Embedding does not support batch embedding. Will make multiple API calls to embed(Document)");
|
||||||
|
}
|
||||||
|
List<List<Double>> embeddingList = new ArrayList<>();
|
||||||
|
for (String inputContent : request.getInstructions()) {
|
||||||
|
// TODO 这里需要吧inputContent转化为向量数据
|
||||||
|
String post = HttpUtil.post("http://192.168.10.42:8000/embeddings/", JSONUtil.toJsonStr(Map.of("text", inputContent)));
|
||||||
|
String o = JSONUtil.parseObj(post).getStr("embeddings");
|
||||||
|
List<Double> embedding = JSONUtil.toList(o, Double.class);
|
||||||
|
//List<Double> embedding = List.of(1.0, 2.0, 3.0);
|
||||||
|
embeddingList.add(embedding);
|
||||||
|
}
|
||||||
|
var indexCounter = new AtomicInteger(0);
|
||||||
|
List<Embedding> embeddings = embeddingList.stream()
|
||||||
|
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
|
||||||
|
.toList();
|
||||||
|
return new EmbeddingResponse(embeddings);
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
package com.supervision.config;
|
||||||
|
|
||||||
|
import org.springframework.ai.vectorstore.RedisVectorStore;
|
||||||
|
import org.springframework.context.annotation.Bean;
|
||||||
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
|
||||||
|
@Configuration
|
||||||
|
public class VectorSimilarityConfiguration {
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
public VectorEmbeddingClient vectorEmbeddingClient() {
|
||||||
|
return new VectorEmbeddingClient();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
public RedisVectorStore redisVectorStore(VectorEmbeddingClient vectorEmbeddingClient) {
|
||||||
|
RedisVectorStore.RedisVectorStoreConfig config = RedisVectorStore.RedisVectorStoreConfig.builder()
|
||||||
|
.withURI("redis://:123456@192.168.10.137:6380")
|
||||||
|
// 定义搜索过滤器使用的元数据字段
|
||||||
|
.withMetadataFields(
|
||||||
|
RedisVectorStore.MetadataField.tag("medicalId"),
|
||||||
|
RedisVectorStore.MetadataField.tag("type"))
|
||||||
|
.build();
|
||||||
|
return new RedisVectorStore(config, vectorEmbeddingClient);
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,123 @@
|
|||||||
|
package com.supervision.util;
|
||||||
|
|
||||||
|
import cn.hutool.core.map.MapUtil;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.ai.document.Document;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class VectorSimilarityUtil {
|
||||||
|
|
||||||
|
|
||||||
|
private static final Map<String, Map<String, Document>> storeMap = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
|
private static final Double similarityThreshold = 0.5;
|
||||||
|
|
||||||
|
private static final Integer topK = 5;
|
||||||
|
|
||||||
|
|
||||||
|
public static void add(String storeId, Document document) {
|
||||||
|
// TODO 需要序列化成为索引
|
||||||
|
List<Double> embedding = null;
|
||||||
|
document.setEmbedding(embedding);
|
||||||
|
Map<String, Document> store = storeMap.computeIfAbsent(storeId, k -> new ConcurrentHashMap<>());
|
||||||
|
store.put(document.getId(), document);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void add(String storeId, List<Document> documents) {
|
||||||
|
for (Document document : documents) {
|
||||||
|
// TODO 需要序列化成为索引
|
||||||
|
List<Double> embedding = null;
|
||||||
|
document.setEmbedding(embedding);
|
||||||
|
Map<String, Document> store = storeMap.computeIfAbsent(storeId, k -> new ConcurrentHashMap<>());
|
||||||
|
store.put(document.getId(), document);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Optional<Boolean> delete(String storeId, List<String> idList) {
|
||||||
|
if (!storeMap.containsKey(storeId)) {
|
||||||
|
for (String id : idList) {
|
||||||
|
storeMap.get(storeId).remove(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Optional.of(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static List<Document> similaritySearch(String storeId, List<Double> userQueryEmbedding) {
|
||||||
|
Map<String, Document> store = storeMap.get(storeId);
|
||||||
|
if (MapUtil.isNotEmpty(store)) {
|
||||||
|
return store.values()
|
||||||
|
.stream()
|
||||||
|
.map(entry -> new Similarity(entry.getId(),
|
||||||
|
EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding())))
|
||||||
|
.filter(s -> s.score >= similarityThreshold)
|
||||||
|
.sorted(Comparator.<Similarity>comparingDouble(s -> s.score).reversed())
|
||||||
|
.limit(topK)
|
||||||
|
.map(s -> store.get(s.key))
|
||||||
|
.toList();
|
||||||
|
}
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class Similarity {
|
||||||
|
|
||||||
|
private String key;
|
||||||
|
|
||||||
|
private double score;
|
||||||
|
|
||||||
|
public Similarity(String key, double score) {
|
||||||
|
this.key = key;
|
||||||
|
this.score = score;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class EmbeddingMath {
|
||||||
|
|
||||||
|
private EmbeddingMath() {
|
||||||
|
throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static double cosineSimilarity(List<Double> vectorX, List<Double> vectorY) {
|
||||||
|
if (vectorX == null || vectorY == null) {
|
||||||
|
throw new RuntimeException("Vectors must not be null");
|
||||||
|
}
|
||||||
|
if (vectorX.size() != vectorY.size()) {
|
||||||
|
throw new IllegalArgumentException("Vectors lengths must be equal");
|
||||||
|
}
|
||||||
|
|
||||||
|
double dotProduct = dotProduct(vectorX, vectorY);
|
||||||
|
double normX = norm(vectorX);
|
||||||
|
double normY = norm(vectorY);
|
||||||
|
|
||||||
|
if (normX == 0 || normY == 0) {
|
||||||
|
throw new IllegalArgumentException("Vectors cannot have zero norm");
|
||||||
|
}
|
||||||
|
|
||||||
|
return dotProduct / (Math.sqrt(normX) * Math.sqrt(normY));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static double dotProduct(List<Double> vectorX, List<Double> vectorY) {
|
||||||
|
if (vectorX.size() != vectorY.size()) {
|
||||||
|
throw new IllegalArgumentException("Vectors lengths must be equal");
|
||||||
|
}
|
||||||
|
|
||||||
|
double result = 0;
|
||||||
|
for (int i = 0; i < vectorX.size(); ++i) {
|
||||||
|
result += vectorX.get(i) * vectorY.get(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static double norm(List<Double> vector) {
|
||||||
|
return dotProduct(vector, vector);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue