package com.supervision.pdfqaserver.service.impl; import com.supervision.pdfqaserver.cache.PromptCache; import com.supervision.pdfqaserver.dao.Neo4jRepository; import com.supervision.pdfqaserver.domain.ChineseEnglishWords; import com.supervision.pdfqaserver.domain.DomainMetadata; import com.supervision.pdfqaserver.dto.neo4j.RelationObject; import com.supervision.pdfqaserver.service.ChatService; import com.supervision.pdfqaserver.service.ChineseEnglishWordsService; import com.supervision.pdfqaserver.service.DomainMetadataService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; import static com.supervision.pdfqaserver.cache.PromptCache.GENERATE_ANSWER; import static com.supervision.pdfqaserver.cache.PromptCache.TEXT_TO_CYPHER; @Slf4j @Service @RequiredArgsConstructor public class ChatServiceImpl implements ChatService { private static final String PROMPT_PARAM_DOMAIN_METADATA = "domainMetadata"; private static final String PROMPT_PARAM_TRIPLE_METADATA = "tripleMetaData"; private static final String PROMPT_PARAM_USER_QUERY = "userQuery"; private final Neo4jRepository neo4jRepository; private final OllamaChatModel ollamaChatModel; private final DomainMetadataService domainMetadataService; private final ChineseEnglishWordsService chineseEnglishWordsService; @Override public Flux knowledgeQA(String userQuery) { //拼装领域元数据 Map chineseEnglishWordsMap = chineseEnglishWordsService.list().stream() .collect(Collectors.toMap(ChineseEnglishWords::getChineseWord, ChineseEnglishWords::getEnglishWord)); List> domainMappings = domainMetadataService.list().stream().map(domainMetadata -> { Map mapping = new HashMap<>(); mapping.put("source", domainMetadata.getSourceType()); mapping.put("sourceType", chineseEnglishWordsMap.get(domainMetadata.getSourceType())); mapping.put("relation", domainMetadata.getRelation()); mapping.put("relationType", chineseEnglishWordsMap.get(domainMetadata.getRelation())); mapping.put("target", domainMetadata.getTargetType()); mapping.put("targetType", chineseEnglishWordsMap.get(domainMetadata.getTargetType())); return mapping; }).toList(); log.info("domainMappings: {}", domainMappings); //生成CYPHER SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER)); Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_DOMAIN_METADATA, domainMappings, PROMPT_PARAM_USER_QUERY, userQuery)); ChatResponse textToCypherResponse = ollamaChatModel.call(new Prompt(textToCypherMessage)); String queryCypher = "MATCH (startNode:公司)-[r]->(endNode) RETURN startNode,r,endNode"; log.info(textToCypherResponse.getResult().getOutput().getText()); // String queryCypher = textToCypherResponse.getResult().getOutput().getText(); List relationObjects = neo4jRepository.execute(queryCypher, null); if (relationObjects.isEmpty()) { return Flux.just("没有找到相关数据"); } log.info("relationObjects: {}", relationObjects); //生成回答 SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER)); Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_TRIPLE_METADATA, relationObjects, PROMPT_PARAM_USER_QUERY, userQuery)); return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText()); } }