From ee4b4adb37ef7e9feb448aaa809f6773647b86cc Mon Sep 17 00:00:00 2001 From: xueqingkun Date: Thu, 29 May 2025 13:54:06 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E5=8A=9F=E8=83=BDbug?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20=E4=BC=98=E5=8C=96=E9=97=AE=E7=AD=94?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../pdfqaserver/cache/PromptCache.java | 55 +++++++++ .../pdfqaserver/dto/AnswerDetailDTO.java | 94 ++++++++++++++ .../supervision/pdfqaserver/dto/EREDTO.java | 1 + .../pdfqaserver/service/IntentionService.java | 10 +- .../service/KnowledgeGraphService.java | 7 ++ .../service/impl/ChatServiceImpl.java | 116 +++++++++++++++--- .../service/impl/IntentionServiceImpl.java | 5 + .../impl/KnowledgeGraphServiceImpl.java | 21 +++- .../impl/TripleToCypherExecutorImpl.java | 1 - ...TruncationEntityExtractionServiceImpl.java | 1 + ...uncationRelationExtractionServiceImpl.java | 1 + 11 files changed, 289 insertions(+), 23 deletions(-) create mode 100644 src/main/java/com/supervision/pdfqaserver/dto/AnswerDetailDTO.java diff --git a/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java b/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java index a49f0ac..588418e 100644 --- a/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java +++ b/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java @@ -8,17 +8,44 @@ import java.util.Map; */ public class PromptCache { + /** + * 抽取文文本中的三元体 + */ public static final String DOERE_TEXT = "DOERE_TEXT"; + + /** + * 抽取表格中的三元体 + */ public static final String DOERE_TABLE = "DOERE_TABLE"; + + /** + * 将文本转换为Cypher查询语句 + */ public static final String TEXT_TO_CYPHER = "TEXT_TO_CYPHER"; + + /** + * 生成答案 + */ public static final String GENERATE_ANSWER = "GENERATE_ANSWER"; + /** + * 中文转英文 + */ public static final String CHINESE_TO_ENGLISH = "CHINESE_TO_ENGLISH"; + /** + * 将三元组转换为Cypher语句 + */ public static final String ERE_TO_INSERT_CYPHER = "ERE_TO_INSERT_CYPHER"; + /** + * 区分表格是否是描述性标题 + */ public static final String CLASSIFY_TABLE = "CLASSIFY_TABLE"; + /** + * 提取表格标题 + */ public static final String EXTRACT_TABLE_TITLE = "EXTRACT_TABLE_TITLE"; /** @@ -53,6 +80,12 @@ public class PromptCache { */ public static final String EXTRACT_ERE_BASE_INTENT = "EXTRACT_ERE_BASE_INTENT"; + + /** + * 区分问题中的意图分类 + */ + public static final String CLASSIFY_QUERY_INTENT = "CLASSIFY_QUERY_INTENT"; + public static final Map promptMap = new HashMap<>(); static { @@ -73,6 +106,7 @@ public class PromptCache { promptMap.put(CLASSIFY_INTENT_TRAIN, CLASSIFY_INTENT_TRAIN_PROMPT); promptMap.put(EXTRACT_INTENT_METADATA, EXTRACT_INTENT_METADATA_PROMPT); promptMap.put(EXTRACT_ERE_BASE_INTENT, EXTRACT_ERE_BASE_INTENT_PROMPT); + promptMap.put(CLASSIFY_QUERY_INTENT, CLASSIFY_QUERY_INTENT_PROMPT); } @@ -759,4 +793,25 @@ public class PromptCache { - 确保提取的值与原文一致,不进行推断或改写。 - 输出纯JSON格式,不要使用```json ```等任何Markdown标记包装./no_think """; + + + private static final String CLASSIFY_QUERY_INTENT_PROMPT = """ + 请将用户问题分类到以下意图列表中,并严格以JSON数组格式返回结果(即使只有一个意图)。 + + # 意图列表: + {intents} + + # 规则: + 1. 仅返回与问题相关的意图(若无匹配则返回空数组) + 2. 使用意图列表中的意图名称 + 3. 禁止解释原因 + 4. 输出纯JSON格式,不要使用```json ```等任何Markdown标记包装 + + # 示例: + 用户问题:"我昨天买的鞋子怎么还没发货?" + 输出:["订单查询"] + + # 当前用户问题: + {query} + """; } diff --git a/src/main/java/com/supervision/pdfqaserver/dto/AnswerDetailDTO.java b/src/main/java/com/supervision/pdfqaserver/dto/AnswerDetailDTO.java new file mode 100644 index 0000000..ed82325 --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/dto/AnswerDetailDTO.java @@ -0,0 +1,94 @@ +package com.supervision.pdfqaserver.dto; + +import cn.hutool.core.collection.CollUtil; +import com.supervision.pdfqaserver.dto.neo4j.NodeData; +import com.supervision.pdfqaserver.dto.neo4j.RelationObject; +import com.supervision.pdfqaserver.dto.neo4j.RelationshipData; +import lombok.Data; + +@Data +public class AnswerDetailDTO { + + /** + * 文档片段id + */ + private String truncateId; + + /** + * 头节点类型 + */ + private String sourceType; + + /** + * 头节点名称 + */ + private String sourceName; + + /** + * 目标节点类型 + */ + private String targetType; + + + /** + * 目标节点名称 + */ + private String targetName; + + /** + * 关系 + */ + private String relation; + + /** + * 片段内容 + */ + private String truncateContent; + + /** + * PDF ID + */ + private String pdfId; + + /** + * PDF 名称 + */ + private String pdfName; + + public AnswerDetailDTO() { + } + + public AnswerDetailDTO(RelationObject relationObject) { + NodeData endNode = relationObject.endNode(); + NodeData startNode = relationObject.startNode(); + RelationshipData relationship = relationObject.relationship(); + if (null == startNode || null == endNode || null == relationship){ + return; + } + if (CollUtil.isNotEmpty(startNode.properties())){ + Object truncationId = startNode.properties().get("truncationId"); + if (null != truncationId){ + this.truncateId = truncationId.toString(); + } + } + if (CollUtil.isNotEmpty(endNode.labels())){ + this.sourceType = String.join(",", endNode.labels()); + } + if (CollUtil.isNotEmpty(startNode.properties())){ + if (null != startNode.properties().get("name")){ + this.sourceName = startNode.properties().get("name").toString(); + } + } + if (CollUtil.isNotEmpty(endNode.labels())){ + this.targetType = String.join(",", startNode.labels()); + } + if (CollUtil.isNotEmpty(endNode.properties())){ + if (null != startNode.properties().get("name")){ + this.targetName = endNode.properties().get("name").toString(); + } + } + if (CollUtil.isNotEmpty(relationship.properties())) { + this.relation = relationship.type(); + } + } +} diff --git a/src/main/java/com/supervision/pdfqaserver/dto/EREDTO.java b/src/main/java/com/supervision/pdfqaserver/dto/EREDTO.java index 1f37e60..bf97460 100644 --- a/src/main/java/com/supervision/pdfqaserver/dto/EREDTO.java +++ b/src/main/java/com/supervision/pdfqaserver/dto/EREDTO.java @@ -146,6 +146,7 @@ public class EREDTO { entityExtractionDTO.setAttributes(truncationErAttributeDTOS); entities.add(entityExtractionDTO); } + eredto.setEntities(entities); return eredto; diff --git a/src/main/java/com/supervision/pdfqaserver/service/IntentionService.java b/src/main/java/com/supervision/pdfqaserver/service/IntentionService.java index 2c7862f..9cf579d 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/IntentionService.java +++ b/src/main/java/com/supervision/pdfqaserver/service/IntentionService.java @@ -1,9 +1,7 @@ package com.supervision.pdfqaserver.service; -import cn.hutool.core.util.StrUtil; import com.supervision.pdfqaserver.domain.Intention; import com.baomidou.mybatisplus.extension.service.IService; - import java.util.List; /** @@ -33,4 +31,12 @@ public interface IntentionService extends IService { Intention queryByDigestAndDomainCategoryId(String digest, String domainCategoryId); List queryByDomainCategoryId(String domainCategoryId); + + + /** + * 查询所有通过的意图 + * @return 意图列表 + */ + List listAllPassed(); + } diff --git a/src/main/java/com/supervision/pdfqaserver/service/KnowledgeGraphService.java b/src/main/java/com/supervision/pdfqaserver/service/KnowledgeGraphService.java index 25df5bf..0f95dd7 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/KnowledgeGraphService.java +++ b/src/main/java/com/supervision/pdfqaserver/service/KnowledgeGraphService.java @@ -40,6 +40,13 @@ public interface KnowledgeGraphService { void generateGraph(List eredtoList); + /** + * 生知识图谱 + * @param eredtoList + */ + + void generateGraphSimple(List eredtoList); + List truncateERE(List truncateDTOS); diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java index d481b5e..9540f6e 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java @@ -1,14 +1,17 @@ package com.supervision.pdfqaserver.service.impl; +import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONObject; import cn.hutool.json.JSONUtil; import com.supervision.pdfqaserver.cache.PromptCache; import com.supervision.pdfqaserver.dao.Neo4jRepository; -import com.supervision.pdfqaserver.domain.DomainMetadata; +import com.supervision.pdfqaserver.domain.DocumentTruncation; +import com.supervision.pdfqaserver.domain.Intention; +import com.supervision.pdfqaserver.dto.AnswerDetailDTO; +import com.supervision.pdfqaserver.dto.DomainMetadataDTO; import com.supervision.pdfqaserver.dto.neo4j.RelationObject; -import com.supervision.pdfqaserver.service.ChatService; -import com.supervision.pdfqaserver.service.ChineseEnglishWordsService; -import com.supervision.pdfqaserver.service.DomainMetadataService; +import com.supervision.pdfqaserver.service.*; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.messages.Message; @@ -18,13 +21,12 @@ import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; - import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; -import static com.supervision.pdfqaserver.cache.PromptCache.GENERATE_ANSWER; -import static com.supervision.pdfqaserver.cache.PromptCache.TEXT_TO_CYPHER; +import static com.supervision.pdfqaserver.cache.PromptCache.*; @Slf4j @Service @@ -40,26 +42,44 @@ public class ChatServiceImpl implements ChatService { private final Neo4jRepository neo4jRepository; private final OllamaChatModel ollamaChatModel; + private final DomainMetadataService domainMetadataService; - private final ChineseEnglishWordsService chineseEnglishWordsService; + + private final AiCallService aiCallService; + + private final DocumentTruncationService documentTruncationService; + + + private final IntentionService intentionService; @Override public Flux knowledgeQA(String userQuery) { - //分别得到sourceType,relation,targetType的group by后的集合 - List sourceTypeList = domainMetadataService.list().stream().map(DomainMetadata::getSourceType).distinct().toList(); - List relationList = domainMetadataService.list().stream().map(DomainMetadata::getRelation).distinct().toList(); - List targetTypeList = domainMetadataService.list().stream().map(DomainMetadata::getTargetType).distinct().toList(); + List intentions = intentionService.listAllPassed(); + List relations = classifyIntents(userQuery, intentions); + if (CollUtil.isEmpty(relations)){ + log.info("没有匹配到意图,返回查无结果"); + return Flux.just("查无结果").concatWith(Flux.just("[END]")); + } + + List domainMetadataDTOS = domainMetadataService.listByIntentionIds(relations.stream().map(Intention::getId).toList()); + if (CollUtil.isEmpty(domainMetadataDTOS)){ + log.info("没有匹配到领域元数据,返回查无结果"); + return Flux.just("查无结果").concatWith(Flux.just("[END]")); + } //将三个集合分别转换为英文逗号分隔的字符串 - String sourceTypeListEn = String.join(",", sourceTypeList); - String relationListEn = String.join(",", relationList); - String targetTypeListEn = String.join(",", targetTypeList); + String sourceTypeListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getSourceType).distinct().collect(Collectors.joining(",")); + String relationListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getRelation).distinct().collect(Collectors.joining(",")); + String targetTypeListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getTargetType).distinct().collect(Collectors.joining(",")); //LLM生成CYPHER SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER)); - Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_SOURCE_TYPE_LIST, sourceTypeListEn, PROMPT_PARAM_RELATION_TYPE_LIST, relationListEn, PROMPT_PARAM_TARGET_TYPE_LIST, targetTypeListEn, PROMPT_PARAM_QUERY, userQuery)); + Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_SOURCE_TYPE_LIST, sourceTypeListEn, + PROMPT_PARAM_RELATION_TYPE_LIST, relationListEn, PROMPT_PARAM_TARGET_TYPE_LIST, + targetTypeListEn, PROMPT_PARAM_QUERY, userQuery)); log.info("生成CYPHER语句的消息:{}", textToCypherMessage); - String cypherJsonStr = ollamaChatModel.call(new Prompt(textToCypherMessage, OllamaOptions.builder().temperature(3.0).build())).getResult().getOutput().getText(); + String cypherJsonStr = ollamaChatModel.call(new Prompt(textToCypherMessage, OllamaOptions.builder().temperature(0.3).build())).getResult().getOutput().getText(); + log.info(cypherJsonStr); log.info(cypherJsonStr.replaceAll("(?is)]*>(.*?)", "").trim()); cypherJsonStr = cypherJsonStr.replaceAll("(?is)]*>(.*?)", "").trim(); @@ -82,6 +102,7 @@ public class ChatServiceImpl implements ChatService { } } if (relationObjects.isEmpty()) { + log.info("cypher没有查询到结果,返回查无结果"); return Flux.just("查无结果").concatWith(Flux.just("[END]")); } log.info("三元组数据: {}", relationObjects); @@ -90,6 +111,65 @@ public class ChatServiceImpl implements ChatService { SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER)); Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, relationObjects, PROMPT_PARAM_QUERY, userQuery)); log.info("生成回答的提示词:{}", generateAnswerMessage); - return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText()).concatWith(Flux.just("[END]")); + return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText()) + .concatWith(Flux.just(new JSONObject().set("answerDetails", convertToAnswerDetails(relationObjects)).toString())) + .concatWith(Flux.just("[END]")); + } + + /** + * + * 分类查询意图 + * @param query 问题 + * @param intentions 意图列表 + * @return + */ + private List classifyIntents(String query, List intentions) { + if (StrUtil.isEmpty(query) || CollUtil.isEmpty(intentions)) { + return new ArrayList<>(); + } + String prompt = promptMap.get(CLASSIFY_QUERY_INTENT); + List result = new ArrayList<>(); + log.info("开始分类意图,query: {}, intentions size: {}", query, intentions.size()); + List> intentionSplit = CollUtil.split(intentions, 150); + for (List intentionList : intentionSplit) { + log.info("分类意图,query: {}, intentions size: {}", query, intentionList.size()); + String intents = intentionList.stream().map(i -> " - " + i.getDigest() + "\n").collect(Collectors.joining()); + Map params = Map.of("query", query, + "intents", intents); + String format = StrUtil.format(prompt, params); + String call = aiCallService.call(format); + + if (StrUtil.isEmpty(call)) { + return new ArrayList<>(); + } + List digests = JSONUtil.parseArray(call).stream().map(Object::toString).toList(); + if (CollUtil.isEmpty(digests)) { + continue; + } + List collect = intentionList.stream().filter(i -> digests.contains(i.getDigest())).collect(Collectors.toList()); + if (CollUtil.isNotEmpty(collect)) { + result.addAll(collect); + } + } + return result; + } + + private List convertToAnswerDetails(List relationObjects) { + if (CollUtil.isEmpty(relationObjects)) { + return new ArrayList<>(); + } + List answerDetailDTOList = relationObjects.stream().map(AnswerDetailDTO::new).collect(Collectors.toList()); + if (CollUtil.isNotEmpty(answerDetailDTOList)){ + List truncateIds = answerDetailDTOList.stream().map(AnswerDetailDTO::getTruncateId).distinct().toList(); + if (CollUtil.isEmpty(truncateIds)){ + return answerDetailDTOList; + } + List documentTruncations = documentTruncationService.listByIds(truncateIds); + Map contentMap = documentTruncations.stream().collect(Collectors.toMap(DocumentTruncation::getId, DocumentTruncation::getContent)); + for (AnswerDetailDTO answerDetailDTO : answerDetailDTOList) { + answerDetailDTO.setTruncateContent(contentMap.get(answerDetailDTO.getTruncateId())); + } + } + return answerDetailDTOList; } } diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/IntentionServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/IntentionServiceImpl.java index 9298f33..c65436e 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/IntentionServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/IntentionServiceImpl.java @@ -65,6 +65,11 @@ public class IntentionServiceImpl extends ServiceImpl queryByDomainCategoryId(String domainCategoryId) { return super.lambdaQuery().eq(Intention::getDomainCategoryId, domainCategoryId).list(); } + + @Override + public List listAllPassed() { + return super.lambdaQuery().eq(Intention::getGenerationType, "0").list(); + } } diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java index fbfe65b..7a664b1 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java @@ -321,7 +321,7 @@ public class KnowledgeGraphServiceImpl implements KnowledgeGraphService { log.info("开始生成知识图谱..."); timer.start("generateGraph"); - generateGraph(eredtos); + generateGraphSimple(eredtos); log.info("生成知识图谱完成,耗时:{}秒", timer.intervalSecond("generateGraph")); } @@ -347,7 +347,6 @@ public class KnowledgeGraphServiceImpl implements KnowledgeGraphService { } for (RelationExtractionDTO relation : relations) { DomainMetadata domainMetadata = relation.toDomainMetadata(); - domainMetadata.setDomainType("1"); domainMetadata.setGenerationType(DomainMetaGenerationEnum.SYSTEM_AUTO_GENERATION.getCode()); domainMetadataService.saveIfNotExists(domainMetadata); } @@ -400,6 +399,24 @@ public class KnowledgeGraphServiceImpl implements KnowledgeGraphService { } } + @Override + public void generateGraphSimple(List eredtoList) { + log.info("开始合并实体关系抽取结果..."); + List mergedList = tripleConversionPipeline.mergeEreResults(eredtoList); + log.info("合并实体关系抽取结果完成,合并后个数:{}", mergedList.size()); + + for (EREDTO eredto : mergedList) { + if (CollUtil.isEmpty(eredto.getEntities()) && CollUtil.isEmpty(eredto.getRelations())){ + continue; + } + try { + tripleToCypherExecutor.saveERE(eredto); + } catch (Exception e) { + log.info("生成cypher语句失败,切分文档id:{}", JSONUtil.toJsonStr(eredto), e); + } + } + } + private static List getChineseEnglishWords(EREDTO eredto) { List allWords; allWords = eredto.getEntities().stream().flatMap(entity -> { diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java index b156ecd..72da5c4 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java @@ -87,7 +87,6 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor { Map attributes = relation.getAttributes().stream().collect(Collectors.toMap( TruncationERAttributeDTO::getAttributeEn, TruncationERAttributeDTO::getValue )); - attributes.put("sourceType", relation.getSourceType()); attributes.put("truncationId", relation.getTruncationId()); for (Long sourceNodeId : sourceNodeIds) { for (Long targetNodeId : targetNodeIds) { diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/TruncationEntityExtractionServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/TruncationEntityExtractionServiceImpl.java index 9b51c60..10104f8 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/TruncationEntityExtractionServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/TruncationEntityExtractionServiceImpl.java @@ -46,6 +46,7 @@ public class TruncationEntityExtractionServiceImpl extends ServiceImpl