pull/1/head
liu 11 months ago
parent 083cd38c29
commit 74909f7927

@ -26,7 +26,7 @@ public class VectorSimilarityConfiguration {
.withURI(redisVectorProperties.getUri())
.withPrefix(redisVectorProperties.getPrefix())
.withIndexName(redisVectorProperties.getIndexName())
// 定义搜索过滤器使用的元数据字段(!!!!!!!!千万重要,数据类型一定和实际插入的一致,否则会导致查询不到!!!!!!!!)
// 定义搜索过滤器使用的元数据字段(!!!!!!!!千万重要,数据类型一定要用字符串,否则会导致查询不到!!!!!!!!)
.withMetadataFields(
// 问题的ID
RedisVectorStore.MetadataField.tag("questionId"),
@ -35,7 +35,7 @@ public class VectorSimilarityConfiguration {
// 标准问ID
RedisVectorStore.MetadataField.tag("standardQuestionId"),
// 类型 1标准问 2相似问 3自定义
RedisVectorStore.MetadataField.numeric("type"))
RedisVectorStore.MetadataField.tag("type"))
.build();
return new RedisVectorStore(config, vectorEmbeddingClient);
}

@ -45,7 +45,7 @@ public class SimilarityUtil {
qaSimilarityQuestionAnswer.setMatchQuestion(document.getContent());
qaSimilarityQuestionAnswer.setDictId(String.valueOf(document.getMetadata().get("dictId")));
qaSimilarityQuestionAnswer.setMatchQuestionCode(String.valueOf(document.getMetadata().get("standardQuestionId")));
qaSimilarityQuestionAnswer.setMatchScore(Double.parseDouble(String.valueOf(document.getMetadata().get("vector_score"))));
qaSimilarityQuestionAnswer.setMatchScore(1 - Double.parseDouble(String.valueOf(document.getMetadata().get("vector_score"))));
return qaSimilarityQuestionAnswer;
// 排序,降序,取最高的
}).sorted(Comparator.comparing(QaSimilarityQuestionAnswer::getMatchScore).reversed()).toList();
@ -54,4 +54,6 @@ public class SimilarityUtil {
return new ArrayList<>();
}
}
}

@ -1,123 +0,0 @@
package com.supervision.util;
import cn.hutool.core.map.MapUtil;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
@Slf4j
public class VectorSimilarityUtil {
private static final Map<String, Map<String, Document>> storeMap = new ConcurrentHashMap<>();
private static final Double similarityThreshold = 0.5;
private static final Integer topK = 5;
public static void add(String storeId, Document document) {
// TODO 需要序列化成为索引
List<Double> embedding = null;
document.setEmbedding(embedding);
Map<String, Document> store = storeMap.computeIfAbsent(storeId, k -> new ConcurrentHashMap<>());
store.put(document.getId(), document);
}
public static void add(String storeId, List<Document> documents) {
for (Document document : documents) {
// TODO 需要序列化成为索引
List<Double> embedding = null;
document.setEmbedding(embedding);
Map<String, Document> store = storeMap.computeIfAbsent(storeId, k -> new ConcurrentHashMap<>());
store.put(document.getId(), document);
}
}
public static Optional<Boolean> delete(String storeId, List<String> idList) {
if (!storeMap.containsKey(storeId)) {
for (String id : idList) {
storeMap.get(storeId).remove(id);
}
}
return Optional.of(true);
}
public static List<Document> similaritySearch(String storeId, List<Double> userQueryEmbedding) {
Map<String, Document> store = storeMap.get(storeId);
if (MapUtil.isNotEmpty(store)) {
return store.values()
.stream()
.map(entry -> new Similarity(entry.getId(),
EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding())))
.filter(s -> s.score >= similarityThreshold)
.sorted(Comparator.<Similarity>comparingDouble(s -> s.score).reversed())
.limit(topK)
.map(s -> store.get(s.key))
.toList();
}
return new ArrayList<>();
}
@Data
public static class Similarity {
private String key;
private double score;
public Similarity(String key, double score) {
this.key = key;
this.score = score;
}
}
public static class EmbeddingMath {
private EmbeddingMath() {
throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
}
public static double cosineSimilarity(List<Double> vectorX, List<Double> vectorY) {
if (vectorX == null || vectorY == null) {
throw new RuntimeException("Vectors must not be null");
}
if (vectorX.size() != vectorY.size()) {
throw new IllegalArgumentException("Vectors lengths must be equal");
}
double dotProduct = dotProduct(vectorX, vectorY);
double normX = norm(vectorX);
double normY = norm(vectorY);
if (normX == 0 || normY == 0) {
throw new IllegalArgumentException("Vectors cannot have zero norm");
}
return dotProduct / (Math.sqrt(normX) * Math.sqrt(normY));
}
public static double dotProduct(List<Double> vectorX, List<Double> vectorY) {
if (vectorX.size() != vectorY.size()) {
throw new IllegalArgumentException("Vectors lengths must be equal");
}
double result = 0;
for (int i = 0; i < vectorX.size(); ++i) {
result += vectorX.get(i) * vectorY.get(i);
}
return result;
}
public static double norm(List<Double> vector) {
return dotProduct(vector, vector);
}
}
}

@ -86,14 +86,14 @@ public class TestController {
Map.of("type", "1",
"standardQuestionId", askTemplateQuestionLibrary.getId(),
"questionId", askTemplateQuestionLibrary.getId(),
"dictId", askTemplateQuestionLibrary.getDictId()))));
"dictId", String.valueOf(askTemplateQuestionLibrary.getDictId())))));
List<String> question = askTemplateQuestionLibrary.getQuestion();
for (String s : question) {
redisVectorStore.add(List.of(new Document(s,
Map.of("type", "2",
"standardQuestionId", askTemplateQuestionLibrary.getId(),
"questionId", askTemplateQuestionLibrary.getId(),
"dictId", askTemplateQuestionLibrary.getDictId()))));
"dictId", String.valueOf(askTemplateQuestionLibrary.getDictId())))));
}
}
}

Loading…
Cancel
Save