From 74909f7927d81f6dc9cb1bebe5f8997e511dd52b Mon Sep 17 00:00:00 2001 From: liu <liujiatong112@163.com> Date: Thu, 6 Jun 2024 15:37:28 +0800 Subject: [PATCH] bugfix --- .../config/VectorSimilarityConfiguration.java | 4 +- .../com/supervision/util/SimilarityUtil.java | 4 +- .../util/VectorSimilarityUtil.java | 123 ------------------ .../controller/TestController.java | 4 +- 4 files changed, 7 insertions(+), 128 deletions(-) delete mode 100644 virtual-patient-common/src/main/java/com/supervision/util/VectorSimilarityUtil.java diff --git a/virtual-patient-common/src/main/java/com/supervision/config/VectorSimilarityConfiguration.java b/virtual-patient-common/src/main/java/com/supervision/config/VectorSimilarityConfiguration.java index 9030e03c..beffc151 100644 --- a/virtual-patient-common/src/main/java/com/supervision/config/VectorSimilarityConfiguration.java +++ b/virtual-patient-common/src/main/java/com/supervision/config/VectorSimilarityConfiguration.java @@ -26,7 +26,7 @@ public class VectorSimilarityConfiguration { .withURI(redisVectorProperties.getUri()) .withPrefix(redisVectorProperties.getPrefix()) .withIndexName(redisVectorProperties.getIndexName()) - // 定义搜索过滤器使用的元数据字段(!!!!!!!!千万重要,数据类型一定和实际插入的一致,否则会导致查询不到!!!!!!!!) + // 定义搜索过滤器使用的元数据字段(!!!!!!!!千万重要,数据类型一定要用字符串,否则会导致查询不到!!!!!!!!) .withMetadataFields( // 问题的ID RedisVectorStore.MetadataField.tag("questionId"), @@ -35,7 +35,7 @@ public class VectorSimilarityConfiguration { // 标准问ID RedisVectorStore.MetadataField.tag("standardQuestionId"), // 类型 1标准问 2相似问 3自定义 - RedisVectorStore.MetadataField.numeric("type")) + RedisVectorStore.MetadataField.tag("type")) .build(); return new RedisVectorStore(config, vectorEmbeddingClient); } diff --git a/virtual-patient-common/src/main/java/com/supervision/util/SimilarityUtil.java b/virtual-patient-common/src/main/java/com/supervision/util/SimilarityUtil.java index 5d711e92..4c2a0725 100644 --- a/virtual-patient-common/src/main/java/com/supervision/util/SimilarityUtil.java +++ b/virtual-patient-common/src/main/java/com/supervision/util/SimilarityUtil.java @@ -45,7 +45,7 @@ public class SimilarityUtil { qaSimilarityQuestionAnswer.setMatchQuestion(document.getContent()); qaSimilarityQuestionAnswer.setDictId(String.valueOf(document.getMetadata().get("dictId"))); qaSimilarityQuestionAnswer.setMatchQuestionCode(String.valueOf(document.getMetadata().get("standardQuestionId"))); - qaSimilarityQuestionAnswer.setMatchScore(Double.parseDouble(String.valueOf(document.getMetadata().get("vector_score")))); + qaSimilarityQuestionAnswer.setMatchScore(1 - Double.parseDouble(String.valueOf(document.getMetadata().get("vector_score")))); return qaSimilarityQuestionAnswer; // 排序,降序,取最高的 }).sorted(Comparator.comparing(QaSimilarityQuestionAnswer::getMatchScore).reversed()).toList(); @@ -54,4 +54,6 @@ public class SimilarityUtil { return new ArrayList<>(); } } + + } diff --git a/virtual-patient-common/src/main/java/com/supervision/util/VectorSimilarityUtil.java b/virtual-patient-common/src/main/java/com/supervision/util/VectorSimilarityUtil.java deleted file mode 100644 index 2b0542b9..00000000 --- a/virtual-patient-common/src/main/java/com/supervision/util/VectorSimilarityUtil.java +++ /dev/null @@ -1,123 +0,0 @@ -package com.supervision.util; - -import cn.hutool.core.map.MapUtil; -import lombok.Data; -import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.document.Document; - -import java.util.*; -import java.util.concurrent.ConcurrentHashMap; - -@Slf4j -public class VectorSimilarityUtil { - - - private static final Map<String, Map<String, Document>> storeMap = new ConcurrentHashMap<>(); - - private static final Double similarityThreshold = 0.5; - - private static final Integer topK = 5; - - - public static void add(String storeId, Document document) { - // TODO 需要序列化成为索引 - List<Double> embedding = null; - document.setEmbedding(embedding); - Map<String, Document> store = storeMap.computeIfAbsent(storeId, k -> new ConcurrentHashMap<>()); - store.put(document.getId(), document); - } - - public static void add(String storeId, List<Document> documents) { - for (Document document : documents) { - // TODO 需要序列化成为索引 - List<Double> embedding = null; - document.setEmbedding(embedding); - Map<String, Document> store = storeMap.computeIfAbsent(storeId, k -> new ConcurrentHashMap<>()); - store.put(document.getId(), document); - } - } - - public static Optional<Boolean> delete(String storeId, List<String> idList) { - if (!storeMap.containsKey(storeId)) { - for (String id : idList) { - storeMap.get(storeId).remove(id); - } - } - return Optional.of(true); - } - - public static List<Document> similaritySearch(String storeId, List<Double> userQueryEmbedding) { - Map<String, Document> store = storeMap.get(storeId); - if (MapUtil.isNotEmpty(store)) { - return store.values() - .stream() - .map(entry -> new Similarity(entry.getId(), - EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding()))) - .filter(s -> s.score >= similarityThreshold) - .sorted(Comparator.<Similarity>comparingDouble(s -> s.score).reversed()) - .limit(topK) - .map(s -> store.get(s.key)) - .toList(); - } - return new ArrayList<>(); - } - - - @Data - public static class Similarity { - - private String key; - - private double score; - - public Similarity(String key, double score) { - this.key = key; - this.score = score; - } - - } - - public static class EmbeddingMath { - - private EmbeddingMath() { - throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); - } - - public static double cosineSimilarity(List<Double> vectorX, List<Double> vectorY) { - if (vectorX == null || vectorY == null) { - throw new RuntimeException("Vectors must not be null"); - } - if (vectorX.size() != vectorY.size()) { - throw new IllegalArgumentException("Vectors lengths must be equal"); - } - - double dotProduct = dotProduct(vectorX, vectorY); - double normX = norm(vectorX); - double normY = norm(vectorY); - - if (normX == 0 || normY == 0) { - throw new IllegalArgumentException("Vectors cannot have zero norm"); - } - - return dotProduct / (Math.sqrt(normX) * Math.sqrt(normY)); - } - - public static double dotProduct(List<Double> vectorX, List<Double> vectorY) { - if (vectorX.size() != vectorY.size()) { - throw new IllegalArgumentException("Vectors lengths must be equal"); - } - - double result = 0; - for (int i = 0; i < vectorX.size(); ++i) { - result += vectorX.get(i) * vectorY.get(i); - } - - return result; - } - - public static double norm(List<Double> vector) { - return dotProduct(vector, vector); - } - - } -} diff --git a/virtual-patient-web/src/main/java/com/supervision/controller/TestController.java b/virtual-patient-web/src/main/java/com/supervision/controller/TestController.java index ec739a39..2e552746 100644 --- a/virtual-patient-web/src/main/java/com/supervision/controller/TestController.java +++ b/virtual-patient-web/src/main/java/com/supervision/controller/TestController.java @@ -86,14 +86,14 @@ public class TestController { Map.of("type", "1", "standardQuestionId", askTemplateQuestionLibrary.getId(), "questionId", askTemplateQuestionLibrary.getId(), - "dictId", askTemplateQuestionLibrary.getDictId())))); + "dictId", String.valueOf(askTemplateQuestionLibrary.getDictId()))))); List<String> question = askTemplateQuestionLibrary.getQuestion(); for (String s : question) { redisVectorStore.add(List.of(new Document(s, Map.of("type", "2", "standardQuestionId", askTemplateQuestionLibrary.getId(), "questionId", askTemplateQuestionLibrary.getId(), - "dictId", askTemplateQuestionLibrary.getDictId())))); + "dictId", String.valueOf(askTemplateQuestionLibrary.getDictId()))))); } } }