diff --git a/pom.xml b/pom.xml
index 9c375ceb..903e0921 100644
--- a/pom.xml
+++ b/pom.xml
@@ -58,6 +58,7 @@
4.9.0
1.0.3
5.3.1
+ 5.1.0
@@ -92,6 +93,7 @@
import
+
org.springframework.cloud
@@ -181,6 +183,12 @@
${okhttp.version}
+
+ redis.clients
+ jedis
+ ${jedis.version}
+
+
diff --git a/virtual-patient-common/pom.xml b/virtual-patient-common/pom.xml
index 28897208..fa0d8e2d 100644
--- a/virtual-patient-common/pom.xml
+++ b/virtual-patient-common/pom.xml
@@ -48,20 +48,26 @@
httpclient5
+
+ io.springboot.ai
+ spring-ai-redis
+
+
org.springframework.cloud
spring-cloud-starter-bootstrap
-
org.springframework.boot
spring-boot-starter-data-redis
+
redis.clients
jedis
+
com.baomidou
diff --git a/virtual-patient-common/src/main/java/com/supervision/config/VectorEmbeddingClient.java b/virtual-patient-common/src/main/java/com/supervision/config/VectorEmbeddingClient.java
new file mode 100644
index 00000000..73cd2037
--- /dev/null
+++ b/virtual-patient-common/src/main/java/com/supervision/config/VectorEmbeddingClient.java
@@ -0,0 +1,53 @@
+package com.supervision.config;
+
+import cn.hutool.http.HttpUtil;
+import cn.hutool.json.JSONUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.ai.document.Document;
+import org.springframework.ai.embedding.*;
+import org.springframework.util.Assert;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+
+public class VectorEmbeddingClient extends AbstractEmbeddingClient {
+
+ private final Logger logger = LoggerFactory.getLogger(getClass());
+
+
+ @Override
+ public List embed(Document document) {
+ List> list = this.call(new EmbeddingRequest(List.of(document.getContent()), EmbeddingOptions.EMPTY))
+ .getResults()
+ .stream()
+ .map(Embedding::getOutput)
+ .toList();
+ return list.iterator().next();
+ }
+
+ @Override
+ public EmbeddingResponse call(EmbeddingRequest request) {
+ Assert.notEmpty(request.getInstructions(), "At least one text is required!");
+ if (request.getInstructions().size() != 1) {
+ logger.warn(
+ "Ollama Embedding does not support batch embedding. Will make multiple API calls to embed(Document)");
+ }
+ List> embeddingList = new ArrayList<>();
+ for (String inputContent : request.getInstructions()) {
+ // TODO 这里需要吧inputContent转化为向量数据
+ String post = HttpUtil.post("http://192.168.10.42:8000/embeddings/", JSONUtil.toJsonStr(Map.of("text", inputContent)));
+ String o = JSONUtil.parseObj(post).getStr("embeddings");
+ List embedding = JSONUtil.toList(o, Double.class);
+ //List embedding = List.of(1.0, 2.0, 3.0);
+ embeddingList.add(embedding);
+ }
+ var indexCounter = new AtomicInteger(0);
+ List embeddings = embeddingList.stream()
+ .map(e -> new Embedding(e, indexCounter.getAndIncrement()))
+ .toList();
+ return new EmbeddingResponse(embeddings);
+ }
+}
diff --git a/virtual-patient-common/src/main/java/com/supervision/config/VectorSimilarityConfiguration.java b/virtual-patient-common/src/main/java/com/supervision/config/VectorSimilarityConfiguration.java
new file mode 100644
index 00000000..d2ec3699
--- /dev/null
+++ b/virtual-patient-common/src/main/java/com/supervision/config/VectorSimilarityConfiguration.java
@@ -0,0 +1,26 @@
+package com.supervision.config;
+
+import org.springframework.ai.vectorstore.RedisVectorStore;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+
+@Configuration
+public class VectorSimilarityConfiguration {
+
+ @Bean
+ public VectorEmbeddingClient vectorEmbeddingClient() {
+ return new VectorEmbeddingClient();
+ }
+
+ @Bean
+ public RedisVectorStore redisVectorStore(VectorEmbeddingClient vectorEmbeddingClient) {
+ RedisVectorStore.RedisVectorStoreConfig config = RedisVectorStore.RedisVectorStoreConfig.builder()
+ .withURI("redis://:123456@192.168.10.137:6380")
+ // 定义搜索过滤器使用的元数据字段
+ .withMetadataFields(
+ RedisVectorStore.MetadataField.tag("medicalId"),
+ RedisVectorStore.MetadataField.tag("type"))
+ .build();
+ return new RedisVectorStore(config, vectorEmbeddingClient);
+ }
+}
diff --git a/virtual-patient-common/src/main/java/com/supervision/util/VectorSimilarityUtil.java b/virtual-patient-common/src/main/java/com/supervision/util/VectorSimilarityUtil.java
new file mode 100644
index 00000000..2b0542b9
--- /dev/null
+++ b/virtual-patient-common/src/main/java/com/supervision/util/VectorSimilarityUtil.java
@@ -0,0 +1,123 @@
+package com.supervision.util;
+
+import cn.hutool.core.map.MapUtil;
+import lombok.Data;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.ai.document.Document;
+
+import java.util.*;
+import java.util.concurrent.ConcurrentHashMap;
+
+@Slf4j
+public class VectorSimilarityUtil {
+
+
+ private static final Map> storeMap = new ConcurrentHashMap<>();
+
+ private static final Double similarityThreshold = 0.5;
+
+ private static final Integer topK = 5;
+
+
+ public static void add(String storeId, Document document) {
+ // TODO 需要序列化成为索引
+ List embedding = null;
+ document.setEmbedding(embedding);
+ Map store = storeMap.computeIfAbsent(storeId, k -> new ConcurrentHashMap<>());
+ store.put(document.getId(), document);
+ }
+
+ public static void add(String storeId, List documents) {
+ for (Document document : documents) {
+ // TODO 需要序列化成为索引
+ List embedding = null;
+ document.setEmbedding(embedding);
+ Map store = storeMap.computeIfAbsent(storeId, k -> new ConcurrentHashMap<>());
+ store.put(document.getId(), document);
+ }
+ }
+
+ public static Optional delete(String storeId, List idList) {
+ if (!storeMap.containsKey(storeId)) {
+ for (String id : idList) {
+ storeMap.get(storeId).remove(id);
+ }
+ }
+ return Optional.of(true);
+ }
+
+ public static List similaritySearch(String storeId, List userQueryEmbedding) {
+ Map store = storeMap.get(storeId);
+ if (MapUtil.isNotEmpty(store)) {
+ return store.values()
+ .stream()
+ .map(entry -> new Similarity(entry.getId(),
+ EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding())))
+ .filter(s -> s.score >= similarityThreshold)
+ .sorted(Comparator.comparingDouble(s -> s.score).reversed())
+ .limit(topK)
+ .map(s -> store.get(s.key))
+ .toList();
+ }
+ return new ArrayList<>();
+ }
+
+
+ @Data
+ public static class Similarity {
+
+ private String key;
+
+ private double score;
+
+ public Similarity(String key, double score) {
+ this.key = key;
+ this.score = score;
+ }
+
+ }
+
+ public static class EmbeddingMath {
+
+ private EmbeddingMath() {
+ throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
+ }
+
+ public static double cosineSimilarity(List vectorX, List vectorY) {
+ if (vectorX == null || vectorY == null) {
+ throw new RuntimeException("Vectors must not be null");
+ }
+ if (vectorX.size() != vectorY.size()) {
+ throw new IllegalArgumentException("Vectors lengths must be equal");
+ }
+
+ double dotProduct = dotProduct(vectorX, vectorY);
+ double normX = norm(vectorX);
+ double normY = norm(vectorY);
+
+ if (normX == 0 || normY == 0) {
+ throw new IllegalArgumentException("Vectors cannot have zero norm");
+ }
+
+ return dotProduct / (Math.sqrt(normX) * Math.sqrt(normY));
+ }
+
+ public static double dotProduct(List vectorX, List vectorY) {
+ if (vectorX.size() != vectorY.size()) {
+ throw new IllegalArgumentException("Vectors lengths must be equal");
+ }
+
+ double result = 0;
+ for (int i = 0; i < vectorX.size(); ++i) {
+ result += vectorX.get(i) * vectorY.get(i);
+ }
+
+ return result;
+ }
+
+ public static double norm(List vector) {
+ return dotProduct(vector, vector);
+ }
+
+ }
+}
diff --git a/virtual-patient-web/pom.xml b/virtual-patient-web/pom.xml
index ed1dcb10..57d7d249 100644
--- a/virtual-patient-web/pom.xml
+++ b/virtual-patient-web/pom.xml
@@ -110,7 +110,6 @@
test
-
diff --git a/virtual-patient-web/src/main/java/com/supervision/controller/TestController.java b/virtual-patient-web/src/main/java/com/supervision/controller/TestController.java
index 0fe64c12..a45b98a3 100644
--- a/virtual-patient-web/src/main/java/com/supervision/controller/TestController.java
+++ b/virtual-patient-web/src/main/java/com/supervision/controller/TestController.java
@@ -5,18 +5,27 @@ import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.supervision.exception.BusinessException;
+import com.supervision.model.AskTemplateQuestionLibrary;
import com.supervision.model.ConfigPhysicalTool;
+import com.supervision.service.AskTemplateQuestionLibraryService;
import com.supervision.service.ConfigPhysicalToolService;
import com.supervision.util.MinioUtil;
import lombok.RequiredArgsConstructor;
-import org.springframework.web.bind.annotation.*;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.ai.document.Document;
+import org.springframework.ai.vectorstore.RedisVectorStore;
+import org.springframework.ai.vectorstore.SearchRequest;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.PostMapping;
+import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
-import java.io.FileInputStream;
-import java.io.IOException;
import java.io.InputStream;
import java.util.*;
+import java.util.stream.Collectors;
+@Slf4j
@RestController
@RequestMapping("test")
@RequiredArgsConstructor
@@ -24,6 +33,32 @@ public class TestController {
private final ConfigPhysicalToolService configPhysicalToolService;
+ private final RedisVectorStore redisVectorStore;
+
+ private final AskTemplateQuestionLibraryService askTemplateQuestionLibraryService;
+
+ @GetMapping("testRedisVectorStore")
+ public void testRedisVectorStore() {
+ List list = askTemplateQuestionLibraryService.list();
+ for (AskTemplateQuestionLibrary askTemplateQuestionLibrary : list) {
+ String description = askTemplateQuestionLibrary.getDescription();
+ redisVectorStore.add(List.of(new Document(description, Map.of("type", "1", "medicalId", "222"))));
+ List question = askTemplateQuestionLibrary.getQuestion();
+ for (String s : question) {
+ redisVectorStore.add(List.of(new Document(s, Map.of("type", "1", "medicalId", "222"))));
+ }
+// log.info("处理完成:{}", description);
+ }
+ }
+
+ @GetMapping("testQuestion")
+ public List