增加向量化的库

pull/1/head
liu 11 months ago
parent ed7673de24
commit 3724211a12

@ -58,6 +58,7 @@
<okhttp.version>4.9.0</okhttp.version> <okhttp.version>4.9.0</okhttp.version>
<springboot.ai>1.0.3</springboot.ai> <springboot.ai>1.0.3</springboot.ai>
<httpclient5.version>5.3.1</httpclient5.version> <httpclient5.version>5.3.1</httpclient5.version>
<jedis.version>5.1.0</jedis.version>
</properties> </properties>
<dependencyManagement> <dependencyManagement>
@ -92,6 +93,7 @@
<scope>import</scope> <scope>import</scope>
</dependency> </dependency>
<!--spring-cloud--> <!--spring-cloud-->
<dependency> <dependency>
<groupId>org.springframework.cloud</groupId> <groupId>org.springframework.cloud</groupId>
@ -181,6 +183,12 @@
<version>${okhttp.version}</version> <version>${okhttp.version}</version>
</dependency> </dependency>
<dependency>
<groupId>redis.clients</groupId>
<artifactId>jedis</artifactId>
<version>${jedis.version}</version>
</dependency>
</dependencies> </dependencies>
</dependencyManagement> </dependencyManagement>

@ -48,20 +48,26 @@
<artifactId>httpclient5</artifactId> <artifactId>httpclient5</artifactId>
</dependency> </dependency>
<dependency>
<groupId>io.springboot.ai</groupId>
<artifactId>spring-ai-redis</artifactId>
</dependency>
<dependency> <dependency>
<groupId>org.springframework.cloud</groupId> <groupId>org.springframework.cloud</groupId>
<artifactId>spring-cloud-starter-bootstrap</artifactId> <artifactId>spring-cloud-starter-bootstrap</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.springframework.boot</groupId> <groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId> <artifactId>spring-boot-starter-data-redis</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>redis.clients</groupId> <groupId>redis.clients</groupId>
<artifactId>jedis</artifactId> <artifactId>jedis</artifactId>
</dependency> </dependency>
<!--redis分布式锁 https://gitee.com/baomidou/lock4j --> <!--redis分布式锁 https://gitee.com/baomidou/lock4j -->
<dependency> <dependency>
<groupId>com.baomidou</groupId> <groupId>com.baomidou</groupId>

@ -0,0 +1,53 @@
package com.supervision.config;
import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.*;
import org.springframework.util.Assert;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
public class VectorEmbeddingClient extends AbstractEmbeddingClient {
private final Logger logger = LoggerFactory.getLogger(getClass());
@Override
public List<Double> embed(Document document) {
List<List<Double>> list = this.call(new EmbeddingRequest(List.of(document.getContent()), EmbeddingOptions.EMPTY))
.getResults()
.stream()
.map(Embedding::getOutput)
.toList();
return list.iterator().next();
}
@Override
public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
if (request.getInstructions().size() != 1) {
logger.warn(
"Ollama Embedding does not support batch embedding. Will make multiple API calls to embed(Document)");
}
List<List<Double>> embeddingList = new ArrayList<>();
for (String inputContent : request.getInstructions()) {
// TODO 这里需要吧inputContent转化为向量数据
String post = HttpUtil.post("http://192.168.10.42:8000/embeddings/", JSONUtil.toJsonStr(Map.of("text", inputContent)));
String o = JSONUtil.parseObj(post).getStr("embeddings");
List<Double> embedding = JSONUtil.toList(o, Double.class);
//List<Double> embedding = List.of(1.0, 2.0, 3.0);
embeddingList.add(embedding);
}
var indexCounter = new AtomicInteger(0);
List<Embedding> embeddings = embeddingList.stream()
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
.toList();
return new EmbeddingResponse(embeddings);
}
}

@ -0,0 +1,26 @@
package com.supervision.config;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class VectorSimilarityConfiguration {
@Bean
public VectorEmbeddingClient vectorEmbeddingClient() {
return new VectorEmbeddingClient();
}
@Bean
public RedisVectorStore redisVectorStore(VectorEmbeddingClient vectorEmbeddingClient) {
RedisVectorStore.RedisVectorStoreConfig config = RedisVectorStore.RedisVectorStoreConfig.builder()
.withURI("redis://:123456@192.168.10.137:6380")
// 定义搜索过滤器使用的元数据字段
.withMetadataFields(
RedisVectorStore.MetadataField.tag("medicalId"),
RedisVectorStore.MetadataField.tag("type"))
.build();
return new RedisVectorStore(config, vectorEmbeddingClient);
}
}

@ -0,0 +1,123 @@
package com.supervision.util;
import cn.hutool.core.map.MapUtil;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
@Slf4j
public class VectorSimilarityUtil {
private static final Map<String, Map<String, Document>> storeMap = new ConcurrentHashMap<>();
private static final Double similarityThreshold = 0.5;
private static final Integer topK = 5;
public static void add(String storeId, Document document) {
// TODO 需要序列化成为索引
List<Double> embedding = null;
document.setEmbedding(embedding);
Map<String, Document> store = storeMap.computeIfAbsent(storeId, k -> new ConcurrentHashMap<>());
store.put(document.getId(), document);
}
public static void add(String storeId, List<Document> documents) {
for (Document document : documents) {
// TODO 需要序列化成为索引
List<Double> embedding = null;
document.setEmbedding(embedding);
Map<String, Document> store = storeMap.computeIfAbsent(storeId, k -> new ConcurrentHashMap<>());
store.put(document.getId(), document);
}
}
public static Optional<Boolean> delete(String storeId, List<String> idList) {
if (!storeMap.containsKey(storeId)) {
for (String id : idList) {
storeMap.get(storeId).remove(id);
}
}
return Optional.of(true);
}
public static List<Document> similaritySearch(String storeId, List<Double> userQueryEmbedding) {
Map<String, Document> store = storeMap.get(storeId);
if (MapUtil.isNotEmpty(store)) {
return store.values()
.stream()
.map(entry -> new Similarity(entry.getId(),
EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding())))
.filter(s -> s.score >= similarityThreshold)
.sorted(Comparator.<Similarity>comparingDouble(s -> s.score).reversed())
.limit(topK)
.map(s -> store.get(s.key))
.toList();
}
return new ArrayList<>();
}
@Data
public static class Similarity {
private String key;
private double score;
public Similarity(String key, double score) {
this.key = key;
this.score = score;
}
}
public static class EmbeddingMath {
private EmbeddingMath() {
throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
}
public static double cosineSimilarity(List<Double> vectorX, List<Double> vectorY) {
if (vectorX == null || vectorY == null) {
throw new RuntimeException("Vectors must not be null");
}
if (vectorX.size() != vectorY.size()) {
throw new IllegalArgumentException("Vectors lengths must be equal");
}
double dotProduct = dotProduct(vectorX, vectorY);
double normX = norm(vectorX);
double normY = norm(vectorY);
if (normX == 0 || normY == 0) {
throw new IllegalArgumentException("Vectors cannot have zero norm");
}
return dotProduct / (Math.sqrt(normX) * Math.sqrt(normY));
}
public static double dotProduct(List<Double> vectorX, List<Double> vectorY) {
if (vectorX.size() != vectorY.size()) {
throw new IllegalArgumentException("Vectors lengths must be equal");
}
double result = 0;
for (int i = 0; i < vectorX.size(); ++i) {
result += vectorX.get(i) * vectorY.get(i);
}
return result;
}
public static double norm(List<Double> vector) {
return dotProduct(vector, vector);
}
}
}

@ -110,7 +110,6 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
</dependencies> </dependencies>
<build> <build>

@ -5,18 +5,27 @@ import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONObject; import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
import com.supervision.exception.BusinessException; import com.supervision.exception.BusinessException;
import com.supervision.model.AskTemplateQuestionLibrary;
import com.supervision.model.ConfigPhysicalTool; import com.supervision.model.ConfigPhysicalTool;
import com.supervision.service.AskTemplateQuestionLibraryService;
import com.supervision.service.ConfigPhysicalToolService; import com.supervision.service.ConfigPhysicalToolService;
import com.supervision.util.MinioUtil; import com.supervision.util.MinioUtil;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.*; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartFile;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.*; import java.util.*;
import java.util.stream.Collectors;
@Slf4j
@RestController @RestController
@RequestMapping("test") @RequestMapping("test")
@RequiredArgsConstructor @RequiredArgsConstructor
@ -24,6 +33,32 @@ public class TestController {
private final ConfigPhysicalToolService configPhysicalToolService; private final ConfigPhysicalToolService configPhysicalToolService;
private final RedisVectorStore redisVectorStore;
private final AskTemplateQuestionLibraryService askTemplateQuestionLibraryService;
@GetMapping("testRedisVectorStore")
public void testRedisVectorStore() {
List<AskTemplateQuestionLibrary> list = askTemplateQuestionLibraryService.list();
for (AskTemplateQuestionLibrary askTemplateQuestionLibrary : list) {
String description = askTemplateQuestionLibrary.getDescription();
redisVectorStore.add(List.of(new Document(description, Map.of("type", "1", "medicalId", "222"))));
List<String> question = askTemplateQuestionLibrary.getQuestion();
for (String s : question) {
redisVectorStore.add(List.of(new Document(s, Map.of("type", "1", "medicalId", "222"))));
}
// log.info("处理完成:{}", description);
}
}
@GetMapping("testQuestion")
public List<Map<String, String>> testQuestion(String question) {
List<Document> documents = redisVectorStore.similaritySearch(SearchRequest
.query(question)
.withTopK(5));
return documents.stream().map(document -> Map.of("content", document.getContent(), "id", document.getId())).collect(Collectors.toList());
}
@GetMapping("testExpireTime") @GetMapping("testExpireTime")
public String testExpireTime() { public String testExpireTime() {
return "OK"; return "OK";
@ -36,18 +71,19 @@ public class TestController {
/** /**
* *
* @param key ID, *
* @param key ID,
* @param token ,UUID * @param token ,UUID
* @return * @return
*/ */
@GetMapping("queryRoomId") @GetMapping("queryRoomId")
public String queryRoomId(String key,String token){ public String queryRoomId(String key, String token) {
Map<String, Object> param = new HashMap<>(); Map<String, Object> param = new HashMap<>();
param.put("token",token); param.put("token", token);
param.put("key",key); param.put("key", key);
String s = HttpUtil.get("https://digital-human.jd.com/getRoomId", param); String s = HttpUtil.get("https://digital-human.jd.com/getRoomId", param);
JSONObject entries = JSONUtil.parseObj(s); JSONObject entries = JSONUtil.parseObj(s);
if (200 != entries.getInt("code")){ if (200 != entries.getInt("code")) {
throw new BusinessException(entries.getStr("data")); throw new BusinessException(entries.getStr("data"));
} }
return entries.getStr("data"); return entries.getStr("data");
@ -56,7 +92,8 @@ public class TestController {
/** /**
* *
* @param text *
* @param text
* @param roomId ID * @param roomId ID
*/ */
@GetMapping("shuZiRenSend") @GetMapping("shuZiRenSend")
@ -102,5 +139,4 @@ public class TestController {
} }
} }

@ -12,6 +12,8 @@ server:
# 是否分配的直接内存 # 是否分配的直接内存
direct-buffers: true direct-buffers: true
spring: spring:
main:
allow-bean-definition-overriding: true
servlet: servlet:
multipart: multipart:
max-file-size: 100MB max-file-size: 100MB
@ -39,6 +41,12 @@ spring:
log-slow-sql: true # 是否开启 慢SQL 记录默认false log-slow-sql: true # 是否开启 慢SQL 记录默认false
slow-sql-millis: 5000 # 慢 SQL 的标准,默认 3000单位毫秒 slow-sql-millis: 5000 # 慢 SQL 的标准,默认 3000单位毫秒
merge-sql: false # 合并多个连接池的监控数据默认false merge-sql: false # 合并多个连接池的监控数据默认false
ai:
vectorstore:
redis:
uri: redis://:123456@192.168.10.137:6380
index: 1
prefix: 'vp:vector:'
mybatis-plus: mybatis-plus:
mapper-locations: classpath*:mapper/**/*.xml mapper-locations: classpath*:mapper/**/*.xml

Loading…
Cancel
Save