package com.supervision.pdfqaserver.service.impl; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONObject; import cn.hutool.json.JSONUtil; import com.supervision.pdfqaserver.cache.PromptCache; import com.supervision.pdfqaserver.domain.DocumentTruncation; import com.supervision.pdfqaserver.dto.AnswerDetailDTO; import com.supervision.pdfqaserver.dto.neo4j.NodeDTO; import com.supervision.pdfqaserver.dto.neo4j.RelationshipValueDTO; import com.supervision.pdfqaserver.service.*; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.stream.Collectors; import static com.supervision.pdfqaserver.cache.PromptCache.*; @Slf4j @Service @RequiredArgsConstructor public class ChatServiceImpl implements ChatService { private static final String PROMPT_PARAM_SOURCE_TYPE_LIST = "sourceTypeList"; private static final String PROMPT_PARAM_RELATION_TYPE_LIST = "relationTypeList"; private static final String PROMPT_PARAM_TARGET_TYPE_LIST = "targetTypeList"; private static final String PROMPT_PARAM_EXAMPLE_TEXT = "example_text"; private static final String PROMPT_PARAM_QUERY = "query"; private static final String CYPHER_QUERIES = "cypherQueries"; private final AiCallService aiCallService; private final DocumentTruncationService documentTruncationService; private final TripleToCypherExecutor tripleToCypherExecutor; @Override public Flux knowledgeQA(String userQuery) { log.info("用户查询: {}", userQuery); // 生成cypher语句 String cypher = tripleToCypherExecutor.generateQueryCypher(userQuery,null); log.info("生成CYPHER语句的消息:{}", cypher); if (StrUtil.isEmpty(cypher)){ return Flux.just("查无结果").concatWith(Flux.just("[END]")); } // 执行cypher语句 List> graphResult = tripleToCypherExecutor.executeCypher(cypher); if (CollUtil.isEmpty(graphResult)){ return Flux.just("查无结果").concatWith(Flux.just("[END]")); } //生成回答 SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER)); Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, JSONUtil.toJsonStr(graphResult), PROMPT_PARAM_QUERY, userQuery)); log.info("生成回答的提示词:{}", generateAnswerMessage); return aiCallService.stream(new Prompt(generateAnswerMessage)) .map(response -> response.getResult().getOutput().getText()) .concatWith(Flux.just(new JSONObject().set("answerDetails", convertToAnswerDetails(graphResult)).toString())) .concatWith(Flux.just("[END]")); } private List convertToAnswerDetails(List> graphResult) { if (CollUtil.isEmpty(graphResult)){ return new ArrayList<>(); } List answerDetailDTOS = new ArrayList<>(); for (Map map : graphResult) { Long start = null; Long end = null; for (Map.Entry entry : map.entrySet()) { // 先找到头节点和尾节点id if (entry.getValue() instanceof RelationshipValueDTO value){ start = value.getStart(); end = value.getEnd(); break; } } AnswerDetailDTO answerDetailDTO = new AnswerDetailDTO(); if (null == start) { // 没有关系类型 for (Map.Entry entry : map.entrySet()) { // 处理头节点 if(entry.getValue() instanceof NodeDTO nodeDTO){ Map properties = nodeDTO.getProperties(); if (StrUtil.isEmpty(answerDetailDTO.getSourceType())){ answerDetailDTO.setSourceName((String) properties.get("name")); answerDetailDTO.setSourceType(CollUtil.getFirst(nodeDTO.getLabels())); // 假设第一个标签是源类型 // 设置truncationId属性 answerDetailDTO.setTruncateId((String) properties.get("truncationId")); }else { answerDetailDTO.setTargetName((String) properties.get("name")); answerDetailDTO.setTargetType(CollUtil.getFirst(nodeDTO.getLabels())); // 假设第一个标签是目标类型 } } } answerDetailDTOS.add(answerDetailDTO); }else { // 有关系节点 for (Map.Entry entry : map.entrySet()) { // 处理头节点 if(entry.getValue() instanceof NodeDTO nodeDTO){ if (start.equals(nodeDTO.getId())){ Map properties = nodeDTO.getProperties(); answerDetailDTO.setSourceName((String) properties.get("name")); answerDetailDTO.setSourceType(CollUtil.getFirst(nodeDTO.getLabels())); // 假设第一个标签是源类型 // 设置truncationId属性 answerDetailDTO.setTruncateId((String) properties.get("truncationId")); } if (end.equals(nodeDTO.getId())){ Map properties = nodeDTO.getProperties(); answerDetailDTO.setTargetName((String) properties.get("name")); answerDetailDTO.setTargetType(CollUtil.getFirst(nodeDTO.getLabels())); // 假设第一个标签是目标类型 } } if (entry.getValue() instanceof RelationshipValueDTO value) { // 处理关系 if (start.equals(value.getStart()) || end.equals(value.getEnd())) { answerDetailDTO.setRelation(value.getType()); } } } answerDetailDTOS.add(answerDetailDTO); } } List distinct = new ArrayList<>(); if (CollUtil.isNotEmpty(answerDetailDTOS)){ //去重answerDetailDTOS for (AnswerDetailDTO answerDetailDTO : answerDetailDTOS) { boolean noned = distinct.stream().noneMatch(i -> StrUtil.equals(i.getSourceName(), answerDetailDTO.getSourceName()) && StrUtil.equals(i.getTargetName(), answerDetailDTO.getTargetName()) && StrUtil.equals(i.getRelation(), answerDetailDTO.getRelation()) && StrUtil.equals(i.getSourceType(), answerDetailDTO.getSourceType()) && StrUtil.equals(i.getTargetType(), answerDetailDTO.getTargetType()) && StrUtil.equals(i.getTruncateId(), answerDetailDTO.getTruncateId()) ); if (noned){ distinct.add(answerDetailDTO); } } List truncateIds = distinct.stream().map(AnswerDetailDTO::getTruncateId).distinct().toList(); if (CollUtil.isEmpty(truncateIds)){ return answerDetailDTOS; } List documentTruncations = documentTruncationService.listByIds(truncateIds); Map contentMap = documentTruncations.stream().collect(Collectors.toMap(DocumentTruncation::getId, DocumentTruncation::getContent)); for (AnswerDetailDTO answerDetailDTO : distinct) { answerDetailDTO.setTruncateContent(contentMap.get(answerDetailDTO.getTruncateId())); } } return distinct; } }