diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java index eab7e3c..1c01375 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java @@ -2,9 +2,11 @@ package com.supervision.pdfqaserver.service.impl; import com.supervision.pdfqaserver.cache.PromptCache; import com.supervision.pdfqaserver.dao.Neo4jRepository; +import com.supervision.pdfqaserver.domain.ChineseEnglishWords; import com.supervision.pdfqaserver.domain.DomainMetadata; import com.supervision.pdfqaserver.dto.neo4j.RelationObject; import com.supervision.pdfqaserver.service.ChatService; +import com.supervision.pdfqaserver.service.ChineseEnglishWordsService; import com.supervision.pdfqaserver.service.DomainMetadataService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; @@ -16,8 +18,10 @@ import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import static com.supervision.pdfqaserver.cache.PromptCache.GENERATE_ANSWER; import static com.supervision.pdfqaserver.cache.PromptCache.TEXT_TO_CYPHER; @@ -33,16 +37,27 @@ public class ChatServiceImpl implements ChatService { private final Neo4jRepository neo4jRepository; private final OllamaChatModel ollamaChatModel; private final DomainMetadataService domainMetadataService; + private final ChineseEnglishWordsService chineseEnglishWordsService; @Override public Flux knowledgeQA(String userQuery) { - String systemPrompt = domainMetadataService.list().stream() - .map(DomainMetadata::toString) - .reduce("", (acc, metadata) -> acc + metadata + "\n"); - + //拼装领域元数据 + Map chineseEnglishWordsMap = chineseEnglishWordsService.list().stream() + .collect(Collectors.toMap(ChineseEnglishWords::getChineseWord, ChineseEnglishWords::getEnglishWord)); + List> domainMappings = domainMetadataService.list().stream().map(domainMetadata -> { + Map mapping = new HashMap<>(); + mapping.put("source", domainMetadata.getSourceType()); + mapping.put("sourceType", chineseEnglishWordsMap.get(domainMetadata.getSourceType())); + mapping.put("relation", domainMetadata.getRelation()); + mapping.put("relationType", chineseEnglishWordsMap.get(domainMetadata.getRelation())); + mapping.put("target", domainMetadata.getTargetType()); + mapping.put("targetType", chineseEnglishWordsMap.get(domainMetadata.getTargetType())); + return mapping; + }).toList(); + log.info("domainMappings: {}", domainMappings); //生成CYPHER SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER)); - Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_DOMAIN_METADATA, systemPrompt, PROMPT_PARAM_USER_QUERY, userQuery)); + Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_DOMAIN_METADATA, domainMappings, PROMPT_PARAM_USER_QUERY, userQuery)); ChatResponse textToCypherResponse = ollamaChatModel.call(new Prompt(textToCypherMessage)); String queryCypher = "MATCH (startNode:公司)-[r]->(endNode) RETURN startNode,r,endNode"; log.info(textToCypherResponse.getResult().getOutput().getText()); @@ -55,7 +70,7 @@ public class ChatServiceImpl implements ChatService { //生成回答 SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER)); - Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_TRIPLE_METADATA, systemPrompt, PROMPT_PARAM_USER_QUERY, userQuery)); + Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_TRIPLE_METADATA, relationObjects, PROMPT_PARAM_USER_QUERY, userQuery)); return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText()); } }