diff --git a/pom.xml b/pom.xml index 8e4f1ec..f26c4a0 100644 --- a/pom.xml +++ b/pom.xml @@ -85,6 +85,11 @@ stanford-corenlp 4.5.4 + + org.neo4j.driver + neo4j-java-driver + 5.15.0 + diff --git a/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java b/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java index af8111a..47973f9 100644 --- a/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java +++ b/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java @@ -10,6 +10,8 @@ public class PromptCache { public static final String DOERE_TEXT = "DOERE_TEXT"; public static final String DOERE_TABLE = "DOERE_TABLE"; + public static final String TEXT_TO_CYPHER = "TEXT_TO_CYPHER"; + public static final String GENERATE_ANSWER = "GENERATE_ANSWER"; public static final String CHINESE_TO_ENGLISH = "CHINESE_TO_ENGLISH"; @@ -24,6 +26,8 @@ public class PromptCache { promptMap.put(DOERE_TABLE, DOERE_TABLE_PROMPT); promptMap.put(CHINESE_TO_ENGLISH, CHINESE_TO_ENGLISH_PROMPT); promptMap.put(ERE_TO_INSERT_CYPHER, ERE_TO_INSERT_CYPHER_PROMPT); + promptMap.put(TEXT_TO_CYPHER, TEXT_TO_CYPHER_PROMPT); + promptMap.put(GENERATE_ANSWER, GENERATE_ANSWER_PROMPT); } @@ -186,6 +190,17 @@ public class PromptCache { {} """; + private static final String TEXT_TO_CYPHER_PROMPT = """ + 结合给你的领域元数据,分析用户输入的问题,尝试将其转换为CYPHER语句。 + 领域元数据:{domainMetadata} + 用户输入的问题:{userQuery} + """; + + private static final String GENERATE_ANSWER_PROMPT = """ + 结合给你的三元组数据和用户输入的问题,生成一个简洁的回答。 + 三元组数据:{tripleMetaData} + 用户输入的问题:{userQuery} + """; private static final String CHINESE_TO_ENGLISH_PROMPT = """ 你是一个Neo4j图数据库命名规范转换专家,请将以下中文短语转换为符合Neo4j命名规范的英文名称。要求: diff --git a/src/main/java/com/supervision/pdfqaserver/config/Neo4jConfig.java b/src/main/java/com/supervision/pdfqaserver/config/Neo4jConfig.java new file mode 100644 index 0000000..9a8a720 --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/config/Neo4jConfig.java @@ -0,0 +1,39 @@ +package com.supervision.pdfqaserver.config; + +import org.neo4j.driver.AuthTokens; +import org.neo4j.driver.Config; +import org.neo4j.driver.Driver; +import org.neo4j.driver.GraphDatabase; +import org.neo4j.driver.internal.SessionFactory; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +import java.util.concurrent.TimeUnit; + +@Configuration +public class Neo4jConfig { + + @Value("${neo4j.driver.uri}") + private String uri; + @Value("${neo4j.driver.user}") + private String user; + @Value("${neo4j.driver.password}") + private String password; + + /** + * Driver 为线程安全的单例,可重用连接池,建议应用启动时创建并在容器中管理 + */ + @Bean + public Driver neo4jDriver() { + return GraphDatabase.driver( + uri, + AuthTokens.basic(user, password), + Config.builder() + .withMaxConnectionPoolSize(50) + .withConnectionAcquisitionTimeout(5, TimeUnit.SECONDS) + .build() + ); + } + +} diff --git a/src/main/java/com/supervision/pdfqaserver/controller/ChatController.java b/src/main/java/com/supervision/pdfqaserver/controller/ChatController.java index 598fdab..5391055 100644 --- a/src/main/java/com/supervision/pdfqaserver/controller/ChatController.java +++ b/src/main/java/com/supervision/pdfqaserver/controller/ChatController.java @@ -1,46 +1,31 @@ package com.supervision.pdfqaserver.controller; -import cn.hutool.core.lang.Assert; -import cn.hutool.core.util.StrUtil; -import com.supervision.pdfqaserver.dto.R; +import com.supervision.pdfqaserver.service.ChatService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.ollama.OllamaChatModel; -import org.springframework.web.bind.annotation.*; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; +import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; +import reactor.core.publisher.Flux; @Slf4j @RestController -@RequestMapping("/ollama") +@RequestMapping("/chat") @RequiredArgsConstructor -@CrossOrigin(origins = "*", maxAge = 3600) public class ChatController { - private final OllamaChatModel ollamaChatModel; + private final ChatService chatService; /** - * 仅供调试使用,后期移除该接口 - * @param message - * @return + * 知识问答 + * + * @param userQuery 用户查询 + * @return 知识问答结果 */ - @PostMapping("/chat") - public R pageList(@RequestBody Map message) { - List messages = new ArrayList<>(); - if (StrUtil.isNotEmpty(message.get("system"))){ - messages.add(new SystemMessage(message.get("system"))); - } - if (StrUtil.isNotEmpty(message.get("user"))){ - messages.add(new UserMessage(message.get("user"))); - } - log.info("system: {} , user: {}",message.get("system"),message.get("user")); - Assert.notEmpty(messages, "消息不能为空"); - String response = ollamaChatModel.call(messages.toArray(new Message[0])); - log.info("response:{}",response); - return R.ok(response); + @GetMapping(value = "/knowledgeQA", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public Flux knowledgeQA(@RequestParam("userQuery") String userQuery) { + return chatService.knowledgeQA(userQuery); } } diff --git a/src/main/java/com/supervision/pdfqaserver/controller/OllamaChatController.java b/src/main/java/com/supervision/pdfqaserver/controller/OllamaChatController.java new file mode 100644 index 0000000..b6d85ad --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/controller/OllamaChatController.java @@ -0,0 +1,46 @@ +package com.supervision.pdfqaserver.controller; + +import cn.hutool.core.lang.Assert; +import cn.hutool.core.util.StrUtil; +import com.supervision.pdfqaserver.dto.R; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.ollama.OllamaChatModel; +import org.springframework.web.bind.annotation.*; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +@Slf4j +@RestController +@RequestMapping("/ollama") +@RequiredArgsConstructor +@CrossOrigin(origins = "*", maxAge = 3600) +public class OllamaChatController { + + private final OllamaChatModel ollamaChatModel; + + /** + * 仅供调试使用,后期移除该接口 + * @param message + * @return + */ + @PostMapping("/chat") + public R pageList(@RequestBody Map message) { + List messages = new ArrayList<>(); + if (StrUtil.isNotEmpty(message.get("system"))){ + messages.add(new SystemMessage(message.get("system"))); + } + if (StrUtil.isNotEmpty(message.get("user"))){ + messages.add(new UserMessage(message.get("user"))); + } + log.info("system: {} , user: {}",message.get("system"),message.get("user")); + Assert.notEmpty(messages, "消息不能为空"); + String response = ollamaChatModel.call(messages.toArray(new Message[0])); + log.info("response:{}",response); + return R.ok(response); + } +} diff --git a/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java b/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java new file mode 100644 index 0000000..7485f7a --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java @@ -0,0 +1,72 @@ +package com.supervision.pdfqaserver.dao; + +import com.supervision.pdfqaserver.dto.neo4j.NodeData; +import com.supervision.pdfqaserver.dto.neo4j.RelationObject; +import com.supervision.pdfqaserver.dto.neo4j.RelationshipData; +import lombok.RequiredArgsConstructor; +import org.neo4j.driver.Driver; +import org.neo4j.driver.Result; +import org.neo4j.driver.Session; +import org.neo4j.driver.types.Node; +import org.neo4j.driver.types.Relationship; +import org.springframework.stereotype.Repository; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +@Repository +@RequiredArgsConstructor +public class Neo4jRepository { + private final Driver driver; + + /** + * 执行传入的 Cypher 语句并返回 RelationObject 列表 + * + * @param cypher 完整的预生成 Cypher 语句,比如: + * "MATCH (a)-[r:REL_TYPE]->(b) RETURN a AS startNode, r AS rel, b AS endNode" + * @param params 语句参数(可为空) + */ + public List execute(String cypher, Map params) { + try (Session session = driver.session()) { + return session.executeRead(tx -> { + Result result = tx.run(cypher, params == null ? Collections.emptyMap() : params); + List list = new ArrayList<>(); + while (result.hasNext()) { + org.neo4j.driver.Record record = result.next(); + // 从 Record 中取出三部分 + Node a = record.get("startNode").asNode(); + Relationship r = record.get("r").asRelationship(); + Node b = record.get("endNode").asNode(); + + // 转成我们的 DTO + NodeData start = mapNode(a); + RelationshipData rel = mapRel(r); + NodeData end = mapNode(b); + + list.add(new RelationObject(start, rel, end)); + } + return list; + }); + } + } + + private NodeData mapNode(Node node) { + return new NodeData( + node.id(), + StreamSupport.stream(node.labels().spliterator(), false).collect(Collectors.toList()), + node.asMap() + ); + } + + private RelationshipData mapRel(Relationship r) { + return new RelationshipData( + r.id(), + r.type(), + r.asMap() + ); + } +} diff --git a/src/main/java/com/supervision/pdfqaserver/dto/neo4j/NodeData.java b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/NodeData.java new file mode 100644 index 0000000..721b1d9 --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/NodeData.java @@ -0,0 +1,7 @@ +package com.supervision.pdfqaserver.dto.neo4j; + +import java.util.List; +import java.util.Map; + +public record NodeData(long id, List labels, Map properties) { +} diff --git a/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationObject.java b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationObject.java new file mode 100644 index 0000000..cc3e94e --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationObject.java @@ -0,0 +1,4 @@ +package com.supervision.pdfqaserver.dto.neo4j; + +public record RelationObject(NodeData startNode, RelationshipData relationship, NodeData endNode) { +} diff --git a/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationshipData.java b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationshipData.java new file mode 100644 index 0000000..6143c38 --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationshipData.java @@ -0,0 +1,6 @@ +package com.supervision.pdfqaserver.dto.neo4j; + +import java.util.Map; + +public record RelationshipData(long id, String type, Map properties) { +} diff --git a/src/main/java/com/supervision/pdfqaserver/service/ChatService.java b/src/main/java/com/supervision/pdfqaserver/service/ChatService.java new file mode 100644 index 0000000..c5896ba --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/service/ChatService.java @@ -0,0 +1,14 @@ +package com.supervision.pdfqaserver.service; + +import reactor.core.publisher.Flux; + +public interface ChatService { + + /** + * 知识问答 + * + * @param userQuery 用户查询 + * @return 知识问答结果 + */ + Flux knowledgeQA(String userQuery); +} diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java new file mode 100644 index 0000000..eab7e3c --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java @@ -0,0 +1,61 @@ +package com.supervision.pdfqaserver.service.impl; + +import com.supervision.pdfqaserver.cache.PromptCache; +import com.supervision.pdfqaserver.dao.Neo4jRepository; +import com.supervision.pdfqaserver.domain.DomainMetadata; +import com.supervision.pdfqaserver.dto.neo4j.RelationObject; +import com.supervision.pdfqaserver.service.ChatService; +import com.supervision.pdfqaserver.service.DomainMetadataService; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.ollama.OllamaChatModel; +import org.springframework.stereotype.Service; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.Map; + +import static com.supervision.pdfqaserver.cache.PromptCache.GENERATE_ANSWER; +import static com.supervision.pdfqaserver.cache.PromptCache.TEXT_TO_CYPHER; + +@Slf4j +@Service +@RequiredArgsConstructor +public class ChatServiceImpl implements ChatService { + private static final String PROMPT_PARAM_DOMAIN_METADATA = "domainMetadata"; + private static final String PROMPT_PARAM_TRIPLE_METADATA = "tripleMetaData"; + private static final String PROMPT_PARAM_USER_QUERY = "userQuery"; + + private final Neo4jRepository neo4jRepository; + private final OllamaChatModel ollamaChatModel; + private final DomainMetadataService domainMetadataService; + + @Override + public Flux knowledgeQA(String userQuery) { + String systemPrompt = domainMetadataService.list().stream() + .map(DomainMetadata::toString) + .reduce("", (acc, metadata) -> acc + metadata + "\n"); + + //生成CYPHER + SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER)); + Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_DOMAIN_METADATA, systemPrompt, PROMPT_PARAM_USER_QUERY, userQuery)); + ChatResponse textToCypherResponse = ollamaChatModel.call(new Prompt(textToCypherMessage)); + String queryCypher = "MATCH (startNode:公司)-[r]->(endNode) RETURN startNode,r,endNode"; + log.info(textToCypherResponse.getResult().getOutput().getText()); +// String queryCypher = textToCypherResponse.getResult().getOutput().getText(); + List relationObjects = neo4jRepository.execute(queryCypher, null); + if (relationObjects.isEmpty()) { + return Flux.just("没有找到相关数据"); + } + log.info("relationObjects: {}", relationObjects); + + //生成回答 + SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER)); + Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_TRIPLE_METADATA, systemPrompt, PROMPT_PARAM_USER_QUERY, userQuery)); + return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText()); + } +} diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index 03e8cf4..fe43373 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -21,3 +21,8 @@ spring: top_p: 0.9 top_k: 40 temperature: 0.7 +neo4j: + driver: + uri: bolt://192.168.10.138:7687 + user: neo4j + password: 123456 \ No newline at end of file