From 3724211a126f9dd152489a5318f92918221f303b Mon Sep 17 00:00:00 2001 From: liu Date: Wed, 5 Jun 2024 09:11:53 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=90=91=E9=87=8F=E5=8C=96?= =?UTF-8?q?=E7=9A=84=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pom.xml | 8 ++ virtual-patient-common/pom.xml | 8 +- .../config/VectorEmbeddingClient.java | 53 ++++++++ .../config/VectorSimilarityConfiguration.java | 26 ++++ .../util/VectorSimilarityUtil.java | 123 ++++++++++++++++++ virtual-patient-web/pom.xml | 1 - .../controller/TestController.java | 56 ++++++-- .../src/main/resources/application.yml | 8 ++ 8 files changed, 271 insertions(+), 12 deletions(-) create mode 100644 virtual-patient-common/src/main/java/com/supervision/config/VectorEmbeddingClient.java create mode 100644 virtual-patient-common/src/main/java/com/supervision/config/VectorSimilarityConfiguration.java create mode 100644 virtual-patient-common/src/main/java/com/supervision/util/VectorSimilarityUtil.java diff --git a/pom.xml b/pom.xml index 9c375ceb..903e0921 100644 --- a/pom.xml +++ b/pom.xml @@ -58,6 +58,7 @@ 4.9.0 1.0.3 5.3.1 + 5.1.0 @@ -92,6 +93,7 @@ import + org.springframework.cloud @@ -181,6 +183,12 @@ ${okhttp.version} + + redis.clients + jedis + ${jedis.version} + + diff --git a/virtual-patient-common/pom.xml b/virtual-patient-common/pom.xml index 28897208..fa0d8e2d 100644 --- a/virtual-patient-common/pom.xml +++ b/virtual-patient-common/pom.xml @@ -48,20 +48,26 @@ httpclient5 + + io.springboot.ai + spring-ai-redis + + org.springframework.cloud spring-cloud-starter-bootstrap - org.springframework.boot spring-boot-starter-data-redis + redis.clients jedis + com.baomidou diff --git a/virtual-patient-common/src/main/java/com/supervision/config/VectorEmbeddingClient.java b/virtual-patient-common/src/main/java/com/supervision/config/VectorEmbeddingClient.java new file mode 100644 index 00000000..73cd2037 --- /dev/null +++ b/virtual-patient-common/src/main/java/com/supervision/config/VectorEmbeddingClient.java @@ -0,0 +1,53 @@ +package com.supervision.config; + +import cn.hutool.http.HttpUtil; +import cn.hutool.json.JSONUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.*; +import org.springframework.util.Assert; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +public class VectorEmbeddingClient extends AbstractEmbeddingClient { + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + + @Override + public List embed(Document document) { + List> list = this.call(new EmbeddingRequest(List.of(document.getContent()), EmbeddingOptions.EMPTY)) + .getResults() + .stream() + .map(Embedding::getOutput) + .toList(); + return list.iterator().next(); + } + + @Override + public EmbeddingResponse call(EmbeddingRequest request) { + Assert.notEmpty(request.getInstructions(), "At least one text is required!"); + if (request.getInstructions().size() != 1) { + logger.warn( + "Ollama Embedding does not support batch embedding. Will make multiple API calls to embed(Document)"); + } + List> embeddingList = new ArrayList<>(); + for (String inputContent : request.getInstructions()) { + // TODO 这里需要吧inputContent转化为向量数据 + String post = HttpUtil.post("http://192.168.10.42:8000/embeddings/", JSONUtil.toJsonStr(Map.of("text", inputContent))); + String o = JSONUtil.parseObj(post).getStr("embeddings"); + List embedding = JSONUtil.toList(o, Double.class); + //List embedding = List.of(1.0, 2.0, 3.0); + embeddingList.add(embedding); + } + var indexCounter = new AtomicInteger(0); + List embeddings = embeddingList.stream() + .map(e -> new Embedding(e, indexCounter.getAndIncrement())) + .toList(); + return new EmbeddingResponse(embeddings); + } +} diff --git a/virtual-patient-common/src/main/java/com/supervision/config/VectorSimilarityConfiguration.java b/virtual-patient-common/src/main/java/com/supervision/config/VectorSimilarityConfiguration.java new file mode 100644 index 00000000..d2ec3699 --- /dev/null +++ b/virtual-patient-common/src/main/java/com/supervision/config/VectorSimilarityConfiguration.java @@ -0,0 +1,26 @@ +package com.supervision.config; + +import org.springframework.ai.vectorstore.RedisVectorStore; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class VectorSimilarityConfiguration { + + @Bean + public VectorEmbeddingClient vectorEmbeddingClient() { + return new VectorEmbeddingClient(); + } + + @Bean + public RedisVectorStore redisVectorStore(VectorEmbeddingClient vectorEmbeddingClient) { + RedisVectorStore.RedisVectorStoreConfig config = RedisVectorStore.RedisVectorStoreConfig.builder() + .withURI("redis://:123456@192.168.10.137:6380") + // 定义搜索过滤器使用的元数据字段 + .withMetadataFields( + RedisVectorStore.MetadataField.tag("medicalId"), + RedisVectorStore.MetadataField.tag("type")) + .build(); + return new RedisVectorStore(config, vectorEmbeddingClient); + } +} diff --git a/virtual-patient-common/src/main/java/com/supervision/util/VectorSimilarityUtil.java b/virtual-patient-common/src/main/java/com/supervision/util/VectorSimilarityUtil.java new file mode 100644 index 00000000..2b0542b9 --- /dev/null +++ b/virtual-patient-common/src/main/java/com/supervision/util/VectorSimilarityUtil.java @@ -0,0 +1,123 @@ +package com.supervision.util; + +import cn.hutool.core.map.MapUtil; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.document.Document; + +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; + +@Slf4j +public class VectorSimilarityUtil { + + + private static final Map> storeMap = new ConcurrentHashMap<>(); + + private static final Double similarityThreshold = 0.5; + + private static final Integer topK = 5; + + + public static void add(String storeId, Document document) { + // TODO 需要序列化成为索引 + List embedding = null; + document.setEmbedding(embedding); + Map store = storeMap.computeIfAbsent(storeId, k -> new ConcurrentHashMap<>()); + store.put(document.getId(), document); + } + + public static void add(String storeId, List documents) { + for (Document document : documents) { + // TODO 需要序列化成为索引 + List embedding = null; + document.setEmbedding(embedding); + Map store = storeMap.computeIfAbsent(storeId, k -> new ConcurrentHashMap<>()); + store.put(document.getId(), document); + } + } + + public static Optional delete(String storeId, List idList) { + if (!storeMap.containsKey(storeId)) { + for (String id : idList) { + storeMap.get(storeId).remove(id); + } + } + return Optional.of(true); + } + + public static List similaritySearch(String storeId, List userQueryEmbedding) { + Map store = storeMap.get(storeId); + if (MapUtil.isNotEmpty(store)) { + return store.values() + .stream() + .map(entry -> new Similarity(entry.getId(), + EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding()))) + .filter(s -> s.score >= similarityThreshold) + .sorted(Comparator.comparingDouble(s -> s.score).reversed()) + .limit(topK) + .map(s -> store.get(s.key)) + .toList(); + } + return new ArrayList<>(); + } + + + @Data + public static class Similarity { + + private String key; + + private double score; + + public Similarity(String key, double score) { + this.key = key; + this.score = score; + } + + } + + public static class EmbeddingMath { + + private EmbeddingMath() { + throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); + } + + public static double cosineSimilarity(List vectorX, List vectorY) { + if (vectorX == null || vectorY == null) { + throw new RuntimeException("Vectors must not be null"); + } + if (vectorX.size() != vectorY.size()) { + throw new IllegalArgumentException("Vectors lengths must be equal"); + } + + double dotProduct = dotProduct(vectorX, vectorY); + double normX = norm(vectorX); + double normY = norm(vectorY); + + if (normX == 0 || normY == 0) { + throw new IllegalArgumentException("Vectors cannot have zero norm"); + } + + return dotProduct / (Math.sqrt(normX) * Math.sqrt(normY)); + } + + public static double dotProduct(List vectorX, List vectorY) { + if (vectorX.size() != vectorY.size()) { + throw new IllegalArgumentException("Vectors lengths must be equal"); + } + + double result = 0; + for (int i = 0; i < vectorX.size(); ++i) { + result += vectorX.get(i) * vectorY.get(i); + } + + return result; + } + + public static double norm(List vector) { + return dotProduct(vector, vector); + } + + } +} diff --git a/virtual-patient-web/pom.xml b/virtual-patient-web/pom.xml index ed1dcb10..57d7d249 100644 --- a/virtual-patient-web/pom.xml +++ b/virtual-patient-web/pom.xml @@ -110,7 +110,6 @@ test - diff --git a/virtual-patient-web/src/main/java/com/supervision/controller/TestController.java b/virtual-patient-web/src/main/java/com/supervision/controller/TestController.java index 0fe64c12..a45b98a3 100644 --- a/virtual-patient-web/src/main/java/com/supervision/controller/TestController.java +++ b/virtual-patient-web/src/main/java/com/supervision/controller/TestController.java @@ -5,18 +5,27 @@ import cn.hutool.http.HttpUtil; import cn.hutool.json.JSONObject; import cn.hutool.json.JSONUtil; import com.supervision.exception.BusinessException; +import com.supervision.model.AskTemplateQuestionLibrary; import com.supervision.model.ConfigPhysicalTool; +import com.supervision.service.AskTemplateQuestionLibraryService; import com.supervision.service.ConfigPhysicalToolService; import com.supervision.util.MinioUtil; import lombok.RequiredArgsConstructor; -import org.springframework.web.bind.annotation.*; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.document.Document; +import org.springframework.ai.vectorstore.RedisVectorStore; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; import org.springframework.web.multipart.MultipartFile; -import java.io.FileInputStream; -import java.io.IOException; import java.io.InputStream; import java.util.*; +import java.util.stream.Collectors; +@Slf4j @RestController @RequestMapping("test") @RequiredArgsConstructor @@ -24,6 +33,32 @@ public class TestController { private final ConfigPhysicalToolService configPhysicalToolService; + private final RedisVectorStore redisVectorStore; + + private final AskTemplateQuestionLibraryService askTemplateQuestionLibraryService; + + @GetMapping("testRedisVectorStore") + public void testRedisVectorStore() { + List list = askTemplateQuestionLibraryService.list(); + for (AskTemplateQuestionLibrary askTemplateQuestionLibrary : list) { + String description = askTemplateQuestionLibrary.getDescription(); + redisVectorStore.add(List.of(new Document(description, Map.of("type", "1", "medicalId", "222")))); + List question = askTemplateQuestionLibrary.getQuestion(); + for (String s : question) { + redisVectorStore.add(List.of(new Document(s, Map.of("type", "1", "medicalId", "222")))); + } +// log.info("处理完成:{}", description); + } + } + + @GetMapping("testQuestion") + public List> testQuestion(String question) { + List documents = redisVectorStore.similaritySearch(SearchRequest + .query(question) + .withTopK(5)); + return documents.stream().map(document -> Map.of("content", document.getContent(), "id", document.getId())).collect(Collectors.toList()); + } + @GetMapping("testExpireTime") public String testExpireTime() { return "OK"; @@ -36,18 +71,19 @@ public class TestController { /** * 数字人获取房间号 - * @param key 数字人ID,张总那边提供 + * + * @param key 数字人ID,张总那边提供 * @param token 前端每打开一个页面,就给一个新的UUID * @return 房间号 */ @GetMapping("queryRoomId") - public String queryRoomId(String key,String token){ + public String queryRoomId(String key, String token) { Map param = new HashMap<>(); - param.put("token",token); - param.put("key",key); + param.put("token", token); + param.put("key", key); String s = HttpUtil.get("https://digital-human.jd.com/getRoomId", param); JSONObject entries = JSONUtil.parseObj(s); - if (200 != entries.getInt("code")){ + if (200 != entries.getInt("code")) { throw new BusinessException(entries.getStr("data")); } return entries.getStr("data"); @@ -56,7 +92,8 @@ public class TestController { /** * 调用数字人接口进行语音的播放 - * @param text 需要播放的文本 + * + * @param text 需要播放的文本 * @param roomId 房间ID */ @GetMapping("shuZiRenSend") @@ -102,5 +139,4 @@ public class TestController { } - } diff --git a/virtual-patient-web/src/main/resources/application.yml b/virtual-patient-web/src/main/resources/application.yml index 33ae290a..7d38e32e 100644 --- a/virtual-patient-web/src/main/resources/application.yml +++ b/virtual-patient-web/src/main/resources/application.yml @@ -12,6 +12,8 @@ server: # 是否分配的直接内存 direct-buffers: true spring: + main: + allow-bean-definition-overriding: true servlet: multipart: max-file-size: 100MB @@ -39,6 +41,12 @@ spring: log-slow-sql: true # 是否开启 慢SQL 记录,默认false slow-sql-millis: 5000 # 慢 SQL 的标准,默认 3000,单位:毫秒 merge-sql: false # 合并多个连接池的监控数据,默认false + ai: + vectorstore: + redis: + uri: redis://:123456@192.168.10.137:6380 + index: 1 + prefix: 'vp:vector:' mybatis-plus: mapper-locations: classpath*:mapper/**/*.xml