From 56771401724eacca20f5e9ab94a4784df617a073 Mon Sep 17 00:00:00 2001 From: daixiaoyi Date: Wed, 7 May 2025 13:57:44 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E7=A4=BA=E8=AF=8D=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../pdfqaserver/cache/PromptCache.java | 75 ++++++++++++++++-- .../config/OllamaChatModelAspect.java | 15 +--- .../pdfqaserver/dao/Neo4jRepository.java | 4 +- .../service/impl/ChatServiceImpl.java | 78 ++++++++++++------- 4 files changed, 125 insertions(+), 47 deletions(-) diff --git a/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java b/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java index 5e2d485..cc843f3 100644 --- a/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java +++ b/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java @@ -191,15 +191,76 @@ public class PromptCache { """; private static final String TEXT_TO_CYPHER_PROMPT = """ - 结合给你的领域元数据,分析用户输入的问题,尝试将其转换为CYPHER语句。 - 领域元数据:{domainMetadata} - 用户输入的问题:{userQuery} - """; + 你是一个专业的 Neo4j Cypher 查询语句生成器,专门用于构建针对特定结构的查询语句。 + + --- + + **【数据库结构说明】** + + - **关系类型(relationType)**: + {relationTypeList} + + - **源节点类型(sourceType)**: + {sourceTypeList} + + - **目标节点类型(targetType)**: + {targetTypeList} + + --- + + **【生成规则】** + + 1. 识别用户问题中的实体及意图,映射为 `Cypher 查询语句` + 2. 使用无条件匹配,不假设任何属性名称,不添加 `WHERE` 子句过滤。 + 3. 返回所有满足该关系的查询语句,并包含节点和关系的**所有属性**。 + 4. 仅输出 Cypher 代码块,不附加任何解释。 + 5. 如无法从结构中推断 relationType、sourceType 或 targetType,输出: + ``` + 无法根据数据库结构生成查询 + ``` + + --- + + **【示例】** + + 1. - **用户问题:** 查询“宝塔盛华商贸集团”的基本信息? + - **生成的 Cypher 查询:** + \\{ + "cypherQueries": [ + "MATCH (c:Company)-[r:HAS_LEGAL_REP]->(t) RETURN c, r, t", + "MATCH (c:Company)-[r:HAS_PHONE]->(t) RETURN c, r, t", + ..... + ] + \\} + + 2. - **用户问题:** 查询“宝塔盛华商贸集团”出具的电子银行承兑汇票金额是多少? + - **生成的 Cypher 查询:** + \\{ + "cypherQueries": [ + "MATCH (c:Company)-[r:IssueDocument]->(t:FinancialBill) RETURN c, r, t", + ..... + ] + \\} + + 【用户问题】 + {query} + + 【你的任务】 + 根据以上数据库结构和用户问题,生成正确的Cypher查询语句。 + """; private static final String GENERATE_ANSWER_PROMPT = """ - 结合给你的三元组数据和用户输入的问题,生成一个简洁的回答。 - 三元组数据:{tripleMetaData} - 用户输入的问题:{userQuery} + 一、任务说明 + 你是一个审计报告的整理助手,你需要对用户询问的问题根据产考数据进行回答。 + 二、用户输入: + {query} + 三、参考文本: + {example_text} + 四、注意事项 + 1. 数据核实:需确认是否使用参考文本中的数据。 + 2. 信息时效:可能需要确认数据的时间。 + 3. 问题范围界定:当用户提问超出审计报告分析范畴时,统一采用以下话术回复: + "您好!当前系统功能聚焦于审计报告相关内容分析,您的问题暂不在支持范围内。如需查询财务数据、票据详情或其他审计相关信息,请提供具体问题,我们将全力协助。" """; private static final String CHINESE_TO_ENGLISH_PROMPT = """ diff --git a/src/main/java/com/supervision/pdfqaserver/config/OllamaChatModelAspect.java b/src/main/java/com/supervision/pdfqaserver/config/OllamaChatModelAspect.java index c6c108d..e8724d7 100644 --- a/src/main/java/com/supervision/pdfqaserver/config/OllamaChatModelAspect.java +++ b/src/main/java/com/supervision/pdfqaserver/config/OllamaChatModelAspect.java @@ -26,20 +26,13 @@ public class OllamaChatModelAspect { @Around("execution(* org.springframework.ai.chat.model.ChatModel.call(..))") public Object aroundMethodExecution(ProceedingJoinPoint joinPoint) throws Throwable { String signature = joinPoint.getSignature().toString(); - if (StrUtil.equals(model,"qwen3:30b-a3b") && StrUtil.equals(signature, callStringMessage)) { - Object[] args = joinPoint.getArgs(); - if (args.length > 0) { - String arg = (String) args[0]; - args[0] = arg + "\n /no_think"; - } - } // 执行原方法 Object result = joinPoint.proceed(); - - if (StrUtil.equals(model,"qwen3:30b-a3b") && StrUtil.equals(signature, callStringMessage)) { - result = ((String) result).replaceAll("(?is)]*>(.*?)", "").trim(); + if (StrUtil.equals(model,"qwen3:30b-a3b") ) { + if(StrUtil.equals(signature, callStringMessage)){ + result = ((String) result).replaceAll("(?is)]*>(.*?)", "").trim(); + } } - return result; } } diff --git a/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java b/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java index 3f1b268..6ecafd5 100644 --- a/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java +++ b/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java @@ -39,9 +39,9 @@ public class Neo4jRepository { while (result.hasNext()) { org.neo4j.driver.Record record = result.next(); // 从 Record 中取出三部分 - Node a = record.get("startNode").asNode(); + Node a = record.get("c").asNode(); Relationship r = record.get("r").asRelationship(); - Node b = record.get("endNode").asNode(); + Node b = record.get("t").asNode(); // 转成我们的 DTO NodeData start = mapNode(a); diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java index 02b0d5f..57173eb 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java @@ -1,5 +1,7 @@ package com.supervision.pdfqaserver.service.impl; +import cn.hutool.json.JSONObject; +import cn.hutool.json.JSONUtil; import com.supervision.pdfqaserver.cache.PromptCache; import com.supervision.pdfqaserver.dao.Neo4jRepository; import com.supervision.pdfqaserver.domain.ChineseEnglishWords; @@ -11,14 +13,13 @@ import com.supervision.pdfqaserver.service.DomainMetadataService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; -import java.util.HashMap; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -30,14 +31,16 @@ import static com.supervision.pdfqaserver.cache.PromptCache.TEXT_TO_CYPHER; @Service @RequiredArgsConstructor public class ChatServiceImpl implements ChatService { - private static final String PROMPT_PARAM_DOMAIN_METADATA = "domainMetadata"; - private static final String PROMPT_PARAM_TRIPLE_METADATA = "tripleMetaData"; - private static final String PROMPT_PARAM_USER_QUERY = "userQuery"; + private static final String PROMPT_PARAM_SOURCE_TYPE_LIST = "sourceTypeList"; + private static final String PROMPT_PARAM_RELATION_TYPE_LIST = "relationTypeList"; + private static final String PROMPT_PARAM_TARGET_TYPE_LIST = "targetTypeList"; + private static final String PROMPT_PARAM_EXAMPLE_TEXT = "example_text"; + private static final String PROMPT_PARAM_QUERY = "query"; + private static final String CYPHER_QUERIES = "cypherQueries"; private final Neo4jRepository neo4jRepository; private final OllamaChatModel ollamaChatModel; - private final DomainMetadataService domainMetadataService; private final ChineseEnglishWordsService chineseEnglishWordsService; @@ -46,33 +49,54 @@ public class ChatServiceImpl implements ChatService { //拼装领域元数据 Map chineseEnglishWordsMap = chineseEnglishWordsService.list().stream() .collect(Collectors.toMap(ChineseEnglishWords::getChineseWord, ChineseEnglishWords::getEnglishWord)); - List> domainMappings = domainMetadataService.list().stream().map(domainMetadata -> { - Map mapping = new HashMap<>(); - mapping.put("source", domainMetadata.getSourceType()); - mapping.put("sourceType", chineseEnglishWordsMap.get(domainMetadata.getSourceType())); - mapping.put("relation", domainMetadata.getRelation()); - mapping.put("relationType", chineseEnglishWordsMap.get(domainMetadata.getRelation())); - mapping.put("target", domainMetadata.getTargetType()); - mapping.put("targetType", chineseEnglishWordsMap.get(domainMetadata.getTargetType())); - return mapping; - }).toList(); - log.info("domainMappings: {}", domainMappings); - //生成CYPHER + + //分别得到sourceType,relation,targetType的group by后的集合 + List sourceTypeList = domainMetadataService.list().stream().map(DomainMetadata::getSourceType).distinct().toList(); + List relationList = domainMetadataService.list().stream().map(DomainMetadata::getRelation).distinct().toList(); + List targetTypeList = domainMetadataService.list().stream().map(DomainMetadata::getTargetType).distinct().toList(); + + //将三个集合分别结合chineseEnglishWordsMap的key转化为value集合 + List sourceTypeListEnList = sourceTypeList.stream().map(chineseEnglishWordsMap::get).toList(); + List relationListEnList = relationList.stream().map(chineseEnglishWordsMap::get).toList(); + List targetTypeListEnList = targetTypeList.stream().map(chineseEnglishWordsMap::get).toList(); + + //将三个集合分别转换为英文逗号分隔的字符串 + String sourceTypeListEn = String.join(",", sourceTypeListEnList); + String relationListEn = String.join(",", relationListEnList); + String targetTypeListEn = String.join(",", targetTypeListEnList); + log.info("sourceTypeListEn: {}, relationListEn: {}, targetTypeListEn: {}", sourceTypeListEn, relationListEn, targetTypeListEn); + + //LLM生成CYPHER SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER)); - Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_DOMAIN_METADATA, domainMappings, PROMPT_PARAM_USER_QUERY, userQuery)); - ChatResponse textToCypherResponse = ollamaChatModel.call(new Prompt(textToCypherMessage)); - String queryCypher = "MATCH (startNode:公司)-[r]->(endNode) RETURN startNode,r,endNode"; - log.info(textToCypherResponse.getResult().getOutput().getText()); -// String queryCypher = textToCypherResponse.getResult().getOutput().getText(); - List relationObjects = neo4jRepository.execute(queryCypher, null); + Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_SOURCE_TYPE_LIST, sourceTypeListEn, PROMPT_PARAM_RELATION_TYPE_LIST, relationListEn, PROMPT_PARAM_TARGET_TYPE_LIST, targetTypeListEn, PROMPT_PARAM_QUERY, userQuery)); + String cypherJsonStr = ollamaChatModel.call(textToCypherMessage.getText()); + log.info(cypherJsonStr); + List cypherQueries; + try { + JSONObject jsonObj = JSONUtil.parseObj(cypherJsonStr); + cypherQueries = jsonObj.getJSONArray(CYPHER_QUERIES) + .toList(String.class); + } catch (Exception e) { + log.error("解析CYPHER JSON字符串失败: {}", e.getMessage()); + return Flux.just("查无结果"); + } + log.info("转换后的Cypher语句:{}", cypherQueries.toString()); + + //执行CYPHER查询并汇总结果 + List relationObjects = new ArrayList<>(); + if (!cypherQueries.isEmpty()) { + for (String cypher : cypherQueries) { + relationObjects.addAll(neo4jRepository.execute(cypher, null)); + } + } if (relationObjects.isEmpty()) { - return Flux.just("没有找到相关数据"); + return Flux.just("查无结果"); } - log.info("relationObjects: {}", relationObjects); + log.info("三元组数据: {}", relationObjects); //生成回答 SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER)); - Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_TRIPLE_METADATA, relationObjects, PROMPT_PARAM_USER_QUERY, userQuery)); + Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, relationObjects, PROMPT_PARAM_QUERY, userQuery)); return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText()); } }