From 375892f7964d481fe26c539f0710d5fb4f4231a5 Mon Sep 17 00:00:00 2001 From: xueqingkun Date: Tue, 29 Apr 2025 17:57:43 +0800 Subject: [PATCH] =?UTF-8?q?generateGraph=20=E5=8A=9F=E8=83=BD=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../pdfqaserver/dao/Neo4jRepository.java | 67 +++++++++++++++++-- .../pdfqaserver/dto/EntityExtractionDTO.java | 4 +- .../service/TripleToCypherExecutor.java | 2 + .../service/impl/ChatServiceImpl.java | 2 + .../impl/KnowledgeGraphServiceImpl.java | 7 +- .../impl/TripleConversionPipelineImpl.java | 8 ++- .../impl/TripleToCypherExecutorImpl.java | 67 +++++++++++++++++++ .../PdfQaServerApplicationTests.java | 62 +++++++++++++++++ 8 files changed, 208 insertions(+), 11 deletions(-) diff --git a/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java b/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java index 7485f7a..3f1b268 100644 --- a/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java +++ b/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java @@ -11,16 +11,17 @@ import org.neo4j.driver.types.Node; import org.neo4j.driver.types.Relationship; import org.springframework.stereotype.Repository; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.stream.Collectors; import java.util.stream.StreamSupport; +import static org.neo4j.driver.Values.parameters; @Repository @RequiredArgsConstructor public class Neo4jRepository { + /** + * Neo4j 驱动 + */ private final Driver driver; /** @@ -54,6 +55,64 @@ public class Neo4jRepository { } } + + /** + * 创建或更新实体节点 + * 根据唯一键(uniqueKey)来判断节点是否存在 + * note: properties中需要包含唯一键的值 + * @param label label + * @param uniqueKey 唯一键 + * @param properties 属性 + * @return 节点ID列表 + */ + public List saveOrUpdateEntityNode(String label,String uniqueKey,Map properties) { + + try (Session session = driver.session()) { + // MERGE语句确保唯一性 + String query = String.format( + "MERGE (n:%s {%s: $uniqueValue}) " + + "SET n += $properties " + + "RETURN id(n) as id ", + label, uniqueKey); + + Result result = session.run(query, parameters("uniqueValue", properties.get(uniqueKey), "properties", properties)); + return result.list().stream().map(record -> record.get("id").asLong()).collect(Collectors.toList()); + } + } + + /** + * 创建或更新关系 + * @param sourceId 头节点ID + * @param targetId 尾节点ID + * @param relationType 关系类型 + * @param leftDirection 方向左 + * @param rightDirection 方向右 + * @param properties 关系属性 + * @return + */ + public List saveOrUpdateRelation(Long sourceId,Long targetId,String relationType, + boolean leftDirection,boolean rightDirection,Map properties) { + + try (Session session = driver.session()) { + // MERGE关系 + String query = "MATCH (a) WHERE id(a) = $sourceId " + + "MATCH (b) WHERE id(b) = $targetId " + + "MERGE (a)%s-[r:" + relationType + "]-%s(b) " + + "SET r += $properties " + + "RETURN id(r) as id"; + + String format = String.format(query, + leftDirection ? "<" : "", + rightDirection ? ">" : ""); + + Result result = session.run(format, + parameters("sourceId", sourceId, "targetId", targetId, "properties", properties)); + + return result.stream().map(record -> record.get("id").asLong()).collect(Collectors.toList()); + } + } + + private NodeData mapNode(Node node) { return new NodeData( node.id(), diff --git a/src/main/java/com/supervision/pdfqaserver/dto/EntityExtractionDTO.java b/src/main/java/com/supervision/pdfqaserver/dto/EntityExtractionDTO.java index e976070..e202b2c 100644 --- a/src/main/java/com/supervision/pdfqaserver/dto/EntityExtractionDTO.java +++ b/src/main/java/com/supervision/pdfqaserver/dto/EntityExtractionDTO.java @@ -1,6 +1,8 @@ package com.supervision.pdfqaserver.dto; import lombok.Data; + +import java.util.ArrayList; import java.util.List; /** @@ -28,7 +30,7 @@ public class EntityExtractionDTO { */ private String name; - private List attributes; + private List attributes = new ArrayList<>(); public EntityExtractionDTO() { } diff --git a/src/main/java/com/supervision/pdfqaserver/service/TripleToCypherExecutor.java b/src/main/java/com/supervision/pdfqaserver/service/TripleToCypherExecutor.java index d48050d..b07725f 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/TripleToCypherExecutor.java +++ b/src/main/java/com/supervision/pdfqaserver/service/TripleToCypherExecutor.java @@ -28,4 +28,6 @@ public interface TripleToCypherExecutor { * @return */ void executeCypher(String cypher); + + void saveERE(EREDTO eredto); } diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java index 1c01375..02b0d5f 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java @@ -35,7 +35,9 @@ public class ChatServiceImpl implements ChatService { private static final String PROMPT_PARAM_USER_QUERY = "userQuery"; private final Neo4jRepository neo4jRepository; + private final OllamaChatModel ollamaChatModel; + private final DomainMetadataService domainMetadataService; private final ChineseEnglishWordsService chineseEnglishWordsService; diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java index 7d3f88d..719949e 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java @@ -122,10 +122,11 @@ public class KnowledgeGraphServiceImpl implements KnowledgeGraphService { log.info("保存字典完成,新增字典个数:{}", allWords.size() - wordsSize); // 生成cypher语句 for (EREDTO eredto : mergedList) { + if (CollUtil.isEmpty(eredto.getEntities()) && CollUtil.isEmpty(eredto.getRelations())){ + continue; + } eredto.setEn(allWords); - String insertCypher = tripleToCypherExecutor.generateInsertCypher(eredto); - log.info("insertCypher:{}", insertCypher); - tripleToCypherExecutor.executeCypher(insertCypher); + tripleToCypherExecutor.saveERE(eredto); } } diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/TripleConversionPipelineImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/TripleConversionPipelineImpl.java index 7690530..a9b2537 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/TripleConversionPipelineImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/TripleConversionPipelineImpl.java @@ -185,9 +185,11 @@ public class TripleConversionPipelineImpl implements TripleConversionPipeline { leavedEntities.add(entry.getValue()); } } - EREDTO eredto = new EREDTO(); - eredto.setEntities(leavedEntities); - merged.add(eredto); + if (CollUtil.isNotEmpty(leavedEntities)){ + EREDTO eredto = new EREDTO(); + eredto.setEntities(leavedEntities); + merged.add(eredto); + } return merged; } diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java index 025fee0..fdc0894 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java @@ -1,12 +1,24 @@ package com.supervision.pdfqaserver.service.impl; +import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.util.StrUtil; +import cn.hutool.json.JSONUtil; import com.supervision.pdfqaserver.cache.PromptCache; +import com.supervision.pdfqaserver.dao.Neo4jRepository; +import com.supervision.pdfqaserver.dto.ERAttributeDTO; import com.supervision.pdfqaserver.dto.EREDTO; +import com.supervision.pdfqaserver.dto.EntityExtractionDTO; +import com.supervision.pdfqaserver.dto.RelationExtractionDTO; import com.supervision.pdfqaserver.service.TripleToCypherExecutor; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.stereotype.Service; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + import static com.supervision.pdfqaserver.cache.PromptCache.ERE_TO_INSERT_CYPHER; @Slf4j @@ -15,6 +27,8 @@ import static com.supervision.pdfqaserver.cache.PromptCache.ERE_TO_INSERT_CYPHER public class TripleToCypherExecutorImpl implements TripleToCypherExecutor { private final OllamaChatModel ollamaChatModel; + + private final Neo4jRepository neo4jRepository; @Override public String generateInsertCypher(EREDTO eredto) { @@ -32,4 +46,57 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor { public void executeCypher(String cypher) { } + + @Override + public void saveERE(EREDTO eredto) { + + List entities = eredto.getEntities(); + Map> nodeCache = new HashMap<>(); + if (CollUtil.isNotEmpty(entities)){ + // 保存节点 + for (EntityExtractionDTO entity : entities) { + if (StrUtil.isEmpty(entity.getName())){ + log.info("实体name属性为空,详情:{}", JSONUtil.toJsonStr(entity)); + continue; + } + Map attributes = entity.getAttributes().stream().collect(Collectors.toMap( + ERAttributeDTO::getAttributeEn, ERAttributeDTO::getValue + )); + attributes.put("name", entity.getName()); + log.info("保存节点{},属性:{}", entity.getEntityEn(),JSONUtil.toJsonStr(entity.getAttributes())); + List nodeIds = neo4jRepository.saveOrUpdateEntityNode(entity.getEntityEn(), "name", attributes); + nodeCache.put(StrUtil.join("_", entity.getEntity(), entity.getName()), nodeIds); + } + } + if (CollUtil.isNotEmpty(eredto.getRelations())){ + // 保存关系 + for (RelationExtractionDTO relation : eredto.getRelations()) { + String sourceNodeKey = StrUtil.join("_", relation.getSourceType(), relation.getSource()); + List sourceNodeIds = nodeCache.get(sourceNodeKey); + if (CollUtil.isEmpty(sourceNodeIds)) { + log.info("关系{}没有source节点", sourceNodeKey); + continue; + } + String targetNodeKey = StrUtil.join("_", relation.getTargetType(), relation.getTarget()); + List targetNodeIds = nodeCache.get(targetNodeKey); + if (CollUtil.isEmpty(targetNodeIds)) { + log.info("关系{}没有target节点", targetNodeKey); + continue; + } + Map attributes = relation.getAttributes().stream().collect(Collectors.toMap( + ERAttributeDTO::getAttributeEn, ERAttributeDTO::getValue + )); + for (Long sourceNodeId : sourceNodeIds) { + for (Long targetNodeId : targetNodeIds) { + if (sourceNodeId.equals(targetNodeId)) { + log.info("关系{}的source和target节点相同", sourceNodeKey); + continue; + } + log.info("保存关系{}-{}-{}的属性:{}", relation.getSourceTypeEn(), relation.getRelationEn(),relation.getTargetTypeEn(), attributes); + neo4jRepository.saveOrUpdateRelation(sourceNodeId, targetNodeId, relation.getRelationEn(), false, false, attributes); + } + } + } + } + } } diff --git a/src/test/java/com/supervision/pdfqaserver/PdfQaServerApplicationTests.java b/src/test/java/com/supervision/pdfqaserver/PdfQaServerApplicationTests.java index 2abdc70..cdaa1a3 100644 --- a/src/test/java/com/supervision/pdfqaserver/PdfQaServerApplicationTests.java +++ b/src/test/java/com/supervision/pdfqaserver/PdfQaServerApplicationTests.java @@ -3,8 +3,14 @@ package com.supervision.pdfqaserver; import com.supervision.pdfqaserver.service.KnowledgeGraphService; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; +import org.neo4j.driver.*; +import org.neo4j.driver.Record; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import static org.neo4j.driver.Values.parameters; @Slf4j @SpringBootTest @@ -18,4 +24,60 @@ class PdfQaServerApplicationTests { log.info("finish..."); } + @Autowired + private Driver driver; + + /** + * 测试保存或更新节点 + */ + @Test + void testSaveOrUpdateEntityNode() { + + List longs = saveOrUpdateRelation(7838L, 7940L, "REL_TYPE", true, false, Map.of("name", "test1")); + System.out.println(longs); + } + public void saveOrUpdateEntityNode(String label, String uniqueKey, Map properties) { + + try (Session session = driver.session()) { + // MERGE语句确保唯一性 + String query = String.format( + "MERGE (n:%s {%s: $uniqueValue}) " + + "SET n += $properties " + + "RETURN id(n) as id", + label, uniqueKey); + + Result result = session.run(query, parameters("uniqueValue", properties.get(uniqueKey), "properties", properties)); + + + if (result.hasNext()) { + Record next = result.next(); + Value value = next.get("id"); + long aLong = value.asLong(); + System.out.printf("已处理 %s 节点: %s%n", label, next.keys()); + } + } + } + + private List saveOrUpdateRelation(Long sourceId, Long targetId, String relationType, + boolean leftDirection, boolean rightDirection, Map relation) { + + try (Session session = driver.session()) { + // MERGE关系 + String query = "MATCH (a) WHERE id(a) = $sourceId " + + "MATCH (b) WHERE id(b) = $targetId " + + "MERGE (a)%s-[r:" + relationType + "]-%s(b) " + + "SET r += $properties " + + "RETURN id(r) as id"; + + String format = String.format(query, + leftDirection ? "<" : "", + rightDirection ? ">" : ""); + + Result result = session.run(format, + parameters("sourceId", sourceId, "targetId", targetId, "properties", relation)); + + return result.stream().map(record -> record.get("id").asLong()).collect(Collectors.toList()); + } + } + }