From 786110d13021720014ffe606512bb3d2d02b0580 Mon Sep 17 00:00:00 2001 From: liu Date: Wed, 5 Jun 2024 15:50:50 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=90=91=E9=87=8F=E5=8C=96?= =?UTF-8?q?=E7=9A=84=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../config/VectorEmbeddingClient.java | 26 ++++++++++--------- .../config/VectorSimilarityConfiguration.java | 8 ++++-- .../VirtualPatientApplication.java | 2 ++ .../controller/TestController.java | 22 +++++++++++++--- 4 files changed, 40 insertions(+), 18 deletions(-) diff --git a/virtual-patient-common/src/main/java/com/supervision/config/VectorEmbeddingClient.java b/virtual-patient-common/src/main/java/com/supervision/config/VectorEmbeddingClient.java index e2af3491..3d35b843 100644 --- a/virtual-patient-common/src/main/java/com/supervision/config/VectorEmbeddingClient.java +++ b/virtual-patient-common/src/main/java/com/supervision/config/VectorEmbeddingClient.java @@ -2,10 +2,11 @@ package com.supervision.config; import cn.hutool.http.HttpUtil; import cn.hutool.json.JSONUtil; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.*; +import org.springframework.core.SpringProperties; import org.springframework.util.Assert; import java.util.ArrayList; @@ -13,9 +14,9 @@ import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; +@Slf4j public class VectorEmbeddingClient extends AbstractEmbeddingClient { - private final Logger logger = LoggerFactory.getLogger(getClass()); @Override public List embed(Document document) { @@ -30,17 +31,13 @@ public class VectorEmbeddingClient extends AbstractEmbeddingClient { @Override public EmbeddingResponse call(EmbeddingRequest request) { Assert.notEmpty(request.getInstructions(), "At least one text is required!"); - if (request.getInstructions().size() != 1) { - logger.warn( - "Ollama Embedding does not support batch embedding. Will make multiple API calls to embed(Document)"); - } List> embeddingList = new ArrayList<>(); + String embeddingUrl = SpringProperties.getProperty("embeddingUrl"); for (String inputContent : request.getInstructions()) { - // TODO 这里需要吧inputContent转化为向量数据 - String post = HttpUtil.post("http://192.168.10.42:8000/embeddings/", JSONUtil.toJsonStr(Map.of("text", inputContent))); - String o = JSONUtil.parseObj(post).getStr("embeddings"); - List embedding = JSONUtil.toList(o, Double.class); - embeddingList.add(embedding); + // 这里需要吧inputContent转化为向量数据 + String post = HttpUtil.post(embeddingUrl, JSONUtil.toJsonStr(Map.of("text", inputContent))); + EmbeddingData bean = JSONUtil.toBean(post, EmbeddingData.class); + embeddingList.add(bean.embeddings); } var indexCounter = new AtomicInteger(0); List embeddings = embeddingList.stream() @@ -48,4 +45,9 @@ public class VectorEmbeddingClient extends AbstractEmbeddingClient { .toList(); return new EmbeddingResponse(embeddings); } + + @Data + private static class EmbeddingData { + private List embeddings; + } } diff --git a/virtual-patient-common/src/main/java/com/supervision/config/VectorSimilarityConfiguration.java b/virtual-patient-common/src/main/java/com/supervision/config/VectorSimilarityConfiguration.java index d2ec3699..d5701e9c 100644 --- a/virtual-patient-common/src/main/java/com/supervision/config/VectorSimilarityConfiguration.java +++ b/virtual-patient-common/src/main/java/com/supervision/config/VectorSimilarityConfiguration.java @@ -3,6 +3,7 @@ package com.supervision.config; import org.springframework.ai.vectorstore.RedisVectorStore; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.core.SpringProperties; @Configuration public class VectorSimilarityConfiguration { @@ -14,11 +15,14 @@ public class VectorSimilarityConfiguration { @Bean public RedisVectorStore redisVectorStore(VectorEmbeddingClient vectorEmbeddingClient) { + String property = SpringProperties.getProperty("spring.ai.vectorstore.redis.uri"); RedisVectorStore.RedisVectorStoreConfig config = RedisVectorStore.RedisVectorStoreConfig.builder() - .withURI("redis://:123456@192.168.10.137:6380") + .withURI(property) // 定义搜索过滤器使用的元数据字段 .withMetadataFields( - RedisVectorStore.MetadataField.tag("medicalId"), + // 问题的ID + RedisVectorStore.MetadataField.tag("questionId"), + // 类型 1标准问 2相似问 3自定义 RedisVectorStore.MetadataField.tag("type")) .build(); return new RedisVectorStore(config, vectorEmbeddingClient); diff --git a/virtual-patient-web/src/main/java/com/supervision/VirtualPatientApplication.java b/virtual-patient-web/src/main/java/com/supervision/VirtualPatientApplication.java index 1d74a13b..4a44d419 100644 --- a/virtual-patient-web/src/main/java/com/supervision/VirtualPatientApplication.java +++ b/virtual-patient-web/src/main/java/com/supervision/VirtualPatientApplication.java @@ -5,8 +5,10 @@ import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.cloud.client.discovery.EnableDiscoveryClient; import org.springframework.cloud.openfeign.EnableFeignClients; +import org.springframework.context.annotation.EnableAspectJAutoProxy; import org.springframework.scheduling.annotation.EnableScheduling; +@EnableAspectJAutoProxy(proxyTargetClass = true) @SpringBootApplication @MapperScan(basePackages = {"com.supervision.**.mapper"}) @EnableScheduling diff --git a/virtual-patient-web/src/main/java/com/supervision/controller/TestController.java b/virtual-patient-web/src/main/java/com/supervision/controller/TestController.java index b4a5a276..6b94464c 100644 --- a/virtual-patient-web/src/main/java/com/supervision/controller/TestController.java +++ b/virtual-patient-web/src/main/java/com/supervision/controller/TestController.java @@ -49,6 +49,7 @@ public class TestController { private final OllamaChatClient chatClient; + @GetMapping("testMatchQuestion") public String test(String question) { String template = """ @@ -70,8 +71,15 @@ public class TestController { return call.getResult().getOutput().getContent(); } + @GetMapping("testJedis") + public void testJedis() { + Map stringObjectMap = redisVectorStore.getJedis().ftConfigGet("spring-ai-index", ""); + System.out.println(1); + } + @GetMapping("testRedisVectorStore") public void testRedisVectorStore() { + List list = askTemplateQuestionLibraryService.list(); for (AskTemplateQuestionLibrary askTemplateQuestionLibrary : list) { String description = askTemplateQuestionLibrary.getDescription(); @@ -84,14 +92,20 @@ public class TestController { } } + @GetMapping("testQuestion") - public List> testQuestion(String question) { + public List testQuestion(String question) { List documents = redisVectorStore.similaritySearch(SearchRequest .query(question) - .withTopK(5)); - return documents.stream().map(document -> Map.of("content", document.getContent(), "id", document.getId())).collect(Collectors.toList()); + .withTopK(5).withSimilarityThreshold(0.5)); +// return documents.stream().map(document -> Map.of("content", document.getContent(), "id", document.getId())).collect(Collectors.toList()); + documents.forEach(e -> { + double v = Double.parseDouble(String.valueOf(e.getMetadata().get("vector_score"))); + // 降序 + e.getMetadata().put("originalScore", 1 - (v * 2)); + }); + return documents; } - @GetMapping("testExpireTime") public String testExpireTime() { return "OK";