package com.supervision.pdfqaserver.service.impl; import cn.hutool.core.bean.BeanUtil; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONObject; import cn.hutool.json.JSONUtil; import com.supervision.pdfqaserver.cache.PromptCache; import com.supervision.pdfqaserver.domain.DocumentTruncation; import com.supervision.pdfqaserver.dto.AnswerDetailDTO; import com.supervision.pdfqaserver.dto.neo4j.NodeDTO; import com.supervision.pdfqaserver.dto.neo4j.RelationshipValueDTO; import com.supervision.pdfqaserver.service.*; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; import static com.supervision.pdfqaserver.cache.PromptCache.*; @Slf4j @Service @RequiredArgsConstructor public class ChatServiceImpl implements ChatService { private static final String PROMPT_PARAM_SOURCE_TYPE_LIST = "sourceTypeList"; private static final String PROMPT_PARAM_RELATION_TYPE_LIST = "relationTypeList"; private static final String PROMPT_PARAM_TARGET_TYPE_LIST = "targetTypeList"; private static final String PROMPT_PARAM_EXAMPLE_TEXT = "example_text"; private static final String PROMPT_PARAM_QUERY = "query"; private static final String CYPHER_QUERIES = "cypherQueries"; private final AiCallService aiCallService; private final DocumentTruncationService documentTruncationService; private final TripleToCypherExecutor tripleToCypherExecutor; private final DataCompareRetriever compareRetriever; @Override public Flux knowledgeQA(String userQuery) { log.info("用户查询: {}", userQuery); List> graphResult = compareRetriever.retrieval(userQuery); if (CollUtil.isEmpty(graphResult)){ return Flux.just("查无结果").concatWith(Flux.just("[END]")); } //生成回答 SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER)); Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, JSONUtil.toJsonStr(clearGraphElements(graphResult)), PROMPT_PARAM_QUERY, userQuery)); log.info("生成回答的提示词:{}", generateAnswerMessage); return aiCallService.stream(new Prompt(generateAnswerMessage)) .map(response -> response.getResult().getOutput().getText()) .concatWith(Flux.just(new JSONObject().set("answerDetails", convertToAnswerDetails(graphResult)).toString())) .concatWith(Flux.just("[END]")); } private List convertToAnswerDetails(List> graphResult) { if (CollUtil.isEmpty(graphResult)){ return new ArrayList<>(); } List answerDetailDTOS = new ArrayList<>(); for (Map map : graphResult) { Long start = null; Long end = null; for (Map.Entry entry : map.entrySet()) { // 先找到头节点和尾节点id if (entry.getValue() instanceof RelationshipValueDTO value){ start = value.getStart(); end = value.getEnd(); break; } } AnswerDetailDTO answerDetailDTO = new AnswerDetailDTO(); if (null == start) { // 没有关系类型 for (Map.Entry entry : map.entrySet()) { // 处理头节点 if(entry.getValue() instanceof NodeDTO nodeDTO){ Map properties = nodeDTO.getProperties(); if (StrUtil.isEmpty(answerDetailDTO.getSourceType())){ answerDetailDTO.setSourceName((String) properties.get("name")); answerDetailDTO.setSourceType(CollUtil.getFirst(nodeDTO.getLabels())); // 假设第一个标签是源类型 // 设置truncationId属性 answerDetailDTO.setTruncateId((String) properties.get("truncationId")); }else { answerDetailDTO.setTargetName((String) properties.get("name")); answerDetailDTO.setTargetType(CollUtil.getFirst(nodeDTO.getLabels())); // 假设第一个标签是目标类型 } } } answerDetailDTOS.add(answerDetailDTO); }else { // 有关系节点 for (Map.Entry entry : map.entrySet()) { // 处理头节点 if(entry.getValue() instanceof NodeDTO nodeDTO){ if (start.equals(nodeDTO.getId())){ Map properties = nodeDTO.getProperties(); answerDetailDTO.setSourceName((String) properties.get("name")); answerDetailDTO.setSourceType(CollUtil.getFirst(nodeDTO.getLabels())); // 假设第一个标签是源类型 // 设置truncationId属性 answerDetailDTO.setTruncateId((String) properties.get("truncationId")); } if (end.equals(nodeDTO.getId())){ Map properties = nodeDTO.getProperties(); answerDetailDTO.setTargetName((String) properties.get("name")); answerDetailDTO.setTargetType(CollUtil.getFirst(nodeDTO.getLabels())); // 假设第一个标签是目标类型 } } if (entry.getValue() instanceof RelationshipValueDTO value) { // 处理关系 if (start.equals(value.getStart()) || end.equals(value.getEnd())) { answerDetailDTO.setRelation(value.getType()); } } } answerDetailDTOS.add(answerDetailDTO); } } List distinct = new ArrayList<>(); if (CollUtil.isNotEmpty(answerDetailDTOS)){ //去重answerDetailDTOS for (AnswerDetailDTO answerDetailDTO : answerDetailDTOS) { boolean noned = distinct.stream().noneMatch(i -> StrUtil.equals(i.getSourceName(), answerDetailDTO.getSourceName()) && StrUtil.equals(i.getTargetName(), answerDetailDTO.getTargetName()) && StrUtil.equals(i.getRelation(), answerDetailDTO.getRelation()) && StrUtil.equals(i.getSourceType(), answerDetailDTO.getSourceType()) && StrUtil.equals(i.getTargetType(), answerDetailDTO.getTargetType()) && StrUtil.equals(i.getTruncateId(), answerDetailDTO.getTruncateId()) ); if (noned){ distinct.add(answerDetailDTO); } } List truncateIds = distinct.stream().map(AnswerDetailDTO::getTruncateId).distinct().toList(); if (CollUtil.isEmpty(truncateIds)){ return answerDetailDTOS; } List documentTruncations = documentTruncationService.listByIds(truncateIds); Map contentMap = documentTruncations.stream().collect(Collectors.toMap(DocumentTruncation::getId, DocumentTruncation::getContent)); for (AnswerDetailDTO answerDetailDTO : distinct) { answerDetailDTO.setTruncateContent(contentMap.get(answerDetailDTO.getTruncateId())); } } return distinct; } /** * 清理图谱元素中的无效数据 * @param graphElements 图谱元素列表 * @return */ public List> clearGraphElements(List> graphElements) { if (CollUtil.isEmpty(graphElements)) { return graphElements; } List> result = new ArrayList<>(graphElements.size()); for (Map originalMap : graphElements) { Map newMap = new HashMap<>(); for (Map.Entry entry : originalMap.entrySet()) { String key = entry.getKey(); Object value = entry.getValue(); if (value instanceof NodeDTO nodeDTO){ NodeDTO newNodeDTO = BeanUtil.copyProperties(nodeDTO, NodeDTO.class); newNodeDTO.clearGraphElement(); // 清理图谱元素 newMap.put(key, newNodeDTO); } else if (value instanceof RelationshipValueDTO relationshipValueDTO) { RelationshipValueDTO newRelationshipValueDTO = BeanUtil.copyProperties(relationshipValueDTO, RelationshipValueDTO.class); newRelationshipValueDTO.clearGraphElement(); // 清理图谱元素 newMap.put(key, newRelationshipValueDTO); } } result.add(newMap); } return result; } }