From e181c00c402d2e03e3d5a06ef41d0d3d9f25b755 Mon Sep 17 00:00:00 2001
From: xueqingkun <xueqingkun@126.com>
Date: Thu, 19 Jun 2025 15:45:29 +0800
Subject: [PATCH] =?UTF-8?q?=E9=97=AE=E7=AD=94=E5=8A=9F=E8=83=BD=E4=BC=98?=
 =?UTF-8?q?=E5=8C=96-=E5=88=9D=E5=A7=8B=E5=8C=96=E8=A1=A8?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 .../pdfqaserver/cache/PromptCache.java        | 48 ++++++++-----------
 .../pdfqaserver/dto/CypherSchemaDTO.java      | 22 +++++++--
 .../supervision/pdfqaserver/dto/TextTerm.java |  7 ++-
 .../service/impl/ChatServiceImpl.java         | 40 ++++++++++++++--
 .../service/impl/DataCompareRetriever.java    | 32 ++-----------
 .../impl/TripleToCypherExecutorImpl.java      | 22 +++++++--
 6 files changed, 103 insertions(+), 68 deletions(-)

diff --git a/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java b/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java
index b3a98df..dc28f87 100644
--- a/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java
+++ b/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java
@@ -850,7 +850,7 @@ public class PromptCache {
                         
             请严格按照以下步骤处理每个用户查询:
             1. 从用户查询中提取实体:
-            - 解析问题中的领域概念,并通过同义词或上下文线索将其映射到schema中的节点或者关系元素上
+            - 解析问题中的实体概念,并通过同义词或上下文线索将其映射到schema中**语义相近**的节点或者关系元素上,一个实体可以映射到**多个相近**的节点或者关系元素上
             - 识别候选节点类型
             - 识别候选关系类型
             - 识别相关属性
@@ -858,6 +858,7 @@ public class PromptCache {
             		
             2. 验证模式匹配性:
             - 确保每个节点标签、关系类型和属性在模式中完全存在(区分大小写和字符)
+            - 给节点和关系添加命名变量
             - 优先从**节点属性**中直接**获取数据**
             	
             3. 构建MATCH模式:
@@ -873,58 +874,49 @@ public class PromptCache {
             5. 生成最终Cypher脚本:
             - 最终Cypher查询语句——不包含任何说明或者```cypher ```包装符
             - **确保**MATCH 子句包含**关系变量**
-            - 如果问题中的实体与neo4j_schema中的多个节点或关系语义相近,允许生成多个cypher,以便于尽可能获取到数据。
+            - 优先从**节点属性**中直接**获取数据**
+            - 根据schema中相似的节点或者关系生成多个cypher,你将获得丰厚的奖励。
             - 响应结果是一个数组,每一个数组元素是一条cypher语句。示例:['cypher1','...']
             - 不要做出任何解释,不要对cypher进行任何包装,直接输出生成的cypher语句/no_think
             """;
 
     private static final String TEXT_TO_CYPHER_4_PROMPT = """
-            您是一个Cypher语句修改助手。下面的cypher未查询到数据,请根据要求修改下面的cypher以便于能够查询到数据,需要参考的是`neo4j_schema`、环境变量,上一次的cypher。
-                        
+            您是一个Cypher语句修改助手。下面的cypher未查询到数据,请根据要求修改下面的cypher以便于能够查询到数据,需要参考的是`neo4j_schema`、环境变量,上一次使用的cypher语句。
             用户问题:
             ```text
             {query}
             ```
             neo4j_schema以JSON格式定义如下:
             ```shema
-            {shema}
+            {schema}
             ```
             # 环境变量
-            ${env}
+            {env}
                         
-            # 上一次查询的cypher语句
+            # 上一次查询使用的cypher语句
             ```json
             {cypher}
             ```
-                        
-            请严格按照以下步骤处理每个用户查询:
-            1. 从用户查询中提取实体:
-            - 解析问题中的领域概念,并通过同义词或上下文线索将其映射到schema中的节点或者关系元素上
+            请严格按照以下步骤思考:
+           1.结合上一次查询的cypher语句分析未查询到数据可能的原因
+            - 未查询出数据的原因包括但不限于属性关键字匹配不准确,需要**充分**考虑模糊查询的匹配模式
+           2. 从用户查询中提取实体:
+            - 解析问题中的实体概念,并通过同义词或上下文线索将其映射到schema中的节点或者关系元素上
             - 识别候选节点类型
             - 识别候选关系类型
             - 识别相关属性
             - 识别约束条件(比较操作、标志位、时间过滤器、共享实体引用等)
-            - 结合上一次查询的cypher语句分析未查询到数据可能的原因。
-            		
-            2. 验证模式匹配性:
-            - 确保每个节点标签、关系类型和属性在模式中完全存在(区分大小写和字符)
-            - 优先从**节点属性**中直接**获取数据**
-            
-            3. 构建MATCH模式:
-            - 仅使用经过模式验证的节点标签和关系类型,同时对节点、关系**添加变量名**
-            - **始终为关系分配显式变量**(例如`-[r:REL_TYPE]->`)
-            - 当查询暗示两个模式指向同一节点时,重复使用同一变量
-            - 在映射模式中表达简单等值谓词,其他过滤条件移至WHERE子句
-            - 结合分析未查询到数据的原因修改cypher
-            		
-            4. RETURN子句策略:
+           
+           2. 构建MATCH模式:
+            - 把MATCH子句中过滤条件修改为通过WHERE过滤的方式
+           
+           3. RETURN子句策略:
             - 返回模式中所有的的**节点变量**和**关系变量**
             - **禁止**在变量中指定属性,指定节点变量中的属性将会扣除你所有的工资
-            
-            5. 生成最终Cypher脚本:
+           
+           4. 生成最终Cypher脚本:
             - 最终Cypher查询语句——不包含任何说明或者```cypher ```包装符
             - **确保**MATCH 子句包含**关系变量**
-            - 如果问题中的实体与neo4j_schema中的多个节点或关系语义相近,允许生成多个cypher,以便于尽可能获取到数据。
             - 响应结果是一个数组,每一个数组元素是一条cypher语句。示例:['cypher1','...']
             - 不要做出任何解释,不要对cypher进行任何包装,直接输出生成的cypher语句/no_think
             """;
diff --git a/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java b/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java
index 77c8c19..3dae615 100644
--- a/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java
+++ b/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java
@@ -3,9 +3,9 @@ package com.supervision.pdfqaserver.dto;
 import cn.hutool.core.util.StrUtil;
 import cn.hutool.json.JSONArray;
 import cn.hutool.json.JSONObject;
-
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Map;
 
 /**
  * CypherSchemaDTO
@@ -118,7 +118,24 @@ public class CypherSchemaDTO {
                 json = new JSONObject();
                 relJson.set(rela, json);
             }
-            json.set("_endpoints", new JSONArray(new String[]{sourceType, targetType}));
+            JSONArray endpoints = json.getJSONArray("_endpoints");
+            if (null == endpoints){
+                endpoints = new JSONArray();
+                endpoints.add(Map.of("sourceType", sourceType, "targetType", targetType));
+                json.set("_endpoints", endpoints);
+            }else {
+                boolean absent = false;
+                for (Object endpoint : endpoints) {
+                    Map<String,Object> nodes = (Map<String, Object>) endpoint;
+                    if (sourceType.equals(nodes.get("sourceType"))|| sourceType.equals(nodes.get("targetType"))){
+                        absent = true;
+                        break;
+                    }
+                }
+                if (absent){
+                    endpoints.add(Map.of("sourceType", sourceType, "targetType", targetType));
+                }
+            }
             for (TruncationERAttributeDTO attribute : relation.getAttributes()) {
                 if ("truncationId".equals(attribute.getAttribute())){
                     continue;
@@ -127,7 +144,6 @@ public class CypherSchemaDTO {
                         entry -> StrUtil.equals(entry.getKey(), attribute.getAttribute())
                 );
                 if (none) {
-
                     json.set(attribute.getAttribute(), attribute.getDataType());
                 }
             }
diff --git a/src/main/java/com/supervision/pdfqaserver/dto/TextTerm.java b/src/main/java/com/supervision/pdfqaserver/dto/TextTerm.java
index 1edd9b8..e84671c 100644
--- a/src/main/java/com/supervision/pdfqaserver/dto/TextTerm.java
+++ b/src/main/java/com/supervision/pdfqaserver/dto/TextTerm.java
@@ -1,7 +1,9 @@
 package com.supervision.pdfqaserver.dto;
 
+import cn.hutool.core.collection.CollUtil;
 import cn.hutool.core.util.StrUtil;
 import lombok.Data;
+import java.util.List;
 
 @Data
 public class TextTerm {
@@ -18,7 +20,10 @@ public class TextTerm {
 
     private float[] embedding;
 
-    public String getLabelValue() {
+    public String getLabelValue(List<String> keyWords) {
+        if (CollUtil.isNotEmpty(keyWords) && keyWords.contains(word)) {
+            return word;
+        }
         if (StrUtil.equalsAny(label,"n","nl","nr","ns","nsf","nz")){
             return word;
         }
diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java
index b08e0aa..5b98413 100644
--- a/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java
+++ b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java
@@ -1,5 +1,6 @@
 package com.supervision.pdfqaserver.service.impl;
 
+import cn.hutool.core.bean.BeanUtil;
 import cn.hutool.core.collection.CollUtil;
 import cn.hutool.core.util.StrUtil;
 import cn.hutool.json.JSONObject;
@@ -18,6 +19,7 @@ import org.springframework.ai.chat.prompt.SystemPromptTemplate;
 import org.springframework.stereotype.Service;
 import reactor.core.publisher.Flux;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
@@ -58,14 +60,14 @@ public class ChatServiceImpl implements ChatService {
 
         // 执行cypher语句
         List<Map<String, Object>> graphResult = tripleToCypherExecutor.executeCypher(cypher);
+        */
+        List<Map<String, Object>> graphResult = compareRetriever.retrieval(userQuery);
         if (CollUtil.isEmpty(graphResult)){
             return Flux.just("查无结果").concatWith(Flux.just("[END]"));
-        }*/
-        List<Map<String, Object>> graphResult = compareRetriever.retrieval(userQuery);
-
+        }
         //生成回答
         SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER));
-        Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, JSONUtil.toJsonStr(graphResult), PROMPT_PARAM_QUERY, userQuery));
+        Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, JSONUtil.toJsonStr(clearGraphElements(graphResult)), PROMPT_PARAM_QUERY, userQuery));
         log.info("生成回答的提示词:{}", generateAnswerMessage);
         return aiCallService.stream(new Prompt(generateAnswerMessage))
                 .map(response -> response.getResult().getOutput().getText())
@@ -167,4 +169,34 @@ public class ChatServiceImpl implements ChatService {
         }
         return distinct;
     }
+
+    /**
+     * 清理图谱元素中的无效数据
+     * @param graphElements 图谱元素列表
+     * @return
+     */
+    public List<Map<String, Object>> clearGraphElements(List<Map<String, Object>> graphElements) {
+        if (CollUtil.isEmpty(graphElements)) {
+            return graphElements;
+        }
+        List<Map<String, Object>> result = new ArrayList<>(graphElements.size());
+        for (Map<String, Object> originalMap : graphElements) {
+            Map<String, Object> newMap = new HashMap<>();
+            for (Map.Entry<String, Object> entry : originalMap.entrySet()) {
+                String key = entry.getKey();
+                Object value = entry.getValue();
+                if (value instanceof NodeDTO nodeDTO){
+                    NodeDTO newNodeDTO = BeanUtil.copyProperties(nodeDTO, NodeDTO.class);
+                    newNodeDTO.clearGraphElement(); // 清理图谱元素
+                    newMap.put(key, newNodeDTO);
+                } else if (value instanceof RelationshipValueDTO relationshipValueDTO) {
+                    RelationshipValueDTO newRelationshipValueDTO = BeanUtil.copyProperties(relationshipValueDTO, RelationshipValueDTO.class);
+                    newRelationshipValueDTO.clearGraphElement(); // 清理图谱元素
+                    newMap.put(key, newRelationshipValueDTO);
+                }
+            }
+            result.add(newMap);
+        }
+        return result;
+    }
 }
diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/DataCompareRetriever.java b/src/main/java/com/supervision/pdfqaserver/service/impl/DataCompareRetriever.java
index cdfa05c..099826b 100644
--- a/src/main/java/com/supervision/pdfqaserver/service/impl/DataCompareRetriever.java
+++ b/src/main/java/com/supervision/pdfqaserver/service/impl/DataCompareRetriever.java
@@ -62,7 +62,7 @@ public class DataCompareRetriever implements Retriever {
             boolean allEmpty = cypherData.values().stream().noneMatch(CollUtil::isNotEmpty);
             if (!allEmpty){
                 cypherData.values().stream().filter(CollUtil::isNotEmpty).forEach(result::addAll);
-                return clearGraphElements(result);
+                return result;
             }
         }
         if (CollUtil.isEmpty(result)){
@@ -70,8 +70,8 @@ public class DataCompareRetriever implements Retriever {
             prompt = PromptCache.promptMap.get(TEXT_TO_CYPHER_4);
             format = StrUtil.format(prompt,
                     Map.of("query", query, "schema", schemaDTO.format(),
-                            "env", "- 当前时间是:" + DateUtil.now()),"cypher",js.toString());
-            log.info("retrieval: 生成的cypher语句:{}", format);
+                            "env", "- 当前时间是:" + DateUtil.now(),"cypher",js.toString()));
+            log.info("retrieval: 生成cypher的语句:{}", format);
             call = aiCallService.call(format);
             log.info("retrieval: AI调用返回结果:{}", call);
 
@@ -81,33 +81,11 @@ public class DataCompareRetriever implements Retriever {
                 boolean allEmpty2 = cypherData.values().stream().noneMatch(CollUtil::isNotEmpty);
                 if (!allEmpty2){
                     cypherData.values().stream().filter(CollUtil::isNotEmpty).forEach(result::addAll);
-                    return clearGraphElements(result);
+                    return result;
                 }
             }
         }
 
-        return clearGraphElements(result);
-    }
-
-    /**
-     * 清理图谱元素中的无效数据
-     * @param graphElements 图谱元素列表
-     * @return
-     */
-    private List<Map<String, Object>> clearGraphElements(List<Map<String, Object>> graphElements) {
-        if (CollUtil.isEmpty(graphElements)){
-            return graphElements;
-        }
-        for (Map<String, Object> element : graphElements) {
-            for (Object value : element.values()) {
-                if (value instanceof RelationshipValueDTO relationshipValueDTO) {
-                    relationshipValueDTO.clearGraphElement();
-                }
-                if (value instanceof com.supervision.pdfqaserver.dto.neo4j.NodeDTO nodeDTO) {
-                    nodeDTO.clearGraphElement();
-                }
-            }
-        }
-        return graphElements;
+        return result;
     }
 }
diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java
index 3d54bf6..dd7b886 100644
--- a/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java
+++ b/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java
@@ -26,6 +26,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
+import java.util.stream.Stream;
 import static com.supervision.pdfqaserver.cache.PromptCache.*;
 
 @Slf4j
@@ -272,15 +273,16 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor {
         log.info("queryRelationSchema: 分词结果:{}", terms);
         log.info("queryRelationSchema: 开始进行文本标签向量匹配...");
         List<NodeRelationVector> matchedText = new ArrayList<>();
+        List<String> keywords = mergeNodeAndRelationLabel();
         for (TextTerm term : terms) {
-            if (StrUtil.isEmpty(term.getLabelValue())){
+            if (StrUtil.isEmpty(term.getLabelValue(keywords))){
                 log.info("queryRelationSchema: 分词结果`{}`不是关键标签,跳过...", term.getWord());
                 continue;
             }
-            Embedding embedding = aiCallService.embedding(term.getLabelValue());
+            Embedding embedding = aiCallService.embedding(term.getLabelValue(keywords));
             term.setEmbedding(embedding.getOutput());
             List<NodeRelationVector> textVectorDTOS = nodeRelationVectorService.matchSimilarByCosine(embedding.getOutput(), 0.9, List.of("N","R"),3); // 继续过滤
-            log.info("retrieval: 文本:{}匹配到的文本向量:{}", term.getWord() ,textVectorDTOS.stream().map(NodeRelationVector::getContent).collect(Collectors.joining(" ")));
+            log.info("retrieval: 文本:`{}`匹配到的文本向量:`{}`", term.getWord() ,textVectorDTOS.stream().map(NodeRelationVector::getContent).collect(Collectors.joining(" ")));
             matchedText.addAll(textVectorDTOS);
         }
         if (CollUtil.isEmpty(matchedText)){
@@ -306,7 +308,7 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor {
         // 对查询到的关系进行重排序
         List<Pair<Double, RelationExtractionDTO>> pairs = new ArrayList<>();
         TimeInterval timeInterval = new TimeInterval();
-        String join = terms.stream().map(TextTerm::getLabelValue).filter(StrUtil::isNotEmpty).collect(Collectors.joining());
+        String join = terms.stream().map(t->t.getLabelValue(keywords)).filter(StrUtil::isNotEmpty).collect(Collectors.joining());
         Embedding embedding = aiCallService.embedding(join);
         for (RelationExtractionDTO relation : merged) {
             String content = relation.getSourceType() + " " + relation.getRelation() + " " + relation.getTargetType();
@@ -319,7 +321,7 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor {
         }
         log.info("queryRelationSchema: 关系排序耗时:{}ms", timeInterval.intervalMs());
 
-        merged = pairs.stream().sorted((p1, p2) -> Double.compare(p2.getKey(), p1.getKey())).limit(5).map(Pair::getValue).toList();
+        merged = pairs.stream().sorted((p1, p2) -> Double.compare(p2.getKey(), p1.getKey())).limit(4).map(Pair::getValue).toList();
         List<EntityExtractionDTO> entityExtractionDTOS = new ArrayList<>();
         for (RelationExtractionDTO relationExtractionDTO : merged) {
             EntityExtractionDTO sourceNode = cypherSchemaDTO.getNode(relationExtractionDTO.getSourceType());
@@ -392,4 +394,14 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor {
             }
         }
     }
+
+    private List<String> mergeNodeAndRelationLabel() {
+        loadCypherSchemaIfAbsent();
+        if (CollUtil.isEmpty(cypherSchemaDTO.getRelations())) {
+            log.warn("图谱schema数据为空,无法合并节点和关系标签");
+            return new ArrayList<>();
+        }
+        return  cypherSchemaDTO.getRelations().stream()
+                .flatMap(r -> Stream.of(r.getSourceType(), r.getRelation(), r.getTargetType())).distinct().collect(Collectors.toList());
+    }
 }