From eaf043aa075a9fe962503250e2b9a21d7b1ffea9 Mon Sep 17 00:00:00 2001 From: xueqingkun Date: Fri, 30 May 2025 16:36:56 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E5=8A=9F=E8=83=BDbug?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20=E4=BC=98=E5=8C=96=E9=97=AE=E7=AD=94?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../pdfqaserver/cache/PromptCache.java | 49 ++++- .../pdfqaserver/dao/Neo4jRepository.java | 131 +++++++++++- .../pdfqaserver/dto/CypherSchemaDTO.java | 117 +++++++++++ .../pdfqaserver/dto/neo4j/NodeDTO.java | 30 +++ .../pdfqaserver/dto/neo4j/PathDTO.java | 41 ++++ .../dto/neo4j/RelationshipValueDTO.java | 43 ++++ .../pdfqaserver/service/AiCallService.java | 6 + .../service/TripleToCypherExecutor.java | 28 ++- .../service/impl/ChatServiceImpl.java | 103 ++-------- .../impl/KnowledgeGraphServiceImpl.java | 1 + .../impl/TripleToCypherExecutorImpl.java | 189 ++++++++++++++++-- 11 files changed, 618 insertions(+), 120 deletions(-) create mode 100644 src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java create mode 100644 src/main/java/com/supervision/pdfqaserver/dto/neo4j/NodeDTO.java create mode 100644 src/main/java/com/supervision/pdfqaserver/dto/neo4j/PathDTO.java create mode 100644 src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationshipValueDTO.java diff --git a/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java b/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java index 588418e..d355d8b 100644 --- a/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java +++ b/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java @@ -86,6 +86,8 @@ public class PromptCache { */ public static final String CLASSIFY_QUERY_INTENT = "CLASSIFY_QUERY_INTENT"; + public static final String TEXT_TO_CYPHER_2 = "TEXT_TO_CYPHER_2"; + public static final Map promptMap = new HashMap<>(); static { @@ -107,6 +109,7 @@ public class PromptCache { promptMap.put(EXTRACT_INTENT_METADATA, EXTRACT_INTENT_METADATA_PROMPT); promptMap.put(EXTRACT_ERE_BASE_INTENT, EXTRACT_ERE_BASE_INTENT_PROMPT); promptMap.put(CLASSIFY_QUERY_INTENT, CLASSIFY_QUERY_INTENT_PROMPT); + promptMap.put(TEXT_TO_CYPHER_2, TEXT_TO_CYPHER_2_PROMPT); } @@ -657,13 +660,9 @@ public class PromptCache { 3. 不需要解释,不需要说明,仅返回以下两种结果: 匹配成功: - ```json {"IntentTypeList": ["...", "..."]} - ``` - 匹配失败: - ```json - {} - ``` +x {} 3.每个意图标签必须独立表述,禁止使用“...和...”等连接词合并两个意图。 ./no_think @@ -685,8 +684,8 @@ public class PromptCache { 1. 分析文本内容,识别与意图标签相关的实体和关系 2. 每一个意图只能匹配一个结果 3. 每个实体/关系应包含: - - type(类型) - - attributes(相关属性列表) + - type(类型,类型应该是被高度抽象的,不要直接用原文实体名称) + - attributes(相关属性类型列表,类型应该是被高度抽象的,不要直接用原文实体名称) 4. 输出纯JSON格式,不要使用```json ```等任何Markdown标记包装 5. 使用以下示例格式: @@ -694,15 +693,15 @@ public class PromptCache { { "source": { "type": "实体类型1", - "attributes": ["属性1", "属性2"] + "attributes": ["属性类型1", "属性类型2",....] }, "relation": { "type": "关系类型", - "attributes": [] + "attributes": ["属性类型3"...] }, "target": { "type": "实体类型2", - "attributes": ["属性3"] + "attributes": ["属性类型4"...] }, "intent": "匹配的意图标签" }, @@ -814,4 +813,34 @@ public class PromptCache { # 当前用户问题: {query} """; + + + private static final String TEXT_TO_CYPHER_2_PROMPT = """ + "You are a Cypher‑generating assistant. " + "Your sole reference for generating Cypher scripts is the `neo4j_schema` variable.\\n\\n" + "User question:\\n{question}\\n\\n" + "The schema is defined below in JSON format:\\n" + "{schema}\\n\\n" + "Follow these exact steps for every user query:\\n\\n" + "1. Extract Entities from User Query:\\n" + "- Parse the question for domain concepts and use synonyms or contextual cues to map them to schema elements.\\n" + "- Identify candidate **node types**.\\n" + "- Identify candidate **relationship types**.\\n" + "- Identify relevant **properties**.\\n" + "- Identify **constraints or conditions** (comparisons, flags, temporal filters, shared‑entity references, etc.).\\n\\n" + "2. Validate Against the Schema:\\n" + "- Ensure every node label, relationship type, and property exists in the schema **exactly** (case‑ and character‑sensitive).\\n" + "- If any required element is missing, respond exactly:\\n" + ' \\"I could not generate a Cypher script; the required information is not part of the Neo4j schema.\\"\\n\\n' + "3. Construct the MATCH Pattern:\\n" + "- Use only schema‑validated node labels and relationship types.\\n" + "- Reuse a single variable whenever the query implies that two patterns refer to the same node.\\n" + "- Express simple equality predicates in map patterns and move all other filters to a **WHERE** clause.\\n\\n" + "4. Return Clause Strategy:\\n" + "- RETURN every node and relationship mentioned, unless the user explicitly requests specific properties.\\n\\n" + "- The truncationId、name attribute of a node is very important, and each node needs to return truncationId、name .\\n\\n" + "5. Final Cypher Script Generation:\\n" + "- Respond with **only** the final Cypher query—no commentary or extra text.\\n" + "- Use OPTIONAL MATCH only if explicitly required by the user and supported by the schema.\\n" + """; } diff --git a/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java b/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java index 9b27f04..1ccd977 100644 --- a/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java +++ b/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java @@ -1,12 +1,16 @@ package com.supervision.pdfqaserver.dao; +import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.util.StrUtil; +import com.supervision.pdfqaserver.dto.EntityExtractionDTO; +import com.supervision.pdfqaserver.dto.RelationExtractionDTO; +import com.supervision.pdfqaserver.dto.TruncationERAttributeDTO; import com.supervision.pdfqaserver.dto.neo4j.NodeData; import com.supervision.pdfqaserver.dto.neo4j.RelationObject; import com.supervision.pdfqaserver.dto.neo4j.RelationshipData; import lombok.RequiredArgsConstructor; -import org.neo4j.driver.Driver; -import org.neo4j.driver.Result; -import org.neo4j.driver.Session; +import org.neo4j.driver.*; +import org.neo4j.driver.Record; import org.neo4j.driver.types.Node; import org.neo4j.driver.types.Relationship; import org.springframework.stereotype.Repository; @@ -59,6 +63,20 @@ public class Neo4jRepository { } + /** + * 执行原生 Cypher 语句 + * @param cypher 原生 Cypher 语句 + * @param params 参数 + * @return List + */ + public List executeCypherNative(String cypher, Map params) { + try (Session session = driver.session()) { + Result run = session.run(cypher, params == null ? Collections.emptyMap() : params); + return run.list(); + } + } + + /** * 创建或更新实体节点 * 根据唯一键(uniqueKey)来判断节点是否存在 @@ -115,6 +133,113 @@ public class Neo4jRepository { } } + /** + * 获取节点的schema + * @return + */ + public List getNodeSchema(){ + + String query = """ + CALL db.schema.nodeTypeProperties() + YIELD nodeType, propertyName, propertyTypes + RETURN nodeType, propertyName, propertyTypes + """; + try (Session session = driver.session()) { + + List extractionDTOS = new ArrayList<>(); + Result result = session.run(query); + for (Record record : result.list()) { + String nodeType = record.get("nodeType").asString(); + if (StrUtil.isEmpty(nodeType)){ + continue; + } + String propertyName = record.get("propertyName").asString(); + List propertyTypes = record.get("propertyTypes").asList(Value::asString); + + // 创建属性DTO + TruncationERAttributeDTO attributeDTO = new TruncationERAttributeDTO(propertyName, null, CollUtil.getFirst(propertyTypes)); + + // 检查是否已存在该节点类型 + EntityExtractionDTO existingEntity = extractionDTOS.stream() + .filter(e -> StrUtil.equals(e.getEntityEn(), nodeType)) + .findFirst().orElse(null); + + if (existingEntity != null) { + // 如果已存在,添加属性 + existingEntity.getAttributes().add(attributeDTO); + } else { + // 如果不存在,创建新的实体DTO + List truncationERAttributeDTOS = new ArrayList<>(); + truncationERAttributeDTOS.add(attributeDTO); + EntityExtractionDTO entityExtractionDTO = new EntityExtractionDTO(null,nodeType, null,truncationERAttributeDTOS); + extractionDTOS.add(entityExtractionDTO); + } + } + return extractionDTOS; + } + } + + /** + * 获取关系的schema + * @return + */ + public List getRelationSchema(){ + String queryProper = """ + CALL db.schema.relTypeProperties() + YIELD relType, propertyName, propertyTypes + RETURN relType, propertyName, propertyTypes + """; + Map>> relationProperties = new HashMap<>(); + try (Session session = driver.session()) { + Result result = session.run(queryProper); + for (Record record : result.list()) { + String relType = record.get("relType").asString(); + if (StrUtil.isEmpty(relType)){ + continue; + } + String propertyName = record.get("propertyName").asString(); + List propertyTypes = record.get("propertyTypes").asList(Value::asString); + + List> properties = relationProperties.computeIfAbsent(relType, k -> new ArrayList<>()); + boolean noneMatch = properties.stream().noneMatch( + prop -> StrUtil.equals(prop.get("propertyName"), propertyName) + ); + if (noneMatch){ + Map propMap = new HashMap<>(); + propMap.put("propertyName", propertyName); + propMap.put("propertyTypes", CollUtil.getFirst(propertyTypes)); + properties.add(propMap); + } + } + + List relationExtractionDTOS = new ArrayList<>(); + String queryEndpoints = """ + MATCH (s)-[r:`{rtype}`]->(t) + WITH labels(s)[0] AS src, labels(t)[0] AS tgt + RETURN src, tgt + """; + for (Map.Entry>> entry : relationProperties.entrySet()) { + String relType = entry.getKey(); + List> properties = entry.getValue(); + Result run = session.run(queryEndpoints, parameters("rtype", relType)); + for (Record record : run.list()) { + String sourceType = record.get("src").asString(); + String targetType = record.get("tgt").asString(); + List attributeDTOS = properties.stream().map( + prop -> new TruncationERAttributeDTO(prop.get("propertyName"), null, prop.get("propertyTypes")) + ).collect(Collectors.toList()); + RelationExtractionDTO relationExtractionDTO = new RelationExtractionDTO(null,null,sourceType, + relType, + null, + targetType, + attributeDTOS); + relationExtractionDTOS.add(relationExtractionDTO); + } + } + return relationExtractionDTOS; + } + } + private NodeData mapNode(Node node) { return new NodeData( diff --git a/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java b/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java new file mode 100644 index 0000000..f9d603d --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java @@ -0,0 +1,117 @@ +package com.supervision.pdfqaserver.dto; + +import cn.hutool.core.util.StrUtil; +import cn.hutool.json.JSONArray; +import cn.hutool.json.JSONObject; + +import java.util.ArrayList; +import java.util.List; + +/** + * CypherSchemaDTO + */ +public class CypherSchemaDTO { + + private List nodes = new ArrayList<>(); + + private List relations = new ArrayList<>(); + + public CypherSchemaDTO(List nodes, List relations) { + this.nodes = nodes; + this.relations = relations; + } + + /** + * 根据头节点、尾节点、关系获取关系抽取DTO + * @param sourceType 源节点类型 + * @param relation 关系 + * @param targetType 尾节点类型 + * @return + */ + public RelationExtractionDTO getRelation(String sourceType, String relation,String targetType) { + for (RelationExtractionDTO relationDTO : relations) { + if (StrUtil.equals(relationDTO.getSourceType(), sourceType) && + StrUtil.equals(relationDTO.getRelation(), relation) && + StrUtil.equals(relationDTO.getTargetType(), targetType)) { + return relationDTO; + } + } + return null; + } + + /** + * 根据源节点类型或目标节点类型获取关系抽取DTO列表 + * @param sourceOrTargetType + * @return + */ + public List getRelations(String sourceOrTargetType) { + List result = new ArrayList<>(); + for (RelationExtractionDTO relationDTO : relations) { + if (StrUtil.equals(relationDTO.getSourceType(), sourceOrTargetType) || + StrUtil.equals(relationDTO.getTargetType(), sourceOrTargetType)) { + result.add(relationDTO); + } + } + return result; + } + + /** + * 根据实体名获取关系抽取DTO列表 + * @param entity + * @return + */ + public EntityExtractionDTO getNode(String entity) { + for (EntityExtractionDTO node : nodes) { + if (StrUtil.equals(node.getEntity(), entity)) { + return node; + } + } + return null; + } + + public String format(){ + JSONObject nodeJson = new JSONObject(); + for (EntityExtractionDTO node : nodes) { + String entity = node.getEntity(); + List attributes = node.getAttributes(); + JSONObject nodeAttr = nodeJson.getJSONObject(entity); + if (nodeAttr == null) { + nodeAttr = new JSONObject(); + nodeJson.set(entity, nodeAttr); + } + for (TruncationERAttributeDTO attribute : attributes) { + boolean none = nodeAttr.entrySet().stream().noneMatch( + entry -> StrUtil.equals(entry.getKey(), attribute.getAttribute())); + if (none){ + nodeAttr.set(attribute.getAttribute(), attribute.getDataType()); + } + } + + } + JSONObject relJson = new JSONObject(); + for (RelationExtractionDTO relation : relations) { + String sourceType = relation.getSourceType(); + String targetType = relation.getTargetType(); + String rela = relation.getRelation(); + JSONObject json = relJson.getJSONObject(rela); + if (null == json) { + json = new JSONObject(); + relJson.set(rela, json); + } + json.set("_endpoints", new JSONArray(new String[]{sourceType, targetType})); + for (TruncationERAttributeDTO attribute : relation.getAttributes()) { + boolean none = json.entrySet().stream().noneMatch( + entry -> StrUtil.equals(entry.getKey(), attribute.getAttribute()) + ); + if (none) { + json.set(attribute.getAttribute(), attribute.getDataType()); + } + } + } + JSONObject object = new JSONObject() + .set("nodetypes", nodeJson) + .set("relationshiptypes", relJson); + return object.toString(); + + } +} diff --git a/src/main/java/com/supervision/pdfqaserver/dto/neo4j/NodeDTO.java b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/NodeDTO.java new file mode 100644 index 0000000..922c832 --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/NodeDTO.java @@ -0,0 +1,30 @@ +package com.supervision.pdfqaserver.dto.neo4j; + +import lombok.Data; +import org.neo4j.driver.internal.InternalNode; + +import java.util.Collection; +import java.util.Map; + +@Data +public class NodeDTO { + + private long id; + + private String elementId; + + private Map properties; + + private Collection labels; + + + public NodeDTO() { + } + + public NodeDTO(InternalNode internalNode) { + this.id = internalNode.id(); + this.elementId = internalNode.elementId(); + this.properties = internalNode.asMap(); + this.labels = internalNode.labels(); + } +} diff --git a/src/main/java/com/supervision/pdfqaserver/dto/neo4j/PathDTO.java b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/PathDTO.java new file mode 100644 index 0000000..57bb4e2 --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/PathDTO.java @@ -0,0 +1,41 @@ +package com.supervision.pdfqaserver.dto.neo4j; + +import lombok.Data; +import org.neo4j.driver.internal.InternalNode; +import org.neo4j.driver.internal.InternalRelationship; +import org.neo4j.driver.types.Node; +import org.neo4j.driver.types.Path; +import org.neo4j.driver.types.Relationship; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +@Data +public class PathDTO { + + private List nodes; + + private List relationships; + + public PathDTO() { + } + + public PathDTO(Path path) { + Iterator nodeIterator = path.nodes().iterator(); + List nodes = new ArrayList<>(); + while (nodeIterator.hasNext()){ + Node next = nodeIterator.next(); + nodes.add(new NodeDTO((InternalNode) next)); + } + this.nodes = nodes; + + + Iterator iterator = path.relationships().iterator(); + List relationships = new ArrayList<>(); + while (iterator.hasNext()){ + relationships.add(new RelationshipValueDTO((InternalRelationship) iterator.next())); + } + this.relationships = relationships; + } +} diff --git a/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationshipValueDTO.java b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationshipValueDTO.java new file mode 100644 index 0000000..5ab7b46 --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationshipValueDTO.java @@ -0,0 +1,43 @@ +package com.supervision.pdfqaserver.dto.neo4j; + +import lombok.Data; +import org.neo4j.driver.internal.InternalRelationship; + +import java.util.Map; + +@Data +public class RelationshipValueDTO { + + + private long start; + + private String startElementId; + + private long end; + + private String endElementId; + + private String type; + + private long id; + + private String elementId; + + private Map properties; + + + public RelationshipValueDTO() { + } + + public RelationshipValueDTO(InternalRelationship relationship) { + this.start = (int) relationship.startNodeId(); + this.startElementId = relationship.startNodeElementId(); + this.end = relationship.endNodeId(); + this.endElementId = relationship.endNodeElementId(); + this.type = relationship.type(); + this.id = relationship.id(); + this.elementId = relationship.elementId(); + this.properties = relationship.asMap(); + + } +} diff --git a/src/main/java/com/supervision/pdfqaserver/service/AiCallService.java b/src/main/java/com/supervision/pdfqaserver/service/AiCallService.java index 6dccb0d..b82861f 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/AiCallService.java +++ b/src/main/java/com/supervision/pdfqaserver/service/AiCallService.java @@ -1,5 +1,9 @@ package com.supervision.pdfqaserver.service; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import reactor.core.publisher.Flux; + /** * @description: AI调用服务 */ @@ -7,4 +11,6 @@ public interface AiCallService { String call(String prompt); + + Flux stream(Prompt prompt); } diff --git a/src/main/java/com/supervision/pdfqaserver/service/TripleToCypherExecutor.java b/src/main/java/com/supervision/pdfqaserver/service/TripleToCypherExecutor.java index b07725f..efb1586 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/TripleToCypherExecutor.java +++ b/src/main/java/com/supervision/pdfqaserver/service/TripleToCypherExecutor.java @@ -1,6 +1,10 @@ package com.supervision.pdfqaserver.service; +import com.supervision.pdfqaserver.dto.CypherSchemaDTO; +import com.supervision.pdfqaserver.dto.DomainMetadataDTO; import com.supervision.pdfqaserver.dto.EREDTO; +import java.util.List; +import java.util.Map; /** * 三元组转换为Cypher语句的执行器 @@ -15,19 +19,37 @@ public interface TripleToCypherExecutor { String generateInsertCypher(EREDTO eredto); + /** * 生成查询Cypher语句 - * @param query + * @param query 用户查询语句 + * @param domainCategoryId 领域分类ID * @return */ - String generateQueryCypher(String query); + String generateQueryCypher(String query,String domainCategoryId); /** * 执行Cypher语句 * @param cypher * @return */ - void executeCypher(String cypher); + List> executeCypher(String cypher); void saveERE(EREDTO eredto); + + + /** + * 加载图谱的schema + */ + CypherSchemaDTO loadGraphSchema(); + + + /** + * 根据领域元数据查询关联的关系图谱的schema + * @param metadataDTOS + * @return + */ + CypherSchemaDTO queryRelationSchema(List metadataDTOS); + + } diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java index 9540f6e..24202eb 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java @@ -39,120 +39,43 @@ public class ChatServiceImpl implements ChatService { private static final String PROMPT_PARAM_QUERY = "query"; private static final String CYPHER_QUERIES = "cypherQueries"; - private final Neo4jRepository neo4jRepository; private final OllamaChatModel ollamaChatModel; - private final DomainMetadataService domainMetadataService; private final AiCallService aiCallService; private final DocumentTruncationService documentTruncationService; + private final TripleToCypherExecutor tripleToCypherExecutor; - private final IntentionService intentionService; @Override public Flux knowledgeQA(String userQuery) { - List intentions = intentionService.listAllPassed(); - List relations = classifyIntents(userQuery, intentions); - if (CollUtil.isEmpty(relations)){ - log.info("没有匹配到意图,返回查无结果"); + log.info("用户查询: {}", userQuery); + // 生成cypher语句 + String cypher = tripleToCypherExecutor.generateQueryCypher(userQuery,null); + log.info("生成CYPHER语句的消息:{}", cypher); + if (StrUtil.isEmpty(cypher)){ return Flux.just("查无结果").concatWith(Flux.just("[END]")); } - List domainMetadataDTOS = domainMetadataService.listByIntentionIds(relations.stream().map(Intention::getId).toList()); - if (CollUtil.isEmpty(domainMetadataDTOS)){ - log.info("没有匹配到领域元数据,返回查无结果"); + // 执行cypher语句 + List> graphResult = tripleToCypherExecutor.executeCypher(cypher); + if (CollUtil.isEmpty(graphResult)){ return Flux.just("查无结果").concatWith(Flux.just("[END]")); } - //将三个集合分别转换为英文逗号分隔的字符串 - String sourceTypeListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getSourceType).distinct().collect(Collectors.joining(",")); - String relationListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getRelation).distinct().collect(Collectors.joining(",")); - String targetTypeListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getTargetType).distinct().collect(Collectors.joining(",")); - - //LLM生成CYPHER - SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER)); - Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_SOURCE_TYPE_LIST, sourceTypeListEn, - PROMPT_PARAM_RELATION_TYPE_LIST, relationListEn, PROMPT_PARAM_TARGET_TYPE_LIST, - targetTypeListEn, PROMPT_PARAM_QUERY, userQuery)); - log.info("生成CYPHER语句的消息:{}", textToCypherMessage); - String cypherJsonStr = ollamaChatModel.call(new Prompt(textToCypherMessage, OllamaOptions.builder().temperature(0.3).build())).getResult().getOutput().getText(); - - log.info(cypherJsonStr); - log.info(cypherJsonStr.replaceAll("(?is)]*>(.*?)", "").trim()); - cypherJsonStr = cypherJsonStr.replaceAll("(?is)]*>(.*?)", "").trim(); - List cypherQueries; - try { - JSONObject jsonObj = JSONUtil.parseObj(cypherJsonStr); - cypherQueries = jsonObj.getJSONArray(CYPHER_QUERIES) - .toList(String.class); - } catch (Exception e) { - log.error("解析CYPHER JSON字符串失败: {}", e.getMessage()); - return Flux.just("查无结果").concatWith(Flux.just("[END]")); - } - log.info("转换后的Cypher语句:{}", cypherQueries.toString()); - - //执行CYPHER查询并汇总结果 - List relationObjects = new ArrayList<>(); - if (!cypherQueries.isEmpty()) { - for (String cypher : cypherQueries) { - relationObjects.addAll(neo4jRepository.execute(cypher, null)); - } - } - if (relationObjects.isEmpty()) { - log.info("cypher没有查询到结果,返回查无结果"); - return Flux.just("查无结果").concatWith(Flux.just("[END]")); - } - log.info("三元组数据: {}", relationObjects); - //生成回答 SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER)); - Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, relationObjects, PROMPT_PARAM_QUERY, userQuery)); + Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, JSONUtil.toJsonStr(graphResult), PROMPT_PARAM_QUERY, userQuery)); log.info("生成回答的提示词:{}", generateAnswerMessage); - return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText()) - .concatWith(Flux.just(new JSONObject().set("answerDetails", convertToAnswerDetails(relationObjects)).toString())) + return aiCallService.stream(new Prompt(generateAnswerMessage)) + .map(response -> response.getResult().getOutput().getText()) + .concatWith(Flux.just(new JSONObject().set("answerDetails", convertToAnswerDetails(null)).toString())) .concatWith(Flux.just("[END]")); } - /** - * - * 分类查询意图 - * @param query 问题 - * @param intentions 意图列表 - * @return - */ - private List classifyIntents(String query, List intentions) { - if (StrUtil.isEmpty(query) || CollUtil.isEmpty(intentions)) { - return new ArrayList<>(); - } - String prompt = promptMap.get(CLASSIFY_QUERY_INTENT); - List result = new ArrayList<>(); - log.info("开始分类意图,query: {}, intentions size: {}", query, intentions.size()); - List> intentionSplit = CollUtil.split(intentions, 150); - for (List intentionList : intentionSplit) { - log.info("分类意图,query: {}, intentions size: {}", query, intentionList.size()); - String intents = intentionList.stream().map(i -> " - " + i.getDigest() + "\n").collect(Collectors.joining()); - Map params = Map.of("query", query, - "intents", intents); - String format = StrUtil.format(prompt, params); - String call = aiCallService.call(format); - - if (StrUtil.isEmpty(call)) { - return new ArrayList<>(); - } - List digests = JSONUtil.parseArray(call).stream().map(Object::toString).toList(); - if (CollUtil.isEmpty(digests)) { - continue; - } - List collect = intentionList.stream().filter(i -> digests.contains(i.getDigest())).collect(Collectors.toList()); - if (CollUtil.isNotEmpty(collect)) { - result.addAll(collect); - } - } - return result; - } private List convertToAnswerDetails(List relationObjects) { if (CollUtil.isEmpty(relationObjects)) { diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java index 7a664b1..370f35e 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java @@ -209,6 +209,7 @@ public class KnowledgeGraphServiceImpl implements KnowledgeGraphService { // 保存意图数据 intentSize ++; index ++; + List intentions = intentionService.batchSaveIfAbsent(intents, pdfInfo.getDomainCategoryId(), pdfId.toString()); for (Intention intention : intentions) { List metadataDTOS = domainMetadataDTOS.stream() diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java index 72da5c4..56985b0 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java @@ -5,46 +5,110 @@ import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONUtil; import com.supervision.pdfqaserver.cache.PromptCache; import com.supervision.pdfqaserver.dao.Neo4jRepository; -import com.supervision.pdfqaserver.dto.TruncationERAttributeDTO; -import com.supervision.pdfqaserver.dto.EREDTO; -import com.supervision.pdfqaserver.dto.EntityExtractionDTO; -import com.supervision.pdfqaserver.dto.RelationExtractionDTO; +import com.supervision.pdfqaserver.domain.Intention; +import com.supervision.pdfqaserver.dto.*; +import com.supervision.pdfqaserver.dto.neo4j.NodeDTO; +import com.supervision.pdfqaserver.dto.neo4j.PathDTO; +import com.supervision.pdfqaserver.dto.neo4j.RelationshipValueDTO; +import com.supervision.pdfqaserver.service.AiCallService; +import com.supervision.pdfqaserver.service.DomainMetadataService; +import com.supervision.pdfqaserver.service.IntentionService; import com.supervision.pdfqaserver.service.TripleToCypherExecutor; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.ollama.OllamaChatModel; +import org.neo4j.driver.Record; +import org.neo4j.driver.internal.InternalNode; +import org.neo4j.driver.internal.InternalRelationship; import org.springframework.stereotype.Service; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; - -import static com.supervision.pdfqaserver.cache.PromptCache.ERE_TO_INSERT_CYPHER; +import static com.supervision.pdfqaserver.cache.PromptCache.*; @Slf4j @Service @RequiredArgsConstructor public class TripleToCypherExecutorImpl implements TripleToCypherExecutor { - private final OllamaChatModel ollamaChatModel; - private final Neo4jRepository neo4jRepository; + + private final IntentionService intentionService; + private static volatile CypherSchemaDTO cypherSchemaDTO; + + private final AiCallService aiCallService; + + private final DomainMetadataService domainMetadataService; + @Override public String generateInsertCypher(EREDTO eredto) { String prompt = PromptCache.promptMap.get(ERE_TO_INSERT_CYPHER); - String call = ollamaChatModel.call(prompt); - return call; + return aiCallService.call(prompt); } @Override - public String generateQueryCypher(String query) { - return null; + public String generateQueryCypher(String query,String domainCategoryId) { + List intentions = intentionService.listAllPassed(); + List relations = classifyIntents(query, intentions); + if (CollUtil.isEmpty(relations)) { + log.info("没有找到匹配的意图,query: {}", query); + return null; + } + List domainMetadataDTOS = domainMetadataService.listByIntentionIds(relations.stream().map(Intention::getId).toList()); + CypherSchemaDTO schemaDTO = this.queryRelationSchema(domainMetadataDTOS); + String prompt = promptMap.get(TEXT_TO_CYPHER_2); + String format = StrUtil.format(prompt, Map.of("question", query, "schema", schemaDTO.format())); + return aiCallService.call(format); } @Override - public void executeCypher(String cypher) { + public List> executeCypher(String cypher) { + List records = neo4jRepository.executeCypherNative(cypher, null); + return mapRecords(records); + } + + + private List> mapRecords(List records) { + List> recordList = new ArrayList<>(); + for (Record record : records) { + HashMap map = new HashMap<>(); + for (String key : record.keys()) { + org.neo4j.driver.Value value = record.get(key); + String typeName = value.type().name(); + if (typeName.equals("NULL")){ + map.put(key,null); + } + + if (StrUtil.equalsAny(typeName, "BOOLEAN","STRING", "NUMBER", "INTEGER", "FLOAT")){ + // MATCH (n)-[r]-() where n.caseId= '1' RETURN n.recordId limit 10 + map.put(key,value.asObject()); + } + if (typeName.equals("PATH")){ + // MATCH p=(n)-[*2]-() where n.caseId= '1' RETURN p limit 10 + map.put(key,new PathDTO(value.asPath())); + } + + if (typeName.equals("RELATIONSHIP")){ + // MATCH (n)-[r]-() where n.caseId= '1' RETURN r limit 10 + map.put(key,new RelationshipValueDTO((InternalRelationship) value.asRelationship())); + } + if (typeName.equals("LIST OF ANY?")){ + + List list = value.asList().stream() + .map(i -> new RelationshipValueDTO((InternalRelationship) i)).toList(); + map.put(key,list); + } + if (typeName.equals("NODE")){ + // MATCH (n)-[r]-() where n.caseId= '1' RETURN r limit 10 + map.put(key,new NodeDTO((InternalNode) value.asNode())); + } + recordList.add(map); + } + } + return recordList; } @Override @@ -101,4 +165,101 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor { } } } + + @Override + public CypherSchemaDTO loadGraphSchema() { + + List relationSchema = neo4jRepository.getRelationSchema(); + List entitySchema = neo4jRepository.getNodeSchema(); + return new CypherSchemaDTO(entitySchema, relationSchema); + } + + @Override + public CypherSchemaDTO queryRelationSchema(List metadataDTOS) { + if (CollUtil.isEmpty(metadataDTOS)){ + return null; + } + if (cypherSchemaDTO == null) { + synchronized (TripleToCypherExecutorImpl.class) { + if (cypherSchemaDTO == null) { + cypherSchemaDTO = this.loadGraphSchema(); + } + } + } + List merged = new ArrayList<>(); + for (DomainMetadataDTO metadataDTO : metadataDTOS) { + String relation = metadataDTO.getRelation(); + String sourceType = metadataDTO.getSourceType(); + String targetType = metadataDTO.getTargetType(); + if (StrUtil.isEmpty(relation) || StrUtil.isEmpty(sourceType) || StrUtil.isEmpty(targetType)){ + log.warn("元数据中关系、源类型或目标类型为空,无法查询关系schema: {}", metadataDTO); + continue; + } + + RelationExtractionDTO rel = cypherSchemaDTO.getRelation(sourceType, relation, targetType); + if (null == rel){ + continue; + } + List relSourceType = cypherSchemaDTO.getRelations(sourceType); + List relTargetType = cypherSchemaDTO.getRelations(targetType); + + relSourceType.add(rel); + relSourceType.addAll(relTargetType); + for (RelationExtractionDTO relationExtractionDTO : relSourceType) { + boolean none = merged.stream().noneMatch(i -> StrUtil.equals(i.getRelation(), relationExtractionDTO.getRelation()) && + StrUtil.equals(i.getSourceType(), relationExtractionDTO.getSourceType()) && + StrUtil.equals(i.getTargetType(), relationExtractionDTO.getTargetType())); + if (none){ + merged.add(relationExtractionDTO); + } + } + } + List entityExtractionDTOS = new ArrayList<>(); + for (RelationExtractionDTO relationExtractionDTO : merged) { + EntityExtractionDTO node = cypherSchemaDTO.getNode(relationExtractionDTO.getSourceType()); + if (null != node){ + boolean none = entityExtractionDTOS.stream().noneMatch( + entity -> StrUtil.equals(entity.getEntity(), node.getEntity()) + ); + if (none) { + entityExtractionDTOS.add(node); + } + } + } + return new CypherSchemaDTO( + entityExtractionDTOS, + merged + ); + } + + private List classifyIntents(String query, List intentions) { + if (StrUtil.isEmpty(query) || CollUtil.isEmpty(intentions)) { + return new ArrayList<>(); + } + String prompt = promptMap.get(CLASSIFY_QUERY_INTENT); + List result = new ArrayList<>(); + log.info("开始分类意图,query: {}, intentions size: {}", query, intentions.size()); + List> intentionSplit = CollUtil.split(intentions, 200); + for (List intentionList : intentionSplit) { + log.info("分类意图,query: {}, intentions size: {}", query, intentionList.size()); + String intents = intentionList.stream().map(i -> " - " + i.getDigest() + "\n").collect(Collectors.joining()); + Map params = Map.of("query", query, + "intents", intents); + String format = StrUtil.format(prompt, params); + String call = aiCallService.call(format); + + if (StrUtil.isEmpty(call)) { + return new ArrayList<>(); + } + List digests = JSONUtil.parseArray(call).stream().map(Object::toString).toList(); + if (CollUtil.isEmpty(digests)) { + continue; + } + List collect = intentionList.stream().filter(i -> digests.contains(i.getDigest())).collect(Collectors.toList()); + if (CollUtil.isNotEmpty(collect)) { + result.addAll(collect); + } + } + return result; + } }