bugfix
parent
083cd38c29
commit
74909f7927
@ -1,123 +0,0 @@
|
||||
package com.supervision.util;
|
||||
|
||||
import cn.hutool.core.map.MapUtil;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.document.Document;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
@Slf4j
|
||||
public class VectorSimilarityUtil {
|
||||
|
||||
|
||||
private static final Map<String, Map<String, Document>> storeMap = new ConcurrentHashMap<>();
|
||||
|
||||
private static final Double similarityThreshold = 0.5;
|
||||
|
||||
private static final Integer topK = 5;
|
||||
|
||||
|
||||
public static void add(String storeId, Document document) {
|
||||
// TODO 需要序列化成为索引
|
||||
List<Double> embedding = null;
|
||||
document.setEmbedding(embedding);
|
||||
Map<String, Document> store = storeMap.computeIfAbsent(storeId, k -> new ConcurrentHashMap<>());
|
||||
store.put(document.getId(), document);
|
||||
}
|
||||
|
||||
public static void add(String storeId, List<Document> documents) {
|
||||
for (Document document : documents) {
|
||||
// TODO 需要序列化成为索引
|
||||
List<Double> embedding = null;
|
||||
document.setEmbedding(embedding);
|
||||
Map<String, Document> store = storeMap.computeIfAbsent(storeId, k -> new ConcurrentHashMap<>());
|
||||
store.put(document.getId(), document);
|
||||
}
|
||||
}
|
||||
|
||||
public static Optional<Boolean> delete(String storeId, List<String> idList) {
|
||||
if (!storeMap.containsKey(storeId)) {
|
||||
for (String id : idList) {
|
||||
storeMap.get(storeId).remove(id);
|
||||
}
|
||||
}
|
||||
return Optional.of(true);
|
||||
}
|
||||
|
||||
public static List<Document> similaritySearch(String storeId, List<Double> userQueryEmbedding) {
|
||||
Map<String, Document> store = storeMap.get(storeId);
|
||||
if (MapUtil.isNotEmpty(store)) {
|
||||
return store.values()
|
||||
.stream()
|
||||
.map(entry -> new Similarity(entry.getId(),
|
||||
EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding())))
|
||||
.filter(s -> s.score >= similarityThreshold)
|
||||
.sorted(Comparator.<Similarity>comparingDouble(s -> s.score).reversed())
|
||||
.limit(topK)
|
||||
.map(s -> store.get(s.key))
|
||||
.toList();
|
||||
}
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
|
||||
@Data
|
||||
public static class Similarity {
|
||||
|
||||
private String key;
|
||||
|
||||
private double score;
|
||||
|
||||
public Similarity(String key, double score) {
|
||||
this.key = key;
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static class EmbeddingMath {
|
||||
|
||||
private EmbeddingMath() {
|
||||
throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
|
||||
}
|
||||
|
||||
public static double cosineSimilarity(List<Double> vectorX, List<Double> vectorY) {
|
||||
if (vectorX == null || vectorY == null) {
|
||||
throw new RuntimeException("Vectors must not be null");
|
||||
}
|
||||
if (vectorX.size() != vectorY.size()) {
|
||||
throw new IllegalArgumentException("Vectors lengths must be equal");
|
||||
}
|
||||
|
||||
double dotProduct = dotProduct(vectorX, vectorY);
|
||||
double normX = norm(vectorX);
|
||||
double normY = norm(vectorY);
|
||||
|
||||
if (normX == 0 || normY == 0) {
|
||||
throw new IllegalArgumentException("Vectors cannot have zero norm");
|
||||
}
|
||||
|
||||
return dotProduct / (Math.sqrt(normX) * Math.sqrt(normY));
|
||||
}
|
||||
|
||||
public static double dotProduct(List<Double> vectorX, List<Double> vectorY) {
|
||||
if (vectorX.size() != vectorY.size()) {
|
||||
throw new IllegalArgumentException("Vectors lengths must be equal");
|
||||
}
|
||||
|
||||
double result = 0;
|
||||
for (int i = 0; i < vectorX.size(); ++i) {
|
||||
result += vectorX.get(i) * vectorY.get(i);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
public static double norm(List<Double> vector) {
|
||||
return dotProduct(vector, vector);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue