知识问答初版、neo4j驱动,CYPHER语句执行

master
daixiaoyi 2 months ago
parent 7f5c52546a
commit 76e9d05a7b

@ -85,6 +85,11 @@
<artifactId>stanford-corenlp</artifactId> <artifactId>stanford-corenlp</artifactId>
<version>4.5.4</version> <version>4.5.4</version>
</dependency> </dependency>
<dependency>
<groupId>org.neo4j.driver</groupId>
<artifactId>neo4j-java-driver</artifactId>
<version>5.15.0</version>
</dependency>
</dependencies> </dependencies>
<dependencyManagement> <dependencyManagement>
<dependencies> <dependencies>

@ -10,6 +10,8 @@ public class PromptCache {
public static final String DOERE_TEXT = "DOERE_TEXT"; public static final String DOERE_TEXT = "DOERE_TEXT";
public static final String DOERE_TABLE = "DOERE_TABLE"; public static final String DOERE_TABLE = "DOERE_TABLE";
public static final String TEXT_TO_CYPHER = "TEXT_TO_CYPHER";
public static final String GENERATE_ANSWER = "GENERATE_ANSWER";
public static final String CHINESE_TO_ENGLISH = "CHINESE_TO_ENGLISH"; public static final String CHINESE_TO_ENGLISH = "CHINESE_TO_ENGLISH";
@ -24,6 +26,8 @@ public class PromptCache {
promptMap.put(DOERE_TABLE, DOERE_TABLE_PROMPT); promptMap.put(DOERE_TABLE, DOERE_TABLE_PROMPT);
promptMap.put(CHINESE_TO_ENGLISH, CHINESE_TO_ENGLISH_PROMPT); promptMap.put(CHINESE_TO_ENGLISH, CHINESE_TO_ENGLISH_PROMPT);
promptMap.put(ERE_TO_INSERT_CYPHER, ERE_TO_INSERT_CYPHER_PROMPT); promptMap.put(ERE_TO_INSERT_CYPHER, ERE_TO_INSERT_CYPHER_PROMPT);
promptMap.put(TEXT_TO_CYPHER, TEXT_TO_CYPHER_PROMPT);
promptMap.put(GENERATE_ANSWER, GENERATE_ANSWER_PROMPT);
} }
@ -186,6 +190,17 @@ public class PromptCache {
{} {}
"""; """;
private static final String TEXT_TO_CYPHER_PROMPT = """
CYPHER
{domainMetadata}
{userQuery}
""";
private static final String GENERATE_ANSWER_PROMPT = """
{tripleMetaData}
{userQuery}
""";
private static final String CHINESE_TO_ENGLISH_PROMPT = """ private static final String CHINESE_TO_ENGLISH_PROMPT = """
Neo4jNeo4j Neo4jNeo4j

@ -0,0 +1,39 @@
package com.supervision.pdfqaserver.config;
import org.neo4j.driver.AuthTokens;
import org.neo4j.driver.Config;
import org.neo4j.driver.Driver;
import org.neo4j.driver.GraphDatabase;
import org.neo4j.driver.internal.SessionFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.concurrent.TimeUnit;
@Configuration
public class Neo4jConfig {
@Value("${neo4j.driver.uri}")
private String uri;
@Value("${neo4j.driver.user}")
private String user;
@Value("${neo4j.driver.password}")
private String password;
/**
* Driver 线
*/
@Bean
public Driver neo4jDriver() {
return GraphDatabase.driver(
uri,
AuthTokens.basic(user, password),
Config.builder()
.withMaxConnectionPoolSize(50)
.withConnectionAcquisitionTimeout(5, TimeUnit.SECONDS)
.build()
);
}
}

@ -1,46 +1,31 @@
package com.supervision.pdfqaserver.controller; package com.supervision.pdfqaserver.controller;
import cn.hutool.core.lang.Assert; import com.supervision.pdfqaserver.service.ChatService;
import cn.hutool.core.util.StrUtil;
import com.supervision.pdfqaserver.dto.R;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message; import org.springframework.http.MediaType;
import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.ai.chat.messages.UserMessage; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.RestController;
import java.util.ArrayList; import reactor.core.publisher.Flux;
import java.util.List;
import java.util.Map;
@Slf4j @Slf4j
@RestController @RestController
@RequestMapping("/ollama") @RequestMapping("/chat")
@RequiredArgsConstructor @RequiredArgsConstructor
@CrossOrigin(origins = "*", maxAge = 3600)
public class ChatController { public class ChatController {
private final OllamaChatModel ollamaChatModel; private final ChatService chatService;
/** /**
* 使 *
* @param message *
* @return * @param userQuery
* @return
*/ */
@PostMapping("/chat") @GetMapping(value = "/knowledgeQA", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public R<String> pageList(@RequestBody Map<String,String> message) { public Flux<String> knowledgeQA(@RequestParam("userQuery") String userQuery) {
List<Message> messages = new ArrayList<>(); return chatService.knowledgeQA(userQuery);
if (StrUtil.isNotEmpty(message.get("system"))){
messages.add(new SystemMessage(message.get("system")));
}
if (StrUtil.isNotEmpty(message.get("user"))){
messages.add(new UserMessage(message.get("user")));
}
log.info("system: {} , user: {}",message.get("system"),message.get("user"));
Assert.notEmpty(messages, "消息不能为空");
String response = ollamaChatModel.call(messages.toArray(new Message[0]));
log.info("response:{}",response);
return R.ok(response);
} }
} }

@ -0,0 +1,46 @@
package com.supervision.pdfqaserver.controller;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import com.supervision.pdfqaserver.dto.R;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.web.bind.annotation.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@Slf4j
@RestController
@RequestMapping("/ollama")
@RequiredArgsConstructor
@CrossOrigin(origins = "*", maxAge = 3600)
public class OllamaChatController {
private final OllamaChatModel ollamaChatModel;
/**
* 使
* @param message
* @return
*/
@PostMapping("/chat")
public R<String> pageList(@RequestBody Map<String,String> message) {
List<Message> messages = new ArrayList<>();
if (StrUtil.isNotEmpty(message.get("system"))){
messages.add(new SystemMessage(message.get("system")));
}
if (StrUtil.isNotEmpty(message.get("user"))){
messages.add(new UserMessage(message.get("user")));
}
log.info("system: {} , user: {}",message.get("system"),message.get("user"));
Assert.notEmpty(messages, "消息不能为空");
String response = ollamaChatModel.call(messages.toArray(new Message[0]));
log.info("response:{}",response);
return R.ok(response);
}
}

@ -0,0 +1,72 @@
package com.supervision.pdfqaserver.dao;
import com.supervision.pdfqaserver.dto.neo4j.NodeData;
import com.supervision.pdfqaserver.dto.neo4j.RelationObject;
import com.supervision.pdfqaserver.dto.neo4j.RelationshipData;
import lombok.RequiredArgsConstructor;
import org.neo4j.driver.Driver;
import org.neo4j.driver.Result;
import org.neo4j.driver.Session;
import org.neo4j.driver.types.Node;
import org.neo4j.driver.types.Relationship;
import org.springframework.stereotype.Repository;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
@Repository
@RequiredArgsConstructor
public class Neo4jRepository {
private final Driver driver;
/**
* Cypher RelationObject
*
* @param cypher Cypher
* "MATCH (a)-[r:REL_TYPE]->(b) RETURN a AS startNode, r AS rel, b AS endNode"
* @param params
*/
public List<RelationObject> execute(String cypher, Map<String, Object> params) {
try (Session session = driver.session()) {
return session.executeRead(tx -> {
Result result = tx.run(cypher, params == null ? Collections.emptyMap() : params);
List<RelationObject> list = new ArrayList<>();
while (result.hasNext()) {
org.neo4j.driver.Record record = result.next();
// 从 Record 中取出三部分
Node a = record.get("startNode").asNode();
Relationship r = record.get("r").asRelationship();
Node b = record.get("endNode").asNode();
// 转成我们的 DTO
NodeData start = mapNode(a);
RelationshipData rel = mapRel(r);
NodeData end = mapNode(b);
list.add(new RelationObject(start, rel, end));
}
return list;
});
}
}
private NodeData mapNode(Node node) {
return new NodeData(
node.id(),
StreamSupport.stream(node.labels().spliterator(), false).collect(Collectors.toList()),
node.asMap()
);
}
private RelationshipData mapRel(Relationship r) {
return new RelationshipData(
r.id(),
r.type(),
r.asMap()
);
}
}

@ -0,0 +1,7 @@
package com.supervision.pdfqaserver.dto.neo4j;
import java.util.List;
import java.util.Map;
public record NodeData(long id, List<String> labels, Map<String, Object> properties) {
}

@ -0,0 +1,4 @@
package com.supervision.pdfqaserver.dto.neo4j;
public record RelationObject(NodeData startNode, RelationshipData relationship, NodeData endNode) {
}

@ -0,0 +1,6 @@
package com.supervision.pdfqaserver.dto.neo4j;
import java.util.Map;
public record RelationshipData(long id, String type, Map<String, Object> properties) {
}

@ -0,0 +1,14 @@
package com.supervision.pdfqaserver.service;
import reactor.core.publisher.Flux;
public interface ChatService {
/**
*
*
* @param userQuery
* @return
*/
Flux<String> knowledgeQA(String userQuery);
}

@ -0,0 +1,61 @@
package com.supervision.pdfqaserver.service.impl;
import com.supervision.pdfqaserver.cache.PromptCache;
import com.supervision.pdfqaserver.dao.Neo4jRepository;
import com.supervision.pdfqaserver.domain.DomainMetadata;
import com.supervision.pdfqaserver.dto.neo4j.RelationObject;
import com.supervision.pdfqaserver.service.ChatService;
import com.supervision.pdfqaserver.service.DomainMetadataService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.List;
import java.util.Map;
import static com.supervision.pdfqaserver.cache.PromptCache.GENERATE_ANSWER;
import static com.supervision.pdfqaserver.cache.PromptCache.TEXT_TO_CYPHER;
@Slf4j
@Service
@RequiredArgsConstructor
public class ChatServiceImpl implements ChatService {
private static final String PROMPT_PARAM_DOMAIN_METADATA = "domainMetadata";
private static final String PROMPT_PARAM_TRIPLE_METADATA = "tripleMetaData";
private static final String PROMPT_PARAM_USER_QUERY = "userQuery";
private final Neo4jRepository neo4jRepository;
private final OllamaChatModel ollamaChatModel;
private final DomainMetadataService domainMetadataService;
@Override
public Flux<String> knowledgeQA(String userQuery) {
String systemPrompt = domainMetadataService.list().stream()
.map(DomainMetadata::toString)
.reduce("", (acc, metadata) -> acc + metadata + "\n");
//生成CYPHER
SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER));
Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_DOMAIN_METADATA, systemPrompt, PROMPT_PARAM_USER_QUERY, userQuery));
ChatResponse textToCypherResponse = ollamaChatModel.call(new Prompt(textToCypherMessage));
String queryCypher = "MATCH (startNode:公司)-[r]->(endNode) RETURN startNode,r,endNode";
log.info(textToCypherResponse.getResult().getOutput().getText());
// String queryCypher = textToCypherResponse.getResult().getOutput().getText();
List<RelationObject> relationObjects = neo4jRepository.execute(queryCypher, null);
if (relationObjects.isEmpty()) {
return Flux.just("没有找到相关数据");
}
log.info("relationObjects: {}", relationObjects);
//生成回答
SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER));
Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_TRIPLE_METADATA, systemPrompt, PROMPT_PARAM_USER_QUERY, userQuery));
return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText());
}
}

@ -21,3 +21,8 @@ spring:
top_p: 0.9 top_p: 0.9
top_k: 40 top_k: 40
temperature: 0.7 temperature: 0.7
neo4j:
driver:
uri: bolt://192.168.10.138:7687
user: neo4j
password: 123456
Loading…
Cancel
Save