增加向量化的库

pull/1/head
liu 11 months ago
parent 343bfe57fb
commit 786110d130

@ -2,10 +2,11 @@ package com.supervision.config;
import cn.hutool.http.HttpUtil; import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
import org.slf4j.Logger; import lombok.Data;
import org.slf4j.LoggerFactory; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document; import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.*; import org.springframework.ai.embedding.*;
import org.springframework.core.SpringProperties;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import java.util.ArrayList; import java.util.ArrayList;
@ -13,9 +14,9 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
@Slf4j
public class VectorEmbeddingClient extends AbstractEmbeddingClient { public class VectorEmbeddingClient extends AbstractEmbeddingClient {
private final Logger logger = LoggerFactory.getLogger(getClass());
@Override @Override
public List<Double> embed(Document document) { public List<Double> embed(Document document) {
@ -30,17 +31,13 @@ public class VectorEmbeddingClient extends AbstractEmbeddingClient {
@Override @Override
public EmbeddingResponse call(EmbeddingRequest request) { public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!"); Assert.notEmpty(request.getInstructions(), "At least one text is required!");
if (request.getInstructions().size() != 1) {
logger.warn(
"Ollama Embedding does not support batch embedding. Will make multiple API calls to embed(Document)");
}
List<List<Double>> embeddingList = new ArrayList<>(); List<List<Double>> embeddingList = new ArrayList<>();
String embeddingUrl = SpringProperties.getProperty("embeddingUrl");
for (String inputContent : request.getInstructions()) { for (String inputContent : request.getInstructions()) {
// TODO 这里需要吧inputContent转化为向量数据 // 这里需要吧inputContent转化为向量数据
String post = HttpUtil.post("http://192.168.10.42:8000/embeddings/", JSONUtil.toJsonStr(Map.of("text", inputContent))); String post = HttpUtil.post(embeddingUrl, JSONUtil.toJsonStr(Map.of("text", inputContent)));
String o = JSONUtil.parseObj(post).getStr("embeddings"); EmbeddingData bean = JSONUtil.toBean(post, EmbeddingData.class);
List<Double> embedding = JSONUtil.toList(o, Double.class); embeddingList.add(bean.embeddings);
embeddingList.add(embedding);
} }
var indexCounter = new AtomicInteger(0); var indexCounter = new AtomicInteger(0);
List<Embedding> embeddings = embeddingList.stream() List<Embedding> embeddings = embeddingList.stream()
@ -48,4 +45,9 @@ public class VectorEmbeddingClient extends AbstractEmbeddingClient {
.toList(); .toList();
return new EmbeddingResponse(embeddings); return new EmbeddingResponse(embeddings);
} }
@Data
private static class EmbeddingData {
private List<Double> embeddings;
}
} }

@ -3,6 +3,7 @@ package com.supervision.config;
import org.springframework.ai.vectorstore.RedisVectorStore; import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.core.SpringProperties;
@Configuration @Configuration
public class VectorSimilarityConfiguration { public class VectorSimilarityConfiguration {
@ -14,11 +15,14 @@ public class VectorSimilarityConfiguration {
@Bean @Bean
public RedisVectorStore redisVectorStore(VectorEmbeddingClient vectorEmbeddingClient) { public RedisVectorStore redisVectorStore(VectorEmbeddingClient vectorEmbeddingClient) {
String property = SpringProperties.getProperty("spring.ai.vectorstore.redis.uri");
RedisVectorStore.RedisVectorStoreConfig config = RedisVectorStore.RedisVectorStoreConfig.builder() RedisVectorStore.RedisVectorStoreConfig config = RedisVectorStore.RedisVectorStoreConfig.builder()
.withURI("redis://:123456@192.168.10.137:6380") .withURI(property)
// 定义搜索过滤器使用的元数据字段 // 定义搜索过滤器使用的元数据字段
.withMetadataFields( .withMetadataFields(
RedisVectorStore.MetadataField.tag("medicalId"), // 问题的ID
RedisVectorStore.MetadataField.tag("questionId"),
// 类型 1标准问 2相似问 3自定义
RedisVectorStore.MetadataField.tag("type")) RedisVectorStore.MetadataField.tag("type"))
.build(); .build();
return new RedisVectorStore(config, vectorEmbeddingClient); return new RedisVectorStore(config, vectorEmbeddingClient);

@ -5,8 +5,10 @@ import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.cloud.client.discovery.EnableDiscoveryClient; import org.springframework.cloud.client.discovery.EnableDiscoveryClient;
import org.springframework.cloud.openfeign.EnableFeignClients; import org.springframework.cloud.openfeign.EnableFeignClients;
import org.springframework.context.annotation.EnableAspectJAutoProxy;
import org.springframework.scheduling.annotation.EnableScheduling; import org.springframework.scheduling.annotation.EnableScheduling;
@EnableAspectJAutoProxy(proxyTargetClass = true)
@SpringBootApplication @SpringBootApplication
@MapperScan(basePackages = {"com.supervision.**.mapper"}) @MapperScan(basePackages = {"com.supervision.**.mapper"})
@EnableScheduling @EnableScheduling

@ -49,6 +49,7 @@ public class TestController {
private final OllamaChatClient chatClient; private final OllamaChatClient chatClient;
@GetMapping("testMatchQuestion") @GetMapping("testMatchQuestion")
public String test(String question) { public String test(String question) {
String template = """ String template = """
@ -70,8 +71,15 @@ public class TestController {
return call.getResult().getOutput().getContent(); return call.getResult().getOutput().getContent();
} }
@GetMapping("testJedis")
public void testJedis() {
Map<String, Object> stringObjectMap = redisVectorStore.getJedis().ftConfigGet("spring-ai-index", "");
System.out.println(1);
}
@GetMapping("testRedisVectorStore") @GetMapping("testRedisVectorStore")
public void testRedisVectorStore() { public void testRedisVectorStore() {
List<AskTemplateQuestionLibrary> list = askTemplateQuestionLibraryService.list(); List<AskTemplateQuestionLibrary> list = askTemplateQuestionLibraryService.list();
for (AskTemplateQuestionLibrary askTemplateQuestionLibrary : list) { for (AskTemplateQuestionLibrary askTemplateQuestionLibrary : list) {
String description = askTemplateQuestionLibrary.getDescription(); String description = askTemplateQuestionLibrary.getDescription();
@ -84,14 +92,20 @@ public class TestController {
} }
} }
@GetMapping("testQuestion") @GetMapping("testQuestion")
public List<Map<String, String>> testQuestion(String question) { public List<Document> testQuestion(String question) {
List<Document> documents = redisVectorStore.similaritySearch(SearchRequest List<Document> documents = redisVectorStore.similaritySearch(SearchRequest
.query(question) .query(question)
.withTopK(5)); .withTopK(5).withSimilarityThreshold(0.5));
return documents.stream().map(document -> Map.of("content", document.getContent(), "id", document.getId())).collect(Collectors.toList()); // return documents.stream().map(document -> Map.of("content", document.getContent(), "id", document.getId())).collect(Collectors.toList());
documents.forEach(e -> {
double v = Double.parseDouble(String.valueOf(e.getMetadata().get("vector_score")));
// 降序
e.getMetadata().put("originalScore", 1 - (v * 2));
});
return documents;
} }
@GetMapping("testExpireTime") @GetMapping("testExpireTime")
public String testExpireTime() { public String testExpireTime() {
return "OK"; return "OK";

Loading…
Cancel
Save