package com.supervision.pdfqaserver.service.impl; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONObject; import cn.hutool.json.JSONUtil; import com.supervision.pdfqaserver.cache.PromptCache; import com.supervision.pdfqaserver.dao.Neo4jRepository; import com.supervision.pdfqaserver.domain.DocumentTruncation; import com.supervision.pdfqaserver.domain.Intention; import com.supervision.pdfqaserver.dto.AnswerDetailDTO; import com.supervision.pdfqaserver.dto.DomainMetadataDTO; import com.supervision.pdfqaserver.dto.neo4j.RelationObject; import com.supervision.pdfqaserver.service.*; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.stream.Collectors; import static com.supervision.pdfqaserver.cache.PromptCache.*; @Slf4j @Service @RequiredArgsConstructor public class ChatServiceImpl implements ChatService { private static final String PROMPT_PARAM_SOURCE_TYPE_LIST = "sourceTypeList"; private static final String PROMPT_PARAM_RELATION_TYPE_LIST = "relationTypeList"; private static final String PROMPT_PARAM_TARGET_TYPE_LIST = "targetTypeList"; private static final String PROMPT_PARAM_EXAMPLE_TEXT = "example_text"; private static final String PROMPT_PARAM_QUERY = "query"; private static final String CYPHER_QUERIES = "cypherQueries"; private final Neo4jRepository neo4jRepository; private final OllamaChatModel ollamaChatModel; private final DomainMetadataService domainMetadataService; private final AiCallService aiCallService; private final DocumentTruncationService documentTruncationService; private final IntentionService intentionService; @Override public Flux knowledgeQA(String userQuery) { List intentions = intentionService.listAllPassed(); List relations = classifyIntents(userQuery, intentions); if (CollUtil.isEmpty(relations)){ log.info("没有匹配到意图,返回查无结果"); return Flux.just("查无结果").concatWith(Flux.just("[END]")); } List domainMetadataDTOS = domainMetadataService.listByIntentionIds(relations.stream().map(Intention::getId).toList()); if (CollUtil.isEmpty(domainMetadataDTOS)){ log.info("没有匹配到领域元数据,返回查无结果"); return Flux.just("查无结果").concatWith(Flux.just("[END]")); } //将三个集合分别转换为英文逗号分隔的字符串 String sourceTypeListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getSourceType).distinct().collect(Collectors.joining(",")); String relationListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getRelation).distinct().collect(Collectors.joining(",")); String targetTypeListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getTargetType).distinct().collect(Collectors.joining(",")); //LLM生成CYPHER SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER)); Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_SOURCE_TYPE_LIST, sourceTypeListEn, PROMPT_PARAM_RELATION_TYPE_LIST, relationListEn, PROMPT_PARAM_TARGET_TYPE_LIST, targetTypeListEn, PROMPT_PARAM_QUERY, userQuery)); log.info("生成CYPHER语句的消息:{}", textToCypherMessage); String cypherJsonStr = ollamaChatModel.call(new Prompt(textToCypherMessage, OllamaOptions.builder().temperature(0.3).build())).getResult().getOutput().getText(); log.info(cypherJsonStr); log.info(cypherJsonStr.replaceAll("(?is)]*>(.*?)", "").trim()); cypherJsonStr = cypherJsonStr.replaceAll("(?is)]*>(.*?)", "").trim(); List cypherQueries; try { JSONObject jsonObj = JSONUtil.parseObj(cypherJsonStr); cypherQueries = jsonObj.getJSONArray(CYPHER_QUERIES) .toList(String.class); } catch (Exception e) { log.error("解析CYPHER JSON字符串失败: {}", e.getMessage()); return Flux.just("查无结果").concatWith(Flux.just("[END]")); } log.info("转换后的Cypher语句:{}", cypherQueries.toString()); //执行CYPHER查询并汇总结果 List relationObjects = new ArrayList<>(); if (!cypherQueries.isEmpty()) { for (String cypher : cypherQueries) { relationObjects.addAll(neo4jRepository.execute(cypher, null)); } } if (relationObjects.isEmpty()) { log.info("cypher没有查询到结果,返回查无结果"); return Flux.just("查无结果").concatWith(Flux.just("[END]")); } log.info("三元组数据: {}", relationObjects); //生成回答 SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER)); Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, relationObjects, PROMPT_PARAM_QUERY, userQuery)); log.info("生成回答的提示词:{}", generateAnswerMessage); return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText()) .concatWith(Flux.just(new JSONObject().set("answerDetails", convertToAnswerDetails(relationObjects)).toString())) .concatWith(Flux.just("[END]")); } /** * * 分类查询意图 * @param query 问题 * @param intentions 意图列表 * @return */ private List classifyIntents(String query, List intentions) { if (StrUtil.isEmpty(query) || CollUtil.isEmpty(intentions)) { return new ArrayList<>(); } String prompt = promptMap.get(CLASSIFY_QUERY_INTENT); List result = new ArrayList<>(); log.info("开始分类意图,query: {}, intentions size: {}", query, intentions.size()); List> intentionSplit = CollUtil.split(intentions, 150); for (List intentionList : intentionSplit) { log.info("分类意图,query: {}, intentions size: {}", query, intentionList.size()); String intents = intentionList.stream().map(i -> " - " + i.getDigest() + "\n").collect(Collectors.joining()); Map params = Map.of("query", query, "intents", intents); String format = StrUtil.format(prompt, params); String call = aiCallService.call(format); if (StrUtil.isEmpty(call)) { return new ArrayList<>(); } List digests = JSONUtil.parseArray(call).stream().map(Object::toString).toList(); if (CollUtil.isEmpty(digests)) { continue; } List collect = intentionList.stream().filter(i -> digests.contains(i.getDigest())).collect(Collectors.toList()); if (CollUtil.isNotEmpty(collect)) { result.addAll(collect); } } return result; } private List convertToAnswerDetails(List relationObjects) { if (CollUtil.isEmpty(relationObjects)) { return new ArrayList<>(); } List answerDetailDTOList = relationObjects.stream().map(AnswerDetailDTO::new).collect(Collectors.toList()); if (CollUtil.isNotEmpty(answerDetailDTOList)){ List truncateIds = answerDetailDTOList.stream().map(AnswerDetailDTO::getTruncateId).distinct().toList(); if (CollUtil.isEmpty(truncateIds)){ return answerDetailDTOList; } List documentTruncations = documentTruncationService.listByIds(truncateIds); Map contentMap = documentTruncations.stream().collect(Collectors.toMap(DocumentTruncation::getId, DocumentTruncation::getContent)); for (AnswerDetailDTO answerDetailDTO : answerDetailDTOList) { answerDetailDTO.setTruncateContent(contentMap.get(answerDetailDTO.getTruncateId())); } } return answerDetailDTOList; } }