RAG代码提交

dev_1.0.0^2
liu 8 months ago
parent 665ee4af13
commit 4fb5164193

@ -35,7 +35,10 @@
<!-- <artifactId>spring-ai-ollama-spring-boot-starter</artifactId>-->
<!-- </dependency>-->
<dependency>
<groupId>io.springboot.ai</groupId>
<artifactId>spring-ai-redis</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>

@ -0,0 +1,16 @@
package som.supervision.knowsub.config;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
@Data
@ConfigurationProperties(prefix = "vector.redis")
public class RedisVectorProperties {
private String uri;
private String indexName;
private String prefix;
}

@ -0,0 +1,35 @@
package som.supervision.knowsub.config;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.util.Assert;
@Configuration
@EnableConfigurationProperties({RedisVectorProperties.class, EmbeddingProperties.class})
public class RedisVectorStoreConfig {
@Bean
@ConditionalOnProperty(prefix = "embedding", name = "url")
public VectorEmbeddingClient vectorEmbeddingClient(EmbeddingProperties embeddingProperties) {
Assert.notNull(embeddingProperties.getUrl(), "配置文件embedding:url未找到");
return new VectorEmbeddingClient(embeddingProperties.getUrl());
}
@Bean
@ConditionalOnProperty(prefix = "vector.redis", name = "uri")
public RedisVectorStore redisVectorStore(VectorEmbeddingClient vectorEmbeddingClient, RedisVectorProperties redisVectorProperties) {
Assert.notNull(redisVectorProperties.getUri(), "配置文件vector.redis.uri未找到");
RedisVectorStore.RedisVectorStoreConfig config = RedisVectorStore.RedisVectorStoreConfig.builder()
.withURI(redisVectorProperties.getUri())
.withPrefix(redisVectorProperties.getPrefix())
.withIndexName(redisVectorProperties.getIndexName())
// 定义搜索过滤器使用的元数据字段(!!!!!!!!千万重要,数据类型一定要用字符串,否则会导致查询不到!!!!!!!!)
.withMetadataFields(
RedisVectorStore.MetadataField.tag("fileName"))
.build();
return new RedisVectorStore(config, vectorEmbeddingClient);
}
}

@ -0,0 +1,57 @@
package som.supervision.knowsub.config;
import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONUtil;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.*;
import org.springframework.util.Assert;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
@Slf4j
public class VectorEmbeddingClient extends AbstractEmbeddingClient {
private final String embeddingUrl;
public VectorEmbeddingClient(String embeddingUrl) {
this.embeddingUrl = embeddingUrl;
}
@Override
public List<Double> embed(Document document) {
List<List<Double>> list = this.call(new EmbeddingRequest(List.of(document.getContent()), EmbeddingOptions.EMPTY))
.getResults()
.stream()
.map(Embedding::getOutput)
.toList();
return list.iterator().next();
}
@Override
public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
List<List<Double>> embeddingList = new ArrayList<>();
for (String inputContent : request.getInstructions()) {
// 这里需要吧inputContent转化为向量数据
String post = HttpUtil.post(embeddingUrl, JSONUtil.toJsonStr(Map.of("text", inputContent)));
EmbeddingData bean = JSONUtil.toBean(post, EmbeddingData.class);
embeddingList.add(bean.embeddings);
}
var indexCounter = new AtomicInteger(0);
List<Embedding> embeddings = embeddingList.stream()
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
.toList();
return new EmbeddingResponse(embeddings);
}
@Data
private static class EmbeddingData {
private List<Double> embeddings;
}
}

@ -23,4 +23,10 @@ public class KnowledgeEtlController {
public void knowledgeEtl(@RequestParam("files") MultipartFile[] files) {
knowledgeEtlService.knowledgeEtl(files);
}
@Operation(summary = "Redis对知识进行ETL")
@PostMapping("redisKnowledgeEtl")
public void redisKnowledgeEtl(@RequestParam("files") MultipartFile[] files) {
knowledgeEtlService.redisKnowledgeEtl(files);
}
}

@ -8,4 +8,6 @@ import java.io.IOException;
public interface KnowledgeEtlService {
void knowledgeEtl(MultipartFile[] files);
void redisKnowledgeEtl(MultipartFile[] files);
}

@ -6,6 +6,7 @@ import org.springframework.ai.document.Document;
import org.springframework.ai.reader.tika.TikaDocumentReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.ElasticsearchVectorStore;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.core.io.InputStreamResource;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
@ -14,6 +15,7 @@ import som.supervision.knowsub.service.KnowledgeEtlService;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
@Slf4j
@Service
@ -22,6 +24,8 @@ public class KnowledgeEtlServiceImpl implements KnowledgeEtlService {
private final ElasticsearchVectorStore elasticsearchVectorStore;
private final RedisVectorStore redisVectorStore;
/**
* <a href="https://zhuanlan.zhihu.com/p/703705663"/>
*
@ -34,8 +38,11 @@ public class KnowledgeEtlServiceImpl implements KnowledgeEtlService {
List<Document> documents = tikaDocumentReader.read();
log.info("{} 切分完成,开始进行chunk分割", fileName);
// 然后切分为chunk
TokenTextSplitter tokenTextSplitter = new TokenTextSplitter(200, 100, 10, 1000, true);
TokenTextSplitter tokenTextSplitter = new TokenTextSplitter(500, 250, 10, 1000, true);
List<Document> apply = tokenTextSplitter.apply(documents);
for (Document document : apply) {
document.getMetadata().put("fileName", fileName);
}
log.info("{} 切分完成,开始进行保存到向量库中", fileName);
// 保存到向量数据库中
elasticsearchVectorStore.accept(apply);
@ -46,13 +53,47 @@ public class KnowledgeEtlServiceImpl implements KnowledgeEtlService {
@Override
public void knowledgeEtl(MultipartFile[] files) {
AtomicInteger atomicInteger = new AtomicInteger(1);
for (MultipartFile file : files) {
try {
loadFile(file.getInputStream(), file.getOriginalFilename());
} catch (Exception e) {
log.error("{}文件处理失败", file.getOriginalFilename(), e);
}
int andIncrement = atomicInteger.getAndIncrement();
log.info("处理第{}个文件,剩余:{}个", andIncrement, files.length - andIncrement + 1);
}
log.info("文件处理结束");
}
@Override
public void redisKnowledgeEtl(MultipartFile[] files) {
AtomicInteger atomicInteger = new AtomicInteger(1);
for (MultipartFile file : files) {
try {
// 首先使用tika进行文件切分操作
log.info("{} 进行内容切分", file.getOriginalFilename());
TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(new InputStreamResource(file.getInputStream()));
List<Document> documents = tikaDocumentReader.read();
log.info("{} 切分完成,开始进行chunk分割", file.getOriginalFilename());
// 然后切分为chunk
TokenTextSplitter tokenTextSplitter = new TokenTextSplitter(500, 250, 10, 1000, true);
List<Document> apply = tokenTextSplitter.apply(documents);
for (Document document : apply) {
document.getMetadata().put("fileName", file.getOriginalFilename());
}
log.info("{} 切分完成,开始进行保存到向量库中", file.getOriginalFilename());
// 保存到向量数据库中
redisVectorStore.accept(apply);
log.info("{} 保存完成", file.getOriginalFilename());
} catch (Exception e) {
log.error("{}文件处理失败", file.getOriginalFilename(), e);
}
int andIncrement = atomicInteger.getAndIncrement();
log.info("处理第{}个文件,剩余:{}个", andIncrement, files.length - andIncrement + 1);
}
log.info("文件处理结束");
}
}

@ -85,4 +85,10 @@ user:
# uris: http://192.168.10.137:9200
embedding:
url: http://192.168.10.137:8711/embeddings/
url: http://192.168.10.137:8711/embeddings/
vector:
redis:
uri: redis://:123456@192.168.10.137:6380
indexName: 'know-sub-rag-store'
prefix: 'know-sub-rag-store:'

@ -30,6 +30,11 @@
<version>4.5.13</version>
</dependency>
<dependency>
<groupId>org.apache.poi</groupId>
<artifactId>poi-ooxml</artifactId>
</dependency>
<!-- <dependency>-->
<!-- <groupId>io.springboot.ai</groupId>-->
<!-- <artifactId>spring-ai-ollama-spring-boot-starter</artifactId>-->
@ -40,6 +45,11 @@
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>io.springboot.ai</groupId>
<artifactId>spring-ai-redis</artifactId>
</dependency>
<dependency>

@ -0,0 +1,16 @@
package com.supervision.knowsub.config;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
@Data
@ConfigurationProperties(prefix = "vector.redis")
public class RedisVectorProperties {
private String uri;
private String indexName;
private String prefix;
}

@ -0,0 +1,35 @@
package com.supervision.knowsub.config;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.util.Assert;
@Configuration
@EnableConfigurationProperties({RedisVectorProperties.class, EmbeddingProperties.class})
public class RedisVectorStoreConfig {
@Bean
@ConditionalOnProperty(prefix = "embedding", name = "url")
public VectorEmbeddingClient vectorEmbeddingClient(EmbeddingProperties embeddingProperties) {
Assert.notNull(embeddingProperties.getUrl(), "配置文件embedding:url未找到");
return new VectorEmbeddingClient(embeddingProperties.getUrl());
}
@Bean
@ConditionalOnProperty(prefix = "vector.redis", name = "uri")
public RedisVectorStore redisVectorStore(VectorEmbeddingClient vectorEmbeddingClient, RedisVectorProperties redisVectorProperties) {
Assert.notNull(redisVectorProperties.getUri(), "配置文件vector.redis.uri未找到");
RedisVectorStore.RedisVectorStoreConfig config = RedisVectorStore.RedisVectorStoreConfig.builder()
.withURI(redisVectorProperties.getUri())
.withPrefix(redisVectorProperties.getPrefix())
.withIndexName(redisVectorProperties.getIndexName())
// 定义搜索过滤器使用的元数据字段(!!!!!!!!千万重要,数据类型一定要用字符串,否则会导致查询不到!!!!!!!!)
.withMetadataFields(
RedisVectorStore.MetadataField.tag("fileName"))
.build();
return new RedisVectorStore(config, vectorEmbeddingClient);
}
}

@ -0,0 +1,57 @@
package com.supervision.knowsub.config;
import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONUtil;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.*;
import org.springframework.util.Assert;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
@Slf4j
public class VectorEmbeddingClient extends AbstractEmbeddingClient {
private final String embeddingUrl;
public VectorEmbeddingClient(String embeddingUrl) {
this.embeddingUrl = embeddingUrl;
}
@Override
public List<Double> embed(Document document) {
List<List<Double>> list = this.call(new EmbeddingRequest(List.of(document.getContent()), EmbeddingOptions.EMPTY))
.getResults()
.stream()
.map(Embedding::getOutput)
.toList();
return list.iterator().next();
}
@Override
public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
List<List<Double>> embeddingList = new ArrayList<>();
for (String inputContent : request.getInstructions()) {
// 这里需要吧inputContent转化为向量数据
String post = HttpUtil.post(embeddingUrl, JSONUtil.toJsonStr(Map.of("text", inputContent)));
EmbeddingData bean = JSONUtil.toBean(post, EmbeddingData.class);
embeddingList.add(bean.embeddings);
}
var indexCounter = new AtomicInteger(0);
List<Embedding> embeddings = embeddingList.stream()
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
.toList();
return new EmbeddingResponse(embeddings);
}
@Data
private static class EmbeddingData {
private List<Double> embeddings;
}
}

@ -1,6 +1,7 @@
package com.supervision.knowsub.controller;
import com.supervision.knowsub.service.RagService;
import com.supervision.knowsub.vo.RagResVO;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
@ -16,12 +17,15 @@ public class RagController {
private final RagService ragService;
@Operation(summary = "问答")
@GetMapping("question")
public String ask(String question) {
return ragService.ask(question);
@GetMapping("esAsk")
public RagResVO esAsk(String question) {
return ragService.esAsk(question);
}
@Operation(summary = "问答")
@GetMapping("redisAsk")
public void redisAsk(String question) {
ragService.redisAsk(question);
}
}

@ -0,0 +1,56 @@
package com.supervision.knowsub.controller;
import cn.hutool.json.JSONUtil;
import cn.hutool.poi.excel.ExcelReader;
import cn.hutool.poi.excel.ExcelUtil;
import cn.hutool.poi.excel.ExcelWriter;
import com.supervision.knowsub.service.RagService;
import com.supervision.knowsub.vo.RagResVO;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;
@RestController
@RequestMapping("test")
@Slf4j
@RequiredArgsConstructor
public class TestController {
private final RagService ragService;
@GetMapping("esTest")
public void esTest(){
ExcelReader reader = ExcelUtil.getReader("/Users/flevance/Desktop/深圳人社POC/question.xlsx","Sheet2");
List<Object> objects = reader.readColumn(3, 1);
ExcelWriter writer = reader.getWriter();
for (int i = 0; i < objects.size(); i++) {
RagResVO ask = ragService.esAsk(objects.get(i).toString());
writer.writeCellValue(5, i + 1, ask.getAnswer());
writer.writeCellValue(6, i + 1, JSONUtil.toJsonStr(ask.getFileName()));
log.info("第{}条数据写入成功,剩余{}条", i + 1, objects.size() - i - 1);
}
writer.flush();
}
@GetMapping("redisTest")
public void redisTest(){
ExcelReader reader = ExcelUtil.getReader("/Users/flevance/Desktop/深圳人社POC/question.xlsx","Sheet2");
List<Object> objects = reader.readColumn(3, 1);
ExcelWriter writer = reader.getWriter();
for (int i = 0; i < objects.size(); i++) {
RagResVO ask = ragService.redisAsk(objects.get(i).toString());
writer.writeCellValue(5, i + 1, ask.getAnswer());
writer.writeCellValue(6, i + 1, JSONUtil.toJsonStr(ask.getFileName()));
log.info("第{}条数据写入成功,剩余{}条", i + 1, objects.size() - i - 1);
}
writer.flush();
}
}

@ -1,6 +1,10 @@
package com.supervision.knowsub.service;
import com.supervision.knowsub.vo.RagResVO;
public interface RagService {
String ask(String question);
RagResVO esAsk(String question);
RagResVO redisAsk(String question);
}

@ -1,6 +1,8 @@
package com.supervision.knowsub.service.impl;
import cn.hutool.core.util.StrUtil;
import com.supervision.knowsub.service.RagService;
import com.supervision.knowsub.vo.RagResVO;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
@ -11,11 +13,14 @@ import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.vectorstore.ElasticsearchVectorStore;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.stereotype.Service;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
@ -25,6 +30,8 @@ public class RagServiceImpl implements RagService {
private final ElasticsearchVectorStore elasticsearchVectorStore;
private final RedisVectorStore redisVectorStore;
// private final OllamaChatClient chatClient ;
private final OllamaChatModel ollamaChatModel;
@ -38,7 +45,7 @@ public class RagServiceImpl implements RagService {
:
<context>{context}</context>
""";
@ -55,30 +62,86 @@ public class RagServiceImpl implements RagService {
public static final String systemPrompt1 = """
,:"请注意,具体的政策和流程可能会有所变化,因此建议您咨询当地的人力资源和社会保障部门或访问官方网站以获取最新信息。"!
"根据您提供的信息"!
,:
<context>{context}</context>
""";
public static final String langChainChatPrompt = """
<> "根据已知信息无法回答该问题"使 </>
<>{{ context }}</>
<>{{ question }}</>
""";
@Override
public RagResVO esAsk(String question) {
log.info("检索相关文档");
List<Document> similarDocuments = elasticsearchVectorStore.similaritySearch(SearchRequest.query(question).withTopK(10).withSimilarityThreshold(0.5));
Set<String> fileNameList = new HashSet<>();
for (Document similarDocument : similarDocuments) {
fileNameList.add(String.valueOf(similarDocument.getMetadata().get("fileName")));
}
log.info("找到:{}条相关文档", similarDocuments.size());
// 构建系统消息
// String relevantDocument = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n"));
// SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt1);
// Message systemMessage = systemPromptTemplate.createMessage(Map.of("context", relevantDocument));
// // 构建用户消息
// UserMessage userMessage = new UserMessage(question);
// Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
// 构建系统消息
String relevantDocument = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n"));
String format = StrUtil.format(langChainChatPrompt, Map.of("context", relevantDocument, "question", question));
Prompt prompt = new Prompt(new UserMessage(format));
log.info("开始询问GPT问题");
ChatResponse call = ollamaChatModel.call(prompt);
log.info("AI responded.");
RagResVO ragResVO = new RagResVO();
ragResVO.setAnswer(call.getResult().getOutput().getContent());
ragResVO.setFileName(fileNameList);
return ragResVO;
}
@Override
public String ask(String question) {
public RagResVO redisAsk(String question) {
log.info("检索相关文档");
List<Document> similarDocuments = elasticsearchVectorStore.similaritySearch(SearchRequest.query(question).withTopK(10));
List<Document> similarDocuments = redisVectorStore.similaritySearch(SearchRequest.query(question).withTopK(10).withSimilarityThreshold(0.5));
Set<String> fileNameList = new HashSet<>();
for (Document similarDocument : similarDocuments) {
fileNameList.add(String.valueOf(similarDocument.getMetadata().get("fileName")));
}
log.info("找到:{}条相关文档", similarDocuments.size());
// // 构建系统消息
// String relevantDocument = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n"));
// SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt1);
// Message systemMessage = systemPromptTemplate.createMessage(Map.of("context", relevantDocument));
// // 构建用户消息
// UserMessage userMessage = new UserMessage(question);
// Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
// 构建系统消息
String relevantDocument = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n"));
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt1);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("context", relevantDocument));
// 构建用户消息
UserMessage userMessage = new UserMessage(question);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
String format = StrUtil.format(langChainChatPrompt, Map.of("context", relevantDocument, "question", question));
Prompt prompt = new Prompt(new UserMessage(format));
log.info("开始询问GPT问题");
ChatResponse call = ollamaChatModel.call(prompt);
log.info("AI responded.");
return call.getResult().getOutput().getContent();
RagResVO ragResVO = new RagResVO();
ragResVO.setAnswer(call.getResult().getOutput().getContent());
ragResVO.setFileName(fileNameList);
return ragResVO;
}
}

@ -0,0 +1,14 @@
package com.supervision.knowsub.vo;
import lombok.Data;
import java.util.List;
import java.util.Set;
@Data
public class RagResVO {
private String answer;
private Set<String> fileName;
}

@ -98,4 +98,10 @@ user:
# uris: http://192.168.10.137:9200
embedding:
url: http://192.168.10.137:8711/embeddings/
url: http://192.168.10.137:8711/embeddings/
vector:
redis:
uri: redis://:123456@192.168.10.137:6380
indexName: 'know-sub-rag-store'
prefix: 'know-sub-rag-store:'
Loading…
Cancel
Save