RAG代码提交

dev_1.0.0^2
liu 8 months ago
parent 4fb5164193
commit 616678229f

@ -35,11 +35,6 @@
<!-- <artifactId>spring-ai-ollama-spring-boot-starter</artifactId>-->
<!-- </dependency>-->
<dependency>
<groupId>io.springboot.ai</groupId>
<artifactId>spring-ai-redis</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-elasticsearch-store</artifactId>

@ -1,16 +0,0 @@
package som.supervision.knowsub.config;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
@Data
@ConfigurationProperties(prefix = "vector.redis")
public class RedisVectorProperties {
private String uri;
private String indexName;
private String prefix;
}

@ -1,35 +0,0 @@
package som.supervision.knowsub.config;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.util.Assert;
@Configuration
@EnableConfigurationProperties({RedisVectorProperties.class, EmbeddingProperties.class})
public class RedisVectorStoreConfig {
@Bean
@ConditionalOnProperty(prefix = "embedding", name = "url")
public VectorEmbeddingClient vectorEmbeddingClient(EmbeddingProperties embeddingProperties) {
Assert.notNull(embeddingProperties.getUrl(), "配置文件embedding:url未找到");
return new VectorEmbeddingClient(embeddingProperties.getUrl());
}
@Bean
@ConditionalOnProperty(prefix = "vector.redis", name = "uri")
public RedisVectorStore redisVectorStore(VectorEmbeddingClient vectorEmbeddingClient, RedisVectorProperties redisVectorProperties) {
Assert.notNull(redisVectorProperties.getUri(), "配置文件vector.redis.uri未找到");
RedisVectorStore.RedisVectorStoreConfig config = RedisVectorStore.RedisVectorStoreConfig.builder()
.withURI(redisVectorProperties.getUri())
.withPrefix(redisVectorProperties.getPrefix())
.withIndexName(redisVectorProperties.getIndexName())
// 定义搜索过滤器使用的元数据字段(!!!!!!!!千万重要,数据类型一定要用字符串,否则会导致查询不到!!!!!!!!)
.withMetadataFields(
RedisVectorStore.MetadataField.tag("fileName"))
.build();
return new RedisVectorStore(config, vectorEmbeddingClient);
}
}

@ -1,57 +0,0 @@
package som.supervision.knowsub.config;
import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONUtil;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.*;
import org.springframework.util.Assert;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
@Slf4j
public class VectorEmbeddingClient extends AbstractEmbeddingClient {
private final String embeddingUrl;
public VectorEmbeddingClient(String embeddingUrl) {
this.embeddingUrl = embeddingUrl;
}
@Override
public List<Double> embed(Document document) {
List<List<Double>> list = this.call(new EmbeddingRequest(List.of(document.getContent()), EmbeddingOptions.EMPTY))
.getResults()
.stream()
.map(Embedding::getOutput)
.toList();
return list.iterator().next();
}
@Override
public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
List<List<Double>> embeddingList = new ArrayList<>();
for (String inputContent : request.getInstructions()) {
// 这里需要吧inputContent转化为向量数据
String post = HttpUtil.post(embeddingUrl, JSONUtil.toJsonStr(Map.of("text", inputContent)));
EmbeddingData bean = JSONUtil.toBean(post, EmbeddingData.class);
embeddingList.add(bean.embeddings);
}
var indexCounter = new AtomicInteger(0);
List<Embedding> embeddings = embeddingList.stream()
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
.toList();
return new EmbeddingResponse(embeddings);
}
@Data
private static class EmbeddingData {
private List<Double> embeddings;
}
}

@ -24,9 +24,4 @@ public class KnowledgeEtlController {
knowledgeEtlService.knowledgeEtl(files);
}
@Operation(summary = "Redis对知识进行ETL")
@PostMapping("redisKnowledgeEtl")
public void redisKnowledgeEtl(@RequestParam("files") MultipartFile[] files) {
knowledgeEtlService.redisKnowledgeEtl(files);
}
}

@ -9,5 +9,4 @@ public interface KnowledgeEtlService {
void knowledgeEtl(MultipartFile[] files);
void redisKnowledgeEtl(MultipartFile[] files);
}

@ -6,13 +6,11 @@ import org.springframework.ai.document.Document;
import org.springframework.ai.reader.tika.TikaDocumentReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.ElasticsearchVectorStore;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.core.io.InputStreamResource;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import som.supervision.knowsub.service.KnowledgeEtlService;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
@ -24,8 +22,6 @@ public class KnowledgeEtlServiceImpl implements KnowledgeEtlService {
private final ElasticsearchVectorStore elasticsearchVectorStore;
private final RedisVectorStore redisVectorStore;
/**
* <a href="https://zhuanlan.zhihu.com/p/703705663"/>
*
@ -67,33 +63,4 @@ public class KnowledgeEtlServiceImpl implements KnowledgeEtlService {
log.info("文件处理结束");
}
@Override
public void redisKnowledgeEtl(MultipartFile[] files) {
AtomicInteger atomicInteger = new AtomicInteger(1);
for (MultipartFile file : files) {
try {
// 首先使用tika进行文件切分操作
log.info("{} 进行内容切分", file.getOriginalFilename());
TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(new InputStreamResource(file.getInputStream()));
List<Document> documents = tikaDocumentReader.read();
log.info("{} 切分完成,开始进行chunk分割", file.getOriginalFilename());
// 然后切分为chunk
TokenTextSplitter tokenTextSplitter = new TokenTextSplitter(500, 250, 10, 1000, true);
List<Document> apply = tokenTextSplitter.apply(documents);
for (Document document : apply) {
document.getMetadata().put("fileName", file.getOriginalFilename());
}
log.info("{} 切分完成,开始进行保存到向量库中", file.getOriginalFilename());
// 保存到向量数据库中
redisVectorStore.accept(apply);
log.info("{} 保存完成", file.getOriginalFilename());
} catch (Exception e) {
log.error("{}文件处理失败", file.getOriginalFilename(), e);
}
int andIncrement = atomicInteger.getAndIncrement();
log.info("处理第{}个文件,剩余:{}个", andIncrement, files.length - andIncrement + 1);
}
log.info("文件处理结束");
}
}

@ -45,12 +45,6 @@
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>io.springboot.ai</groupId>
<artifactId>spring-ai-redis</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>

@ -1,16 +0,0 @@
package com.supervision.knowsub.config;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
@Data
@ConfigurationProperties(prefix = "vector.redis")
public class RedisVectorProperties {
private String uri;
private String indexName;
private String prefix;
}

@ -1,35 +0,0 @@
package com.supervision.knowsub.config;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.util.Assert;
@Configuration
@EnableConfigurationProperties({RedisVectorProperties.class, EmbeddingProperties.class})
public class RedisVectorStoreConfig {
@Bean
@ConditionalOnProperty(prefix = "embedding", name = "url")
public VectorEmbeddingClient vectorEmbeddingClient(EmbeddingProperties embeddingProperties) {
Assert.notNull(embeddingProperties.getUrl(), "配置文件embedding:url未找到");
return new VectorEmbeddingClient(embeddingProperties.getUrl());
}
@Bean
@ConditionalOnProperty(prefix = "vector.redis", name = "uri")
public RedisVectorStore redisVectorStore(VectorEmbeddingClient vectorEmbeddingClient, RedisVectorProperties redisVectorProperties) {
Assert.notNull(redisVectorProperties.getUri(), "配置文件vector.redis.uri未找到");
RedisVectorStore.RedisVectorStoreConfig config = RedisVectorStore.RedisVectorStoreConfig.builder()
.withURI(redisVectorProperties.getUri())
.withPrefix(redisVectorProperties.getPrefix())
.withIndexName(redisVectorProperties.getIndexName())
// 定义搜索过滤器使用的元数据字段(!!!!!!!!千万重要,数据类型一定要用字符串,否则会导致查询不到!!!!!!!!)
.withMetadataFields(
RedisVectorStore.MetadataField.tag("fileName"))
.build();
return new RedisVectorStore(config, vectorEmbeddingClient);
}
}

@ -1,57 +0,0 @@
package com.supervision.knowsub.config;
import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONUtil;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.*;
import org.springframework.util.Assert;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
@Slf4j
public class VectorEmbeddingClient extends AbstractEmbeddingClient {
private final String embeddingUrl;
public VectorEmbeddingClient(String embeddingUrl) {
this.embeddingUrl = embeddingUrl;
}
@Override
public List<Double> embed(Document document) {
List<List<Double>> list = this.call(new EmbeddingRequest(List.of(document.getContent()), EmbeddingOptions.EMPTY))
.getResults()
.stream()
.map(Embedding::getOutput)
.toList();
return list.iterator().next();
}
@Override
public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
List<List<Double>> embeddingList = new ArrayList<>();
for (String inputContent : request.getInstructions()) {
// 这里需要吧inputContent转化为向量数据
String post = HttpUtil.post(embeddingUrl, JSONUtil.toJsonStr(Map.of("text", inputContent)));
EmbeddingData bean = JSONUtil.toBean(post, EmbeddingData.class);
embeddingList.add(bean.embeddings);
}
var indexCounter = new AtomicInteger(0);
List<Embedding> embeddings = embeddingList.stream()
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
.toList();
return new EmbeddingResponse(embeddings);
}
@Data
private static class EmbeddingData {
private List<Double> embeddings;
}
}

@ -23,9 +23,4 @@ public class RagController {
return ragService.esAsk(question);
}
@Operation(summary = "问答")
@GetMapping("redisAsk")
public void redisAsk(String question) {
ragService.redisAsk(question);
}
}

@ -38,19 +38,4 @@ public class TestController {
writer.flush();
}
@GetMapping("redisTest")
public void redisTest(){
ExcelReader reader = ExcelUtil.getReader("/Users/flevance/Desktop/深圳人社POC/question.xlsx","Sheet2");
List<Object> objects = reader.readColumn(3, 1);
ExcelWriter writer = reader.getWriter();
for (int i = 0; i < objects.size(); i++) {
RagResVO ask = ragService.redisAsk(objects.get(i).toString());
writer.writeCellValue(5, i + 1, ask.getAnswer());
writer.writeCellValue(6, i + 1, JSONUtil.toJsonStr(ask.getFileName()));
log.info("第{}条数据写入成功,剩余{}条", i + 1, objects.size() - i - 1);
}
writer.flush();
}
}

@ -6,5 +6,4 @@ public interface RagService {
RagResVO esAsk(String question);
RagResVO redisAsk(String question);
}

@ -13,7 +13,6 @@ import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.vectorstore.ElasticsearchVectorStore;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.stereotype.Service;
@ -30,8 +29,6 @@ public class RagServiceImpl implements RagService {
private final ElasticsearchVectorStore elasticsearchVectorStore;
private final RedisVectorStore redisVectorStore;
// private final OllamaChatClient chatClient ;
private final OllamaChatModel ollamaChatModel;
@ -71,33 +68,36 @@ public class RagServiceImpl implements RagService {
""";
public static final String langChainChatPrompt = """
<> "根据已知信息无法回答该问题"使 </>
<>{{ context }}</>
<>{{ question }}</>
"根据已知信息无法回答该问题"
!!!
使!
<>{context}</>
<>{question}</>
""";
@Override
public RagResVO esAsk(String question) {
log.info("检索相关文档");
List<Document> similarDocuments = elasticsearchVectorStore.similaritySearch(SearchRequest.query(question).withTopK(10).withSimilarityThreshold(0.5));
List<Document> similarDocuments = elasticsearchVectorStore.similaritySearch(SearchRequest.query(question).withTopK(10));
Set<String> fileNameList = new HashSet<>();
for (Document similarDocument : similarDocuments) {
fileNameList.add(String.valueOf(similarDocument.getMetadata().get("fileName")));
}
log.info("找到:{}条相关文档", similarDocuments.size());
// 构建系统消息
// String relevantDocument = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n"));
// SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt1);
// Message systemMessage = systemPromptTemplate.createMessage(Map.of("context", relevantDocument));
// // 构建用户消息
// UserMessage userMessage = new UserMessage(question);
// Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
// 构建系统消息
String relevantDocument = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n"));
String format = StrUtil.format(langChainChatPrompt, Map.of("context", relevantDocument, "question", question));
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt1);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("context", relevantDocument));
// 构建用户消息
UserMessage userMessage = new UserMessage(question);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
Prompt prompt = new Prompt(new UserMessage(format));
// 构建系统消息
// String relevantDocument = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n"));
// String format = StrUtil.format(langChainChatPrompt, Map.of("context", relevantDocument, "question", question));
//
// Prompt prompt = new Prompt(new UserMessage(format));
log.info("开始询问GPT问题");
ChatResponse call = ollamaChatModel.call(prompt);
@ -110,38 +110,7 @@ public class RagServiceImpl implements RagService {
return ragResVO;
}
@Override
public RagResVO redisAsk(String question) {
log.info("检索相关文档");
List<Document> similarDocuments = redisVectorStore.similaritySearch(SearchRequest.query(question).withTopK(10).withSimilarityThreshold(0.5));
Set<String> fileNameList = new HashSet<>();
for (Document similarDocument : similarDocuments) {
fileNameList.add(String.valueOf(similarDocument.getMetadata().get("fileName")));
}
log.info("找到:{}条相关文档", similarDocuments.size());
// // 构建系统消息
// String relevantDocument = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n"));
// SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt1);
// Message systemMessage = systemPromptTemplate.createMessage(Map.of("context", relevantDocument));
// // 构建用户消息
// UserMessage userMessage = new UserMessage(question);
// Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
// 构建系统消息
String relevantDocument = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n"));
String format = StrUtil.format(langChainChatPrompt, Map.of("context", relevantDocument, "question", question));
Prompt prompt = new Prompt(new UserMessage(format));
log.info("开始询问GPT问题");
ChatResponse call = ollamaChatModel.call(prompt);
log.info("AI responded.");
RagResVO ragResVO = new RagResVO();
ragResVO.setAnswer(call.getResult().getOutput().getContent());
ragResVO.setFileName(fileNameList);
return ragResVO;
}
}

Loading…
Cancel
Save