diff --git a/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java b/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java index b3a98df..dc28f87 100644 --- a/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java +++ b/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java @@ -850,7 +850,7 @@ public class PromptCache { 请严格按照以下步骤处理每个用户查询: 1. 从用户查询中提取实体: - - 解析问题中的领域概念,并通过同义词或上下文线索将其映射到schema中的节点或者关系元素上 + - 解析问题中的实体概念,并通过同义词或上下文线索将其映射到schema中**语义相近**的节点或者关系元素上,一个实体可以映射到**多个相近**的节点或者关系元素上 - 识别候选节点类型 - 识别候选关系类型 - 识别相关属性 @@ -858,6 +858,7 @@ public class PromptCache { 2. 验证模式匹配性: - 确保每个节点标签、关系类型和属性在模式中完全存在(区分大小写和字符) + - 给节点和关系添加命名变量 - 优先从**节点属性**中直接**获取数据** 3. 构建MATCH模式: @@ -873,58 +874,49 @@ public class PromptCache { 5. 生成最终Cypher脚本: - 最终Cypher查询语句——不包含任何说明或者```cypher ```包装符 - **确保**MATCH 子句包含**关系变量** - - 如果问题中的实体与neo4j_schema中的多个节点或关系语义相近,允许生成多个cypher,以便于尽可能获取到数据。 + - 优先从**节点属性**中直接**获取数据** + - 根据schema中相似的节点或者关系生成多个cypher,你将获得丰厚的奖励。 - 响应结果是一个数组,每一个数组元素是一条cypher语句。示例:['cypher1','...'] - 不要做出任何解释,不要对cypher进行任何包装,直接输出生成的cypher语句/no_think """; private static final String TEXT_TO_CYPHER_4_PROMPT = """ - 您是一个Cypher语句修改助手。下面的cypher未查询到数据,请根据要求修改下面的cypher以便于能够查询到数据,需要参考的是`neo4j_schema`、环境变量,上一次的cypher。 - + 您是一个Cypher语句修改助手。下面的cypher未查询到数据,请根据要求修改下面的cypher以便于能够查询到数据,需要参考的是`neo4j_schema`、环境变量,上一次使用的cypher语句。 用户问题: ```text {query} ``` neo4j_schema以JSON格式定义如下: ```shema - {shema} + {schema} ``` # 环境变量 - ${env} + {env} - # 上一次查询的cypher语句 + # 上一次查询使用的cypher语句 ```json {cypher} ``` - - 请严格按照以下步骤处理每个用户查询: - 1. 从用户查询中提取实体: - - 解析问题中的领域概念,并通过同义词或上下文线索将其映射到schema中的节点或者关系元素上 + 请严格按照以下步骤思考: + 1.结合上一次查询的cypher语句分析未查询到数据可能的原因 + - 未查询出数据的原因包括但不限于属性关键字匹配不准确,需要**充分**考虑模糊查询的匹配模式 + 2. 从用户查询中提取实体: + - 解析问题中的实体概念,并通过同义词或上下文线索将其映射到schema中的节点或者关系元素上 - 识别候选节点类型 - 识别候选关系类型 - 识别相关属性 - 识别约束条件(比较操作、标志位、时间过滤器、共享实体引用等) - - 结合上一次查询的cypher语句分析未查询到数据可能的原因。 - - 2. 验证模式匹配性: - - 确保每个节点标签、关系类型和属性在模式中完全存在(区分大小写和字符) - - 优先从**节点属性**中直接**获取数据** - - 3. 构建MATCH模式: - - 仅使用经过模式验证的节点标签和关系类型,同时对节点、关系**添加变量名** - - **始终为关系分配显式变量**(例如`-[r:REL_TYPE]->`) - - 当查询暗示两个模式指向同一节点时,重复使用同一变量 - - 在映射模式中表达简单等值谓词,其他过滤条件移至WHERE子句 - - 结合分析未查询到数据的原因修改cypher - - 4. RETURN子句策略: + + 2. 构建MATCH模式: + - 把MATCH子句中过滤条件修改为通过WHERE过滤的方式 + + 3. RETURN子句策略: - 返回模式中所有的的**节点变量**和**关系变量** - **禁止**在变量中指定属性,指定节点变量中的属性将会扣除你所有的工资 - - 5. 生成最终Cypher脚本: + + 4. 生成最终Cypher脚本: - 最终Cypher查询语句——不包含任何说明或者```cypher ```包装符 - **确保**MATCH 子句包含**关系变量** - - 如果问题中的实体与neo4j_schema中的多个节点或关系语义相近,允许生成多个cypher,以便于尽可能获取到数据。 - 响应结果是一个数组,每一个数组元素是一条cypher语句。示例:['cypher1','...'] - 不要做出任何解释,不要对cypher进行任何包装,直接输出生成的cypher语句/no_think """; diff --git a/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java b/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java index 77c8c19..3dae615 100644 --- a/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java +++ b/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java @@ -3,9 +3,9 @@ package com.supervision.pdfqaserver.dto; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONArray; import cn.hutool.json.JSONObject; - import java.util.ArrayList; import java.util.List; +import java.util.Map; /** * CypherSchemaDTO @@ -118,7 +118,24 @@ public class CypherSchemaDTO { json = new JSONObject(); relJson.set(rela, json); } - json.set("_endpoints", new JSONArray(new String[]{sourceType, targetType})); + JSONArray endpoints = json.getJSONArray("_endpoints"); + if (null == endpoints){ + endpoints = new JSONArray(); + endpoints.add(Map.of("sourceType", sourceType, "targetType", targetType)); + json.set("_endpoints", endpoints); + }else { + boolean absent = false; + for (Object endpoint : endpoints) { + Map nodes = (Map) endpoint; + if (sourceType.equals(nodes.get("sourceType"))|| sourceType.equals(nodes.get("targetType"))){ + absent = true; + break; + } + } + if (absent){ + endpoints.add(Map.of("sourceType", sourceType, "targetType", targetType)); + } + } for (TruncationERAttributeDTO attribute : relation.getAttributes()) { if ("truncationId".equals(attribute.getAttribute())){ continue; @@ -127,7 +144,6 @@ public class CypherSchemaDTO { entry -> StrUtil.equals(entry.getKey(), attribute.getAttribute()) ); if (none) { - json.set(attribute.getAttribute(), attribute.getDataType()); } } diff --git a/src/main/java/com/supervision/pdfqaserver/dto/TextTerm.java b/src/main/java/com/supervision/pdfqaserver/dto/TextTerm.java index 1edd9b8..e84671c 100644 --- a/src/main/java/com/supervision/pdfqaserver/dto/TextTerm.java +++ b/src/main/java/com/supervision/pdfqaserver/dto/TextTerm.java @@ -1,7 +1,9 @@ package com.supervision.pdfqaserver.dto; +import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.StrUtil; import lombok.Data; +import java.util.List; @Data public class TextTerm { @@ -18,7 +20,10 @@ public class TextTerm { private float[] embedding; - public String getLabelValue() { + public String getLabelValue(List keyWords) { + if (CollUtil.isNotEmpty(keyWords) && keyWords.contains(word)) { + return word; + } if (StrUtil.equalsAny(label,"n","nl","nr","ns","nsf","nz")){ return word; } diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java index b08e0aa..5b98413 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java @@ -1,5 +1,6 @@ package com.supervision.pdfqaserver.service.impl; +import cn.hutool.core.bean.BeanUtil; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONObject; @@ -18,6 +19,7 @@ import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -58,14 +60,14 @@ public class ChatServiceImpl implements ChatService { // 执行cypher语句 List> graphResult = tripleToCypherExecutor.executeCypher(cypher); + */ + List> graphResult = compareRetriever.retrieval(userQuery); if (CollUtil.isEmpty(graphResult)){ return Flux.just("查无结果").concatWith(Flux.just("[END]")); - }*/ - List> graphResult = compareRetriever.retrieval(userQuery); - + } //生成回答 SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER)); - Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, JSONUtil.toJsonStr(graphResult), PROMPT_PARAM_QUERY, userQuery)); + Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, JSONUtil.toJsonStr(clearGraphElements(graphResult)), PROMPT_PARAM_QUERY, userQuery)); log.info("生成回答的提示词:{}", generateAnswerMessage); return aiCallService.stream(new Prompt(generateAnswerMessage)) .map(response -> response.getResult().getOutput().getText()) @@ -167,4 +169,34 @@ public class ChatServiceImpl implements ChatService { } return distinct; } + + /** + * 清理图谱元素中的无效数据 + * @param graphElements 图谱元素列表 + * @return + */ + public List> clearGraphElements(List> graphElements) { + if (CollUtil.isEmpty(graphElements)) { + return graphElements; + } + List> result = new ArrayList<>(graphElements.size()); + for (Map originalMap : graphElements) { + Map newMap = new HashMap<>(); + for (Map.Entry entry : originalMap.entrySet()) { + String key = entry.getKey(); + Object value = entry.getValue(); + if (value instanceof NodeDTO nodeDTO){ + NodeDTO newNodeDTO = BeanUtil.copyProperties(nodeDTO, NodeDTO.class); + newNodeDTO.clearGraphElement(); // 清理图谱元素 + newMap.put(key, newNodeDTO); + } else if (value instanceof RelationshipValueDTO relationshipValueDTO) { + RelationshipValueDTO newRelationshipValueDTO = BeanUtil.copyProperties(relationshipValueDTO, RelationshipValueDTO.class); + newRelationshipValueDTO.clearGraphElement(); // 清理图谱元素 + newMap.put(key, newRelationshipValueDTO); + } + } + result.add(newMap); + } + return result; + } } diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/DataCompareRetriever.java b/src/main/java/com/supervision/pdfqaserver/service/impl/DataCompareRetriever.java index cdfa05c..099826b 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/DataCompareRetriever.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/DataCompareRetriever.java @@ -62,7 +62,7 @@ public class DataCompareRetriever implements Retriever { boolean allEmpty = cypherData.values().stream().noneMatch(CollUtil::isNotEmpty); if (!allEmpty){ cypherData.values().stream().filter(CollUtil::isNotEmpty).forEach(result::addAll); - return clearGraphElements(result); + return result; } } if (CollUtil.isEmpty(result)){ @@ -70,8 +70,8 @@ public class DataCompareRetriever implements Retriever { prompt = PromptCache.promptMap.get(TEXT_TO_CYPHER_4); format = StrUtil.format(prompt, Map.of("query", query, "schema", schemaDTO.format(), - "env", "- 当前时间是:" + DateUtil.now()),"cypher",js.toString()); - log.info("retrieval: 生成的cypher语句:{}", format); + "env", "- 当前时间是:" + DateUtil.now(),"cypher",js.toString())); + log.info("retrieval: 生成cypher的语句:{}", format); call = aiCallService.call(format); log.info("retrieval: AI调用返回结果:{}", call); @@ -81,33 +81,11 @@ public class DataCompareRetriever implements Retriever { boolean allEmpty2 = cypherData.values().stream().noneMatch(CollUtil::isNotEmpty); if (!allEmpty2){ cypherData.values().stream().filter(CollUtil::isNotEmpty).forEach(result::addAll); - return clearGraphElements(result); + return result; } } } - return clearGraphElements(result); - } - - /** - * 清理图谱元素中的无效数据 - * @param graphElements 图谱元素列表 - * @return - */ - private List> clearGraphElements(List> graphElements) { - if (CollUtil.isEmpty(graphElements)){ - return graphElements; - } - for (Map element : graphElements) { - for (Object value : element.values()) { - if (value instanceof RelationshipValueDTO relationshipValueDTO) { - relationshipValueDTO.clearGraphElement(); - } - if (value instanceof com.supervision.pdfqaserver.dto.neo4j.NodeDTO nodeDTO) { - nodeDTO.clearGraphElement(); - } - } - } - return graphElements; + return result; } } diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java index 3d54bf6..dd7b886 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java @@ -26,6 +26,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import java.util.stream.Stream; import static com.supervision.pdfqaserver.cache.PromptCache.*; @Slf4j @@ -272,15 +273,16 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor { log.info("queryRelationSchema: 分词结果:{}", terms); log.info("queryRelationSchema: 开始进行文本标签向量匹配..."); List matchedText = new ArrayList<>(); + List keywords = mergeNodeAndRelationLabel(); for (TextTerm term : terms) { - if (StrUtil.isEmpty(term.getLabelValue())){ + if (StrUtil.isEmpty(term.getLabelValue(keywords))){ log.info("queryRelationSchema: 分词结果`{}`不是关键标签,跳过...", term.getWord()); continue; } - Embedding embedding = aiCallService.embedding(term.getLabelValue()); + Embedding embedding = aiCallService.embedding(term.getLabelValue(keywords)); term.setEmbedding(embedding.getOutput()); List textVectorDTOS = nodeRelationVectorService.matchSimilarByCosine(embedding.getOutput(), 0.9, List.of("N","R"),3); // 继续过滤 - log.info("retrieval: 文本:{}匹配到的文本向量:{}", term.getWord() ,textVectorDTOS.stream().map(NodeRelationVector::getContent).collect(Collectors.joining(" "))); + log.info("retrieval: 文本:`{}`匹配到的文本向量:`{}`", term.getWord() ,textVectorDTOS.stream().map(NodeRelationVector::getContent).collect(Collectors.joining(" "))); matchedText.addAll(textVectorDTOS); } if (CollUtil.isEmpty(matchedText)){ @@ -306,7 +308,7 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor { // 对查询到的关系进行重排序 List> pairs = new ArrayList<>(); TimeInterval timeInterval = new TimeInterval(); - String join = terms.stream().map(TextTerm::getLabelValue).filter(StrUtil::isNotEmpty).collect(Collectors.joining()); + String join = terms.stream().map(t->t.getLabelValue(keywords)).filter(StrUtil::isNotEmpty).collect(Collectors.joining()); Embedding embedding = aiCallService.embedding(join); for (RelationExtractionDTO relation : merged) { String content = relation.getSourceType() + " " + relation.getRelation() + " " + relation.getTargetType(); @@ -319,7 +321,7 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor { } log.info("queryRelationSchema: 关系排序耗时:{}ms", timeInterval.intervalMs()); - merged = pairs.stream().sorted((p1, p2) -> Double.compare(p2.getKey(), p1.getKey())).limit(5).map(Pair::getValue).toList(); + merged = pairs.stream().sorted((p1, p2) -> Double.compare(p2.getKey(), p1.getKey())).limit(4).map(Pair::getValue).toList(); List entityExtractionDTOS = new ArrayList<>(); for (RelationExtractionDTO relationExtractionDTO : merged) { EntityExtractionDTO sourceNode = cypherSchemaDTO.getNode(relationExtractionDTO.getSourceType()); @@ -392,4 +394,14 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor { } } } + + private List mergeNodeAndRelationLabel() { + loadCypherSchemaIfAbsent(); + if (CollUtil.isEmpty(cypherSchemaDTO.getRelations())) { + log.warn("图谱schema数据为空,无法合并节点和关系标签"); + return new ArrayList<>(); + } + return cypherSchemaDTO.getRelations().stream() + .flatMap(r -> Stream.of(r.getSourceType(), r.getRelation(), r.getTargetType())).distinct().collect(Collectors.toList()); + } }