RAG代码提交

dev_1.0.0^2
liu 10 months ago
parent c5182c9e78
commit 665ee4af13

@ -30,10 +30,12 @@
<version>4.5.13</version> <version>4.5.13</version>
</dependency> </dependency>
<dependency> <!-- <dependency>-->
<groupId>org.springframework.ai</groupId> <!-- <groupId>io.springboot.ai</groupId>-->
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId> <!-- <artifactId>spring-ai-ollama-spring-boot-starter</artifactId>-->
</dependency> <!-- </dependency>-->
<dependency> <dependency>
<groupId>org.springframework.ai</groupId> <groupId>org.springframework.ai</groupId>

@ -1,7 +1,6 @@
package som.supervision.knowsub.config; package som.supervision.knowsub.config;
import org.elasticsearch.client.RestClient; import org.elasticsearch.client.RestClient;
import org.springframework.ai.autoconfigure.vectorstore.elasticsearch.ElasticsearchVectorStoreProperties;
import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.ElasticsearchVectorStore; import org.springframework.ai.vectorstore.ElasticsearchVectorStore;
import org.springframework.ai.vectorstore.ElasticsearchVectorStoreOptions; import org.springframework.ai.vectorstore.ElasticsearchVectorStoreOptions;
@ -24,9 +23,9 @@ public class ElasticsearchVectorStoreConfig {
@Bean @Bean
@ConditionalOnProperty(prefix = "embedding", name = "url") @ConditionalOnProperty(prefix = "embedding", name = "url")
public ElasticsearchVectorStore vectorStore(ElasticsearchVectorStoreProperties properties,EmbeddingModel embeddingModel, RestClient restClient) { public ElasticsearchVectorStore vectorStore(EmbeddingModel embeddingModel, RestClient restClient) {
ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions(); ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions();
options.setIndexName(properties.getIndexName()); options.setIndexName("know-sub-rag-store");
options.setDimensions(1024); options.setDimensions(1024);
return new ElasticsearchVectorStore(options, restClient, embeddingModel, true); return new ElasticsearchVectorStore(options, restClient, embeddingModel, true);
} }

@ -0,0 +1,26 @@
package som.supervision.knowsub.controller;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import org.apache.ibatis.annotations.Param;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import som.supervision.knowsub.service.KnowledgeEtlService;
import java.io.IOException;
@Tag(name = "知识ETL类")
@RestController
@RequestMapping("etl")
@RequiredArgsConstructor
public class KnowledgeEtlController {
private final KnowledgeEtlService knowledgeEtlService;
@Operation(summary = "对知识进行ETL")
@PostMapping("knowledgeEtl")
public void knowledgeEtl(@RequestParam("files") MultipartFile[] files) {
knowledgeEtlService.knowledgeEtl(files);
}
}

@ -0,0 +1,11 @@
package som.supervision.knowsub.service;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.multipart.MultipartFile;
import java.io.IOException;
public interface KnowledgeEtlService {
void knowledgeEtl(MultipartFile[] files);
}

@ -1,6 +1,5 @@
package com.supervision.knowsub.etl; package som.supervision.knowsub.service.impl;
import com.supervision.knowsub.dto.HtmlContext;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document; import org.springframework.ai.document.Document;
@ -8,16 +7,18 @@ import org.springframework.ai.reader.tika.TikaDocumentReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter; import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.ElasticsearchVectorStore; import org.springframework.ai.vectorstore.ElasticsearchVectorStore;
import org.springframework.core.io.InputStreamResource; import org.springframework.core.io.InputStreamResource;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Service;
import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.multipart.MultipartFile;
import som.supervision.knowsub.service.KnowledgeEtlService;
import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.List; import java.util.List;
@Slf4j @Slf4j
@Component @Service
@RequiredArgsConstructor @RequiredArgsConstructor
public class EtlProcessor { public class KnowledgeEtlServiceImpl implements KnowledgeEtlService {
private final ElasticsearchVectorStore elasticsearchVectorStore; private final ElasticsearchVectorStore elasticsearchVectorStore;
@ -26,23 +27,32 @@ public class EtlProcessor {
* *
* @param inputStream * @param inputStream
*/ */
public void loadFile(InputStream inputStream) { private void loadFile(InputStream inputStream, String fileName) {
// 首先使用tika进行文件切分操作 // 首先使用tika进行文件切分操作
log.info("首先进行内容切分"); log.info("{} 进行内容切分", fileName);
TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(new InputStreamResource(inputStream)); TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(new InputStreamResource(inputStream));
List<Document> documents = tikaDocumentReader.read(); List<Document> documents = tikaDocumentReader.read();
log.info("切分完成,开始进行chunk分割"); log.info("{} 切分完成,开始进行chunk分割", fileName);
// 然后切分为chunk // 然后切分为chunk
TokenTextSplitter tokenTextSplitter = new TokenTextSplitter(); TokenTextSplitter tokenTextSplitter = new TokenTextSplitter(200, 100, 10, 1000, true);
List<Document> apply = tokenTextSplitter.apply(documents); List<Document> apply = tokenTextSplitter.apply(documents);
log.info("切分完成,开始进行保存到向量库中"); log.info("{} 切分完成,开始进行保存到向量库中", fileName);
// 保存到向量数据库中 // 保存到向量数据库中
elasticsearchVectorStore.accept(apply); elasticsearchVectorStore.accept(apply);
log.info("保存完成"); log.info("{} 保存完成", fileName);
} }
public void loadHtml(HtmlContext htmlContext) {
// 使用Html工具进行读取 @Override
public void knowledgeEtl(MultipartFile[] files) {
for (MultipartFile file : files) {
try {
loadFile(file.getInputStream(), file.getOriginalFilename());
} catch (Exception e) {
log.error("{}文件处理失败", file.getOriginalFilename(), e);
}
}
} }
} }

@ -14,10 +14,6 @@ server:
spring: spring:
elasticsearch: elasticsearch:
uris: http://192.168.10.137:9200 uris: http://192.168.10.137:9200
ai:
vectorstore:
elasticsearch:
index-name: know-sub-rag-store
main: main:
allow-bean-definition-overriding: true allow-bean-definition-overriding: true

@ -40,6 +40,8 @@
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId> <artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.springframework.ai</groupId> <groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-elasticsearch-store</artifactId> <artifactId>spring-ai-elasticsearch-store</artifactId>

@ -0,0 +1,34 @@
package com.supervision.knowsub.config;
import org.elasticsearch.client.RestClient;
import org.springframework.ai.autoconfigure.vectorstore.elasticsearch.ElasticsearchVectorStoreProperties;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.ElasticsearchVectorStore;
import org.springframework.ai.vectorstore.ElasticsearchVectorStoreOptions;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.util.Assert;
@Configuration
@EnableConfigurationProperties(EmbeddingProperties.class)
public class ElasticsearchVectorStoreConfig {
@Bean
@ConditionalOnProperty(prefix = "embedding", name = "url")
public EmbeddingModel embeddingModel(EmbeddingProperties embeddingProperties) {
Assert.notNull(embeddingProperties.getUrl(), "配置文件embedding:url未找到");
return new VectorEmbeddingModel(embeddingProperties.getUrl());
}
@Bean
@ConditionalOnProperty(prefix = "embedding", name = "url")
public ElasticsearchVectorStore vectorStore(ElasticsearchVectorStoreProperties properties,EmbeddingModel embeddingModel, RestClient restClient) {
ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions();
options.setIndexName(properties.getIndexName());
options.setDimensions(1024);
return new ElasticsearchVectorStore(options, restClient, embeddingModel, true);
}
}

@ -0,0 +1,12 @@
package com.supervision.knowsub.config;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
@Data
@ConfigurationProperties(prefix = "embedding")
public class EmbeddingProperties {
private String url;
}

@ -0,0 +1,57 @@
package com.supervision.knowsub.config;
import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONUtil;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.*;
import org.springframework.util.Assert;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
@Slf4j
public class VectorEmbeddingModel implements EmbeddingModel {
private final String embeddingUrl;
public VectorEmbeddingModel(String embeddingUrl) {
this.embeddingUrl = embeddingUrl;
}
@Override
public List<Double> embed(Document document) {
List<List<Double>> list = this.call(new EmbeddingRequest(List.of(document.getContent()), EmbeddingOptions.EMPTY))
.getResults()
.stream()
.map(Embedding::getOutput)
.toList();
return list.iterator().next();
}
@Override
public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
List<List<Double>> embeddingList = new ArrayList<>();
for (String inputContent : request.getInstructions()) {
// 这里需要吧inputContent转化为向量数据
String post = HttpUtil.post(embeddingUrl, JSONUtil.toJsonStr(Map.of("text", inputContent)));
EmbeddingData bean = JSONUtil.toBean(post, EmbeddingData.class);
embeddingList.add(bean.embeddings);
}
var indexCounter = new AtomicInteger(0);
List<Embedding> embeddings = embeddingList.stream()
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
.toList();
return new EmbeddingResponse(embeddings);
}
@Data
private static class EmbeddingData {
private List<Double> embeddings;
}
}

@ -1,27 +0,0 @@
package com.supervision.knowsub.controller;
import com.supervision.knowsub.dto.HtmlContext;
import com.supervision.knowsub.etl.EtlProcessor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import java.io.IOException;
@RestController
@RequestMapping("etl")
public class EtlController {
@Autowired
private EtlProcessor etlProcessor;
@PostMapping("testLoadFile")
public void testLoadFile(@RequestParam(name = "file") MultipartFile file) throws IOException {
etlProcessor.loadFile(file.getInputStream());
}
@PostMapping("testLoadHtml")
public void testLoadHtml(@RequestBody HtmlContext htmlContext){
etlProcessor.loadHtml(htmlContext);
}
}

@ -1,9 +0,0 @@
package com.supervision.knowsub.dto;
import lombok.Data;
@Data
public class HtmlContext {
private String htmlContext;
}

@ -1,17 +1,17 @@
package com.supervision.knowsub.service.impl; package com.supervision.knowsub.service.impl;
import com.supervision.knowsub.service.RagService; import com.supervision.knowsub.service.RagService;
import com.supervision.knowsub.util.SpringBeanUtil;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.document.Document; import org.springframework.ai.document.Document;
import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.vectorstore.ElasticsearchVectorStore; import org.springframework.ai.vectorstore.ElasticsearchVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.List; import java.util.List;
@ -25,14 +25,18 @@ public class RagServiceImpl implements RagService {
private final ElasticsearchVectorStore elasticsearchVectorStore; private final ElasticsearchVectorStore elasticsearchVectorStore;
private static final OllamaChatClient chatClient = SpringBeanUtil.getBean(OllamaChatClient.class); // private final OllamaChatClient chatClient ;
private final OllamaChatModel ollamaChatModel;
private static final String springDemoSystemPrompt = """ private static final String springDemoSystemPrompt = """
使 使
: :
@ -40,31 +44,41 @@ public class RagServiceImpl implements RagService {
"""; """;
private static final String systemPrompt = """ private static final String systemPrompt = """
使"我不知道" 使"我不知道"
!"感谢你的提问". !"感谢你的提问".
::
::?:
: :
<context>{context}</context> <context>{context}</context>
Question: {input} """;
public static final String systemPrompt1 = """
,:"请注意,具体的政策和流程可能会有所变化,因此建议您咨询当地的人力资源和社会保障部门或访问官方网站以获取最新信息。"!
"根据您提供的信息"!
,:
<context>{context}</context>
"""; """;
@Override @Override
public String ask(String question) { public String ask(String question) {
log.info("检索相关文档"); log.info("检索相关文档");
List<Document> similarDocuments = elasticsearchVectorStore.similaritySearch(question); List<Document> similarDocuments = elasticsearchVectorStore.similaritySearch(SearchRequest.query(question).withTopK(10));
log.info("找到:{}条相关文档", similarDocuments.size()); log.info("找到:{}条相关文档", similarDocuments.size());
// 构建系统消息 // 构建系统消息
String relevantDocument = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n")); String relevantDocument = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n"));
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(springDemoSystemPrompt); SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt1);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("context", relevantDocument)); Message systemMessage = systemPromptTemplate.createMessage(Map.of("context", relevantDocument));
// 构建用户消息 // 构建用户消息
UserMessage userMessage = new UserMessage(question); UserMessage userMessage = new UserMessage(question);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
log.info("Asking AI model to reply to question."); log.info("开始询问GPT问题");
ChatResponse chatResponse = chatClient.call(prompt); ChatResponse call = ollamaChatModel.call(prompt);
log.info("AI responded."); log.info("AI responded.");
return chatResponse.getResult().getOutput().getContent(); return call.getResult().getOutput().getContent();
} }
} }

@ -1,106 +1,94 @@
package com.supervision.knowsub.util; package com.supervision.knowsub.util;
import cn.hutool.core.thread.ThreadUtil;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.*;
@Slf4j @Slf4j
public class AiChatUtil { public class AiChatUtil {
private static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(5, 5, "chat", new ThreadPoolExecutor.CallerRunsPolicy()); // private static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(5, 5, "chat", new ThreadPoolExecutor.CallerRunsPolicy());
//
private static final OllamaChatClient chatClient = SpringBeanUtil.getBean(OllamaChatClient.class); // private static final OllamaChatClient chatClient = SpringBeanUtil.getBean(OllamaChatClient.class);
//
/** // /**
* // * 单轮对话
* // *
* @param chat // * @param chat 对话的内容
* @return jsonObject // * @return jsonObject
*/ // */
public static Optional<JSONObject> chat(String chat) { // public static Optional<JSONObject> chat(String chat) {
Prompt prompt = new Prompt(List.of(new UserMessage(chat))); // Prompt prompt = new Prompt(List.of(new UserMessage(chat)));
Future<String> submit = chatExecutor.submit(new ChatTask(chatClient, prompt)); // Future<String> submit = chatExecutor.submit(new ChatTask(chatClient, prompt));
try { // try {
return Optional.of(JSONUtil.parseObj(submit.get())); // return Optional.of(JSONUtil.parseObj(submit.get()));
} catch (ExecutionException | InterruptedException e) { // } catch (ExecutionException | InterruptedException e) {
log.error("调用大模型生成失败"); // log.error("调用大模型生成失败");
} // }
return Optional.empty(); // return Optional.empty();
} // }
//
/** // /**
* , // * 支持多轮对话,自定义消息
* // *
* @param messageList // * @param messageList 消息列表
* @return jsonObject // * @return jsonObject
*/ // */
public static Optional<JSONObject> chat(List<Message> messageList) { // public static Optional<JSONObject> chat(List<Message> messageList) {
Prompt prompt = new Prompt(messageList); // Prompt prompt = new Prompt(messageList);
Future<String> submit = chatExecutor.submit(new ChatTask(chatClient, prompt)); // Future<String> submit = chatExecutor.submit(new ChatTask(chatClient, prompt));
try { // try {
return Optional.of(JSONUtil.parseObj(submit.get())); // return Optional.of(JSONUtil.parseObj(submit.get()));
} catch (ExecutionException | InterruptedException e) { // } catch (ExecutionException | InterruptedException e) {
log.error("调用大模型生成失败"); // log.error("调用大模型生成失败");
} // }
return Optional.empty(); // return Optional.empty();
} // }
//
/** // /**
* // * 支持序列化的方式
* // *
* @param messageList // * @param messageList 消息列表
* @param clazz // * @param clazz 需要序列化的对象
* @param <T> // * @param <T> 需要序列化的对象的泛型
* @return , // * @return 对应对象类型, 不支持列表类型
*/ // */
public static <T> Optional<T> chat(List<Message> messageList, Class<T> clazz) { // public static <T> Optional<T> chat(List<Message> messageList, Class<T> clazz) {
Prompt prompt = new Prompt(messageList); // Prompt prompt = new Prompt(messageList);
Future<String> submit = chatExecutor.submit(new ChatTask(chatClient, prompt)); // Future<String> submit = chatExecutor.submit(new ChatTask(chatClient, prompt));
try { // try {
String s = submit.get(); // String s = submit.get();
return Optional.ofNullable(JSONUtil.toBean(s, clazz)); // return Optional.ofNullable(JSONUtil.toBean(s, clazz));
} catch (ExecutionException | InterruptedException e) { // } catch (ExecutionException | InterruptedException e) {
log.error("调用大模型生成失败", e); // log.error("调用大模型生成失败", e);
} // }
return Optional.empty(); // return Optional.empty();
} // }
//
/** // /**
* // * 支持序列化的方式的对话
* // *
* @param chat // * @param chat 对话的消息
* @param clazz // * @param clazz 需要序列化的对象
* @param <T> // * @param <T> 需要序列化的对象的泛型
* @return , // * @return 对应对象类型, 不支持列表类型
*/ // */
public static <T> Optional<T> chat(String chat, Class<T> clazz) { // public static <T> Optional<T> chat(String chat, Class<T> clazz) {
Prompt prompt = new Prompt(List.of(new UserMessage(chat))); // Prompt prompt = new Prompt(List.of(new UserMessage(chat)));
Future<String> submit = chatExecutor.submit(new ChatTask(chatClient, prompt)); // Future<String> submit = chatExecutor.submit(new ChatTask(chatClient, prompt));
try { // try {
String s = submit.get(); // String s = submit.get();
return Optional.ofNullable(JSONUtil.toBean(s, clazz)); // return Optional.ofNullable(JSONUtil.toBean(s, clazz));
} catch (ExecutionException | InterruptedException e) { // } catch (ExecutionException | InterruptedException e) {
log.error("调用大模型生成失败"); // log.error("调用大模型生成失败");
} // }
return Optional.empty(); // return Optional.empty();
} // }
//
private record ChatTask(OllamaChatClient chatClient, Prompt prompt) implements Callable<String> { // private record ChatTask(OllamaChatClient chatClient, Prompt prompt) implements Callable<String> {
@Override // @Override
public String call() { // public String call() {
ChatResponse call = chatClient.call(prompt); // ChatResponse call = chatClient.call(prompt);
return call.getResult().getOutput().getContent(); // return call.getResult().getOutput().getContent();
} // }
} // }
} }

@ -15,6 +15,15 @@ spring:
elasticsearch: elasticsearch:
uris: http://192.168.10.137:9200 uris: http://192.168.10.137:9200
ai: ai:
ollama:
base-url: http://192.168.10.70:11434
chat:
enabled: true
options:
model: llama3-chinese:8b
keep-alive: 1000m
temperature: 0.1
vectorstore: vectorstore:
elasticsearch: elasticsearch:
index-name: know-sub-rag-store index-name: know-sub-rag-store

Loading…
Cancel
Save