diff --git a/pom.xml b/pom.xml
index 8e4f1ec..f26c4a0 100644
--- a/pom.xml
+++ b/pom.xml
@@ -85,6 +85,11 @@
stanford-corenlp
4.5.4
+
+ org.neo4j.driver
+ neo4j-java-driver
+ 5.15.0
+
diff --git a/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java b/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java
index af8111a..47973f9 100644
--- a/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java
+++ b/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java
@@ -10,6 +10,8 @@ public class PromptCache {
public static final String DOERE_TEXT = "DOERE_TEXT";
public static final String DOERE_TABLE = "DOERE_TABLE";
+ public static final String TEXT_TO_CYPHER = "TEXT_TO_CYPHER";
+ public static final String GENERATE_ANSWER = "GENERATE_ANSWER";
public static final String CHINESE_TO_ENGLISH = "CHINESE_TO_ENGLISH";
@@ -24,6 +26,8 @@ public class PromptCache {
promptMap.put(DOERE_TABLE, DOERE_TABLE_PROMPT);
promptMap.put(CHINESE_TO_ENGLISH, CHINESE_TO_ENGLISH_PROMPT);
promptMap.put(ERE_TO_INSERT_CYPHER, ERE_TO_INSERT_CYPHER_PROMPT);
+ promptMap.put(TEXT_TO_CYPHER, TEXT_TO_CYPHER_PROMPT);
+ promptMap.put(GENERATE_ANSWER, GENERATE_ANSWER_PROMPT);
}
@@ -186,6 +190,17 @@ public class PromptCache {
{}
""";
+ private static final String TEXT_TO_CYPHER_PROMPT = """
+ 结合给你的领域元数据,分析用户输入的问题,尝试将其转换为CYPHER语句。
+ 领域元数据:{domainMetadata}
+ 用户输入的问题:{userQuery}
+ """;
+
+ private static final String GENERATE_ANSWER_PROMPT = """
+ 结合给你的三元组数据和用户输入的问题,生成一个简洁的回答。
+ 三元组数据:{tripleMetaData}
+ 用户输入的问题:{userQuery}
+ """;
private static final String CHINESE_TO_ENGLISH_PROMPT = """
你是一个Neo4j图数据库命名规范转换专家,请将以下中文短语转换为符合Neo4j命名规范的英文名称。要求:
diff --git a/src/main/java/com/supervision/pdfqaserver/config/Neo4jConfig.java b/src/main/java/com/supervision/pdfqaserver/config/Neo4jConfig.java
new file mode 100644
index 0000000..9a8a720
--- /dev/null
+++ b/src/main/java/com/supervision/pdfqaserver/config/Neo4jConfig.java
@@ -0,0 +1,39 @@
+package com.supervision.pdfqaserver.config;
+
+import org.neo4j.driver.AuthTokens;
+import org.neo4j.driver.Config;
+import org.neo4j.driver.Driver;
+import org.neo4j.driver.GraphDatabase;
+import org.neo4j.driver.internal.SessionFactory;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+
+import java.util.concurrent.TimeUnit;
+
+@Configuration
+public class Neo4jConfig {
+
+ @Value("${neo4j.driver.uri}")
+ private String uri;
+ @Value("${neo4j.driver.user}")
+ private String user;
+ @Value("${neo4j.driver.password}")
+ private String password;
+
+ /**
+ * Driver 为线程安全的单例,可重用连接池,建议应用启动时创建并在容器中管理
+ */
+ @Bean
+ public Driver neo4jDriver() {
+ return GraphDatabase.driver(
+ uri,
+ AuthTokens.basic(user, password),
+ Config.builder()
+ .withMaxConnectionPoolSize(50)
+ .withConnectionAcquisitionTimeout(5, TimeUnit.SECONDS)
+ .build()
+ );
+ }
+
+}
diff --git a/src/main/java/com/supervision/pdfqaserver/controller/ChatController.java b/src/main/java/com/supervision/pdfqaserver/controller/ChatController.java
index 598fdab..5391055 100644
--- a/src/main/java/com/supervision/pdfqaserver/controller/ChatController.java
+++ b/src/main/java/com/supervision/pdfqaserver/controller/ChatController.java
@@ -1,46 +1,31 @@
package com.supervision.pdfqaserver.controller;
-import cn.hutool.core.lang.Assert;
-import cn.hutool.core.util.StrUtil;
-import com.supervision.pdfqaserver.dto.R;
+import com.supervision.pdfqaserver.service.ChatService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
-import org.springframework.ai.chat.messages.Message;
-import org.springframework.ai.chat.messages.SystemMessage;
-import org.springframework.ai.chat.messages.UserMessage;
-import org.springframework.ai.ollama.OllamaChatModel;
-import org.springframework.web.bind.annotation.*;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
+import org.springframework.http.MediaType;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.RequestParam;
+import org.springframework.web.bind.annotation.RestController;
+import reactor.core.publisher.Flux;
@Slf4j
@RestController
-@RequestMapping("/ollama")
+@RequestMapping("/chat")
@RequiredArgsConstructor
-@CrossOrigin(origins = "*", maxAge = 3600)
public class ChatController {
- private final OllamaChatModel ollamaChatModel;
+ private final ChatService chatService;
/**
- * 仅供调试使用,后期移除该接口
- * @param message
- * @return
+ * 知识问答
+ *
+ * @param userQuery 用户查询
+ * @return 知识问答结果
*/
- @PostMapping("/chat")
- public R pageList(@RequestBody Map message) {
- List messages = new ArrayList<>();
- if (StrUtil.isNotEmpty(message.get("system"))){
- messages.add(new SystemMessage(message.get("system")));
- }
- if (StrUtil.isNotEmpty(message.get("user"))){
- messages.add(new UserMessage(message.get("user")));
- }
- log.info("system: {} , user: {}",message.get("system"),message.get("user"));
- Assert.notEmpty(messages, "消息不能为空");
- String response = ollamaChatModel.call(messages.toArray(new Message[0]));
- log.info("response:{}",response);
- return R.ok(response);
+ @GetMapping(value = "/knowledgeQA", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
+ public Flux knowledgeQA(@RequestParam("userQuery") String userQuery) {
+ return chatService.knowledgeQA(userQuery);
}
}
diff --git a/src/main/java/com/supervision/pdfqaserver/controller/OllamaChatController.java b/src/main/java/com/supervision/pdfqaserver/controller/OllamaChatController.java
new file mode 100644
index 0000000..b6d85ad
--- /dev/null
+++ b/src/main/java/com/supervision/pdfqaserver/controller/OllamaChatController.java
@@ -0,0 +1,46 @@
+package com.supervision.pdfqaserver.controller;
+
+import cn.hutool.core.lang.Assert;
+import cn.hutool.core.util.StrUtil;
+import com.supervision.pdfqaserver.dto.R;
+import lombok.RequiredArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.SystemMessage;
+import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.ai.ollama.OllamaChatModel;
+import org.springframework.web.bind.annotation.*;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+@Slf4j
+@RestController
+@RequestMapping("/ollama")
+@RequiredArgsConstructor
+@CrossOrigin(origins = "*", maxAge = 3600)
+public class OllamaChatController {
+
+ private final OllamaChatModel ollamaChatModel;
+
+ /**
+ * 仅供调试使用,后期移除该接口
+ * @param message
+ * @return
+ */
+ @PostMapping("/chat")
+ public R pageList(@RequestBody Map message) {
+ List messages = new ArrayList<>();
+ if (StrUtil.isNotEmpty(message.get("system"))){
+ messages.add(new SystemMessage(message.get("system")));
+ }
+ if (StrUtil.isNotEmpty(message.get("user"))){
+ messages.add(new UserMessage(message.get("user")));
+ }
+ log.info("system: {} , user: {}",message.get("system"),message.get("user"));
+ Assert.notEmpty(messages, "消息不能为空");
+ String response = ollamaChatModel.call(messages.toArray(new Message[0]));
+ log.info("response:{}",response);
+ return R.ok(response);
+ }
+}
diff --git a/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java b/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java
new file mode 100644
index 0000000..7485f7a
--- /dev/null
+++ b/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java
@@ -0,0 +1,72 @@
+package com.supervision.pdfqaserver.dao;
+
+import com.supervision.pdfqaserver.dto.neo4j.NodeData;
+import com.supervision.pdfqaserver.dto.neo4j.RelationObject;
+import com.supervision.pdfqaserver.dto.neo4j.RelationshipData;
+import lombok.RequiredArgsConstructor;
+import org.neo4j.driver.Driver;
+import org.neo4j.driver.Result;
+import org.neo4j.driver.Session;
+import org.neo4j.driver.types.Node;
+import org.neo4j.driver.types.Relationship;
+import org.springframework.stereotype.Repository;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
+
+@Repository
+@RequiredArgsConstructor
+public class Neo4jRepository {
+ private final Driver driver;
+
+ /**
+ * 执行传入的 Cypher 语句并返回 RelationObject 列表
+ *
+ * @param cypher 完整的预生成 Cypher 语句,比如:
+ * "MATCH (a)-[r:REL_TYPE]->(b) RETURN a AS startNode, r AS rel, b AS endNode"
+ * @param params 语句参数(可为空)
+ */
+ public List execute(String cypher, Map params) {
+ try (Session session = driver.session()) {
+ return session.executeRead(tx -> {
+ Result result = tx.run(cypher, params == null ? Collections.emptyMap() : params);
+ List list = new ArrayList<>();
+ while (result.hasNext()) {
+ org.neo4j.driver.Record record = result.next();
+ // 从 Record 中取出三部分
+ Node a = record.get("startNode").asNode();
+ Relationship r = record.get("r").asRelationship();
+ Node b = record.get("endNode").asNode();
+
+ // 转成我们的 DTO
+ NodeData start = mapNode(a);
+ RelationshipData rel = mapRel(r);
+ NodeData end = mapNode(b);
+
+ list.add(new RelationObject(start, rel, end));
+ }
+ return list;
+ });
+ }
+ }
+
+ private NodeData mapNode(Node node) {
+ return new NodeData(
+ node.id(),
+ StreamSupport.stream(node.labels().spliterator(), false).collect(Collectors.toList()),
+ node.asMap()
+ );
+ }
+
+ private RelationshipData mapRel(Relationship r) {
+ return new RelationshipData(
+ r.id(),
+ r.type(),
+ r.asMap()
+ );
+ }
+}
diff --git a/src/main/java/com/supervision/pdfqaserver/dto/neo4j/NodeData.java b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/NodeData.java
new file mode 100644
index 0000000..721b1d9
--- /dev/null
+++ b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/NodeData.java
@@ -0,0 +1,7 @@
+package com.supervision.pdfqaserver.dto.neo4j;
+
+import java.util.List;
+import java.util.Map;
+
+public record NodeData(long id, List labels, Map properties) {
+}
diff --git a/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationObject.java b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationObject.java
new file mode 100644
index 0000000..cc3e94e
--- /dev/null
+++ b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationObject.java
@@ -0,0 +1,4 @@
+package com.supervision.pdfqaserver.dto.neo4j;
+
+public record RelationObject(NodeData startNode, RelationshipData relationship, NodeData endNode) {
+}
diff --git a/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationshipData.java b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationshipData.java
new file mode 100644
index 0000000..6143c38
--- /dev/null
+++ b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationshipData.java
@@ -0,0 +1,6 @@
+package com.supervision.pdfqaserver.dto.neo4j;
+
+import java.util.Map;
+
+public record RelationshipData(long id, String type, Map properties) {
+}
diff --git a/src/main/java/com/supervision/pdfqaserver/service/ChatService.java b/src/main/java/com/supervision/pdfqaserver/service/ChatService.java
new file mode 100644
index 0000000..c5896ba
--- /dev/null
+++ b/src/main/java/com/supervision/pdfqaserver/service/ChatService.java
@@ -0,0 +1,14 @@
+package com.supervision.pdfqaserver.service;
+
+import reactor.core.publisher.Flux;
+
+public interface ChatService {
+
+ /**
+ * 知识问答
+ *
+ * @param userQuery 用户查询
+ * @return 知识问答结果
+ */
+ Flux knowledgeQA(String userQuery);
+}
diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java
new file mode 100644
index 0000000..eab7e3c
--- /dev/null
+++ b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java
@@ -0,0 +1,61 @@
+package com.supervision.pdfqaserver.service.impl;
+
+import com.supervision.pdfqaserver.cache.PromptCache;
+import com.supervision.pdfqaserver.dao.Neo4jRepository;
+import com.supervision.pdfqaserver.domain.DomainMetadata;
+import com.supervision.pdfqaserver.dto.neo4j.RelationObject;
+import com.supervision.pdfqaserver.service.ChatService;
+import com.supervision.pdfqaserver.service.DomainMetadataService;
+import lombok.RequiredArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.chat.prompt.SystemPromptTemplate;
+import org.springframework.ai.ollama.OllamaChatModel;
+import org.springframework.stereotype.Service;
+import reactor.core.publisher.Flux;
+
+import java.util.List;
+import java.util.Map;
+
+import static com.supervision.pdfqaserver.cache.PromptCache.GENERATE_ANSWER;
+import static com.supervision.pdfqaserver.cache.PromptCache.TEXT_TO_CYPHER;
+
+@Slf4j
+@Service
+@RequiredArgsConstructor
+public class ChatServiceImpl implements ChatService {
+ private static final String PROMPT_PARAM_DOMAIN_METADATA = "domainMetadata";
+ private static final String PROMPT_PARAM_TRIPLE_METADATA = "tripleMetaData";
+ private static final String PROMPT_PARAM_USER_QUERY = "userQuery";
+
+ private final Neo4jRepository neo4jRepository;
+ private final OllamaChatModel ollamaChatModel;
+ private final DomainMetadataService domainMetadataService;
+
+ @Override
+ public Flux knowledgeQA(String userQuery) {
+ String systemPrompt = domainMetadataService.list().stream()
+ .map(DomainMetadata::toString)
+ .reduce("", (acc, metadata) -> acc + metadata + "\n");
+
+ //生成CYPHER
+ SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER));
+ Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_DOMAIN_METADATA, systemPrompt, PROMPT_PARAM_USER_QUERY, userQuery));
+ ChatResponse textToCypherResponse = ollamaChatModel.call(new Prompt(textToCypherMessage));
+ String queryCypher = "MATCH (startNode:公司)-[r]->(endNode) RETURN startNode,r,endNode";
+ log.info(textToCypherResponse.getResult().getOutput().getText());
+// String queryCypher = textToCypherResponse.getResult().getOutput().getText();
+ List relationObjects = neo4jRepository.execute(queryCypher, null);
+ if (relationObjects.isEmpty()) {
+ return Flux.just("没有找到相关数据");
+ }
+ log.info("relationObjects: {}", relationObjects);
+
+ //生成回答
+ SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER));
+ Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_TRIPLE_METADATA, systemPrompt, PROMPT_PARAM_USER_QUERY, userQuery));
+ return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText());
+ }
+}
diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml
index 03e8cf4..fe43373 100644
--- a/src/main/resources/application.yml
+++ b/src/main/resources/application.yml
@@ -21,3 +21,8 @@ spring:
top_p: 0.9
top_k: 40
temperature: 0.7
+neo4j:
+ driver:
+ uri: bolt://192.168.10.138:7687
+ user: neo4j
+ password: 123456
\ No newline at end of file