package com.supervision.pdfqaserver.service.impl; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONObject; import cn.hutool.json.JSONUtil; import com.supervision.pdfqaserver.cache.PromptCache; import com.supervision.pdfqaserver.dao.Neo4jRepository; import com.supervision.pdfqaserver.domain.DocumentTruncation; import com.supervision.pdfqaserver.domain.Intention; import com.supervision.pdfqaserver.dto.AnswerDetailDTO; import com.supervision.pdfqaserver.dto.DomainMetadataDTO; import com.supervision.pdfqaserver.dto.neo4j.RelationObject; import com.supervision.pdfqaserver.service.*; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.stream.Collectors; import static com.supervision.pdfqaserver.cache.PromptCache.*; @Slf4j @Service @RequiredArgsConstructor public class ChatServiceImpl implements ChatService { private static final String PROMPT_PARAM_SOURCE_TYPE_LIST = "sourceTypeList"; private static final String PROMPT_PARAM_RELATION_TYPE_LIST = "relationTypeList"; private static final String PROMPT_PARAM_TARGET_TYPE_LIST = "targetTypeList"; private static final String PROMPT_PARAM_EXAMPLE_TEXT = "example_text"; private static final String PROMPT_PARAM_QUERY = "query"; private static final String CYPHER_QUERIES = "cypherQueries"; private final OllamaChatModel ollamaChatModel; private final AiCallService aiCallService; private final DocumentTruncationService documentTruncationService; private final TripleToCypherExecutor tripleToCypherExecutor; @Override public Flux knowledgeQA(String userQuery) { log.info("用户查询: {}", userQuery); // 生成cypher语句 String cypher = tripleToCypherExecutor.generateQueryCypher(userQuery,null); log.info("生成CYPHER语句的消息:{}", cypher); if (StrUtil.isEmpty(cypher)){ return Flux.just("查无结果").concatWith(Flux.just("[END]")); } // 执行cypher语句 List> graphResult = tripleToCypherExecutor.executeCypher(cypher); if (CollUtil.isEmpty(graphResult)){ return Flux.just("查无结果").concatWith(Flux.just("[END]")); } //生成回答 SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER)); Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, JSONUtil.toJsonStr(graphResult), PROMPT_PARAM_QUERY, userQuery)); log.info("生成回答的提示词:{}", generateAnswerMessage); return aiCallService.stream(new Prompt(generateAnswerMessage)) .map(response -> response.getResult().getOutput().getText()) .concatWith(Flux.just(new JSONObject().set("answerDetails", convertToAnswerDetails(null)).toString())) .concatWith(Flux.just("[END]")); } private List convertToAnswerDetails(List relationObjects) { if (CollUtil.isEmpty(relationObjects)) { return new ArrayList<>(); } List answerDetailDTOList = relationObjects.stream().map(AnswerDetailDTO::new).collect(Collectors.toList()); if (CollUtil.isNotEmpty(answerDetailDTOList)){ List truncateIds = answerDetailDTOList.stream().map(AnswerDetailDTO::getTruncateId).distinct().toList(); if (CollUtil.isEmpty(truncateIds)){ return answerDetailDTOList; } List documentTruncations = documentTruncationService.listByIds(truncateIds); Map contentMap = documentTruncations.stream().collect(Collectors.toMap(DocumentTruncation::getId, DocumentTruncation::getContent)); for (AnswerDetailDTO answerDetailDTO : answerDetailDTOList) { answerDetailDTO.setTruncateContent(contentMap.get(answerDetailDTO.getTruncateId())); } } return answerDetailDTOList; } }