You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

62 lines
3.1 KiB
Java

package com.supervision.pdfqaserver.service.impl;
import com.supervision.pdfqaserver.cache.PromptCache;
import com.supervision.pdfqaserver.dao.Neo4jRepository;
import com.supervision.pdfqaserver.domain.DomainMetadata;
import com.supervision.pdfqaserver.dto.neo4j.RelationObject;
import com.supervision.pdfqaserver.service.ChatService;
import com.supervision.pdfqaserver.service.DomainMetadataService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.List;
import java.util.Map;
import static com.supervision.pdfqaserver.cache.PromptCache.GENERATE_ANSWER;
import static com.supervision.pdfqaserver.cache.PromptCache.TEXT_TO_CYPHER;
@Slf4j
@Service
@RequiredArgsConstructor
public class ChatServiceImpl implements ChatService {
private static final String PROMPT_PARAM_DOMAIN_METADATA = "domainMetadata";
private static final String PROMPT_PARAM_TRIPLE_METADATA = "tripleMetaData";
private static final String PROMPT_PARAM_USER_QUERY = "userQuery";
private final Neo4jRepository neo4jRepository;
private final OllamaChatModel ollamaChatModel;
private final DomainMetadataService domainMetadataService;
@Override
public Flux<String> knowledgeQA(String userQuery) {
String systemPrompt = domainMetadataService.list().stream()
.map(DomainMetadata::toString)
.reduce("", (acc, metadata) -> acc + metadata + "\n");
//生成CYPHER
SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER));
Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_DOMAIN_METADATA, systemPrompt, PROMPT_PARAM_USER_QUERY, userQuery));
ChatResponse textToCypherResponse = ollamaChatModel.call(new Prompt(textToCypherMessage));
String queryCypher = "MATCH (startNode:公司)-[r]->(endNode) RETURN startNode,r,endNode";
log.info(textToCypherResponse.getResult().getOutput().getText());
// String queryCypher = textToCypherResponse.getResult().getOutput().getText();
List<RelationObject> relationObjects = neo4jRepository.execute(queryCypher, null);
if (relationObjects.isEmpty()) {
return Flux.just("没有找到相关数据");
}
log.info("relationObjects: {}", relationObjects);
//生成回答
SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER));
Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_TRIPLE_METADATA, systemPrompt, PROMPT_PARAM_USER_QUERY, userQuery));
return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText());
}
}