package com.supervision.pdfqaserver.service.impl; import cn.hutool.json.JSONObject; import cn.hutool.json.JSONUtil; import com.supervision.pdfqaserver.cache.PromptCache; import com.supervision.pdfqaserver.dao.Neo4jRepository; import com.supervision.pdfqaserver.domain.ChineseEnglishWords; import com.supervision.pdfqaserver.domain.DomainMetadata; import com.supervision.pdfqaserver.dto.neo4j.RelationObject; import com.supervision.pdfqaserver.service.ChatService; import com.supervision.pdfqaserver.service.ChineseEnglishWordsService; import com.supervision.pdfqaserver.service.DomainMetadataService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.stream.Collectors; import static com.supervision.pdfqaserver.cache.PromptCache.GENERATE_ANSWER; import static com.supervision.pdfqaserver.cache.PromptCache.TEXT_TO_CYPHER; @Slf4j @Service @RequiredArgsConstructor public class ChatServiceImpl implements ChatService { private static final String PROMPT_PARAM_SOURCE_TYPE_LIST = "sourceTypeList"; private static final String PROMPT_PARAM_RELATION_TYPE_LIST = "relationTypeList"; private static final String PROMPT_PARAM_TARGET_TYPE_LIST = "targetTypeList"; private static final String PROMPT_PARAM_EXAMPLE_TEXT = "example_text"; private static final String PROMPT_PARAM_QUERY = "query"; private static final String CYPHER_QUERIES = "cypherQueries"; private final Neo4jRepository neo4jRepository; private final OllamaChatModel ollamaChatModel; private final DomainMetadataService domainMetadataService; private final ChineseEnglishWordsService chineseEnglishWordsService; @Override public Flux knowledgeQA(String userQuery) { //拼装领域元数据 Map chineseEnglishWordsMap = chineseEnglishWordsService.list().stream() .collect(Collectors.toMap(ChineseEnglishWords::getChineseWord, ChineseEnglishWords::getEnglishWord)); //分别得到sourceType,relation,targetType的group by后的集合 List sourceTypeList = domainMetadataService.list().stream().map(DomainMetadata::getSourceType).distinct().toList(); List relationList = domainMetadataService.list().stream().map(DomainMetadata::getRelation).distinct().toList(); List targetTypeList = domainMetadataService.list().stream().map(DomainMetadata::getTargetType).distinct().toList(); //将三个集合分别结合chineseEnglishWordsMap的key转化为value集合 List sourceTypeListEnList = sourceTypeList.stream().map(chineseEnglishWordsMap::get).toList(); List relationListEnList = relationList.stream().map(chineseEnglishWordsMap::get).toList(); List targetTypeListEnList = targetTypeList.stream().map(chineseEnglishWordsMap::get).toList(); //将三个集合分别转换为英文逗号分隔的字符串 String sourceTypeListEn = String.join(",", sourceTypeListEnList); String relationListEn = String.join(",", relationListEnList); String targetTypeListEn = String.join(",", targetTypeListEnList); log.info("sourceTypeListEn: {}, relationListEn: {}, targetTypeListEn: {}", sourceTypeListEn, relationListEn, targetTypeListEn); //LLM生成CYPHER SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER)); Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_SOURCE_TYPE_LIST, sourceTypeListEn, PROMPT_PARAM_RELATION_TYPE_LIST, relationListEn, PROMPT_PARAM_TARGET_TYPE_LIST, targetTypeListEn, PROMPT_PARAM_QUERY, userQuery)); String cypherJsonStr = ollamaChatModel.call(textToCypherMessage.getText()); log.info(cypherJsonStr); List cypherQueries; try { JSONObject jsonObj = JSONUtil.parseObj(cypherJsonStr); cypherQueries = jsonObj.getJSONArray(CYPHER_QUERIES) .toList(String.class); } catch (Exception e) { log.error("解析CYPHER JSON字符串失败: {}", e.getMessage()); return Flux.just("查无结果"); } log.info("转换后的Cypher语句:{}", cypherQueries.toString()); //执行CYPHER查询并汇总结果 List relationObjects = new ArrayList<>(); if (!cypherQueries.isEmpty()) { for (String cypher : cypherQueries) { relationObjects.addAll(neo4jRepository.execute(cypher, null)); } } if (relationObjects.isEmpty()) { return Flux.just("查无结果"); } log.info("三元组数据: {}", relationObjects); //生成回答 SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER)); Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, relationObjects, PROMPT_PARAM_QUERY, userQuery)); return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText()); } }