bugfix
parent
083cd38c29
commit
74909f7927
@ -1,123 +0,0 @@
|
|||||||
package com.supervision.util;
|
|
||||||
|
|
||||||
import cn.hutool.core.map.MapUtil;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.ai.document.Document;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class VectorSimilarityUtil {
|
|
||||||
|
|
||||||
|
|
||||||
private static final Map<String, Map<String, Document>> storeMap = new ConcurrentHashMap<>();
|
|
||||||
|
|
||||||
private static final Double similarityThreshold = 0.5;
|
|
||||||
|
|
||||||
private static final Integer topK = 5;
|
|
||||||
|
|
||||||
|
|
||||||
public static void add(String storeId, Document document) {
|
|
||||||
// TODO 需要序列化成为索引
|
|
||||||
List<Double> embedding = null;
|
|
||||||
document.setEmbedding(embedding);
|
|
||||||
Map<String, Document> store = storeMap.computeIfAbsent(storeId, k -> new ConcurrentHashMap<>());
|
|
||||||
store.put(document.getId(), document);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void add(String storeId, List<Document> documents) {
|
|
||||||
for (Document document : documents) {
|
|
||||||
// TODO 需要序列化成为索引
|
|
||||||
List<Double> embedding = null;
|
|
||||||
document.setEmbedding(embedding);
|
|
||||||
Map<String, Document> store = storeMap.computeIfAbsent(storeId, k -> new ConcurrentHashMap<>());
|
|
||||||
store.put(document.getId(), document);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Optional<Boolean> delete(String storeId, List<String> idList) {
|
|
||||||
if (!storeMap.containsKey(storeId)) {
|
|
||||||
for (String id : idList) {
|
|
||||||
storeMap.get(storeId).remove(id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Optional.of(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static List<Document> similaritySearch(String storeId, List<Double> userQueryEmbedding) {
|
|
||||||
Map<String, Document> store = storeMap.get(storeId);
|
|
||||||
if (MapUtil.isNotEmpty(store)) {
|
|
||||||
return store.values()
|
|
||||||
.stream()
|
|
||||||
.map(entry -> new Similarity(entry.getId(),
|
|
||||||
EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding())))
|
|
||||||
.filter(s -> s.score >= similarityThreshold)
|
|
||||||
.sorted(Comparator.<Similarity>comparingDouble(s -> s.score).reversed())
|
|
||||||
.limit(topK)
|
|
||||||
.map(s -> store.get(s.key))
|
|
||||||
.toList();
|
|
||||||
}
|
|
||||||
return new ArrayList<>();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public static class Similarity {
|
|
||||||
|
|
||||||
private String key;
|
|
||||||
|
|
||||||
private double score;
|
|
||||||
|
|
||||||
public Similarity(String key, double score) {
|
|
||||||
this.key = key;
|
|
||||||
this.score = score;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public static class EmbeddingMath {
|
|
||||||
|
|
||||||
private EmbeddingMath() {
|
|
||||||
throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
|
|
||||||
}
|
|
||||||
|
|
||||||
public static double cosineSimilarity(List<Double> vectorX, List<Double> vectorY) {
|
|
||||||
if (vectorX == null || vectorY == null) {
|
|
||||||
throw new RuntimeException("Vectors must not be null");
|
|
||||||
}
|
|
||||||
if (vectorX.size() != vectorY.size()) {
|
|
||||||
throw new IllegalArgumentException("Vectors lengths must be equal");
|
|
||||||
}
|
|
||||||
|
|
||||||
double dotProduct = dotProduct(vectorX, vectorY);
|
|
||||||
double normX = norm(vectorX);
|
|
||||||
double normY = norm(vectorY);
|
|
||||||
|
|
||||||
if (normX == 0 || normY == 0) {
|
|
||||||
throw new IllegalArgumentException("Vectors cannot have zero norm");
|
|
||||||
}
|
|
||||||
|
|
||||||
return dotProduct / (Math.sqrt(normX) * Math.sqrt(normY));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static double dotProduct(List<Double> vectorX, List<Double> vectorY) {
|
|
||||||
if (vectorX.size() != vectorY.size()) {
|
|
||||||
throw new IllegalArgumentException("Vectors lengths must be equal");
|
|
||||||
}
|
|
||||||
|
|
||||||
double result = 0;
|
|
||||||
for (int i = 0; i < vectorX.size(); ++i) {
|
|
||||||
result += vectorX.get(i) * vectorY.get(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static double norm(List<Double> vector) {
|
|
||||||
return dotProduct(vector, vector);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue