package com.supervision.pdfqaserver.service.impl;

import com.supervision.pdfqaserver.cache.PromptCache;
import com.supervision.pdfqaserver.dao.Neo4jRepository;
import com.supervision.pdfqaserver.domain.ChineseEnglishWords;
import com.supervision.pdfqaserver.domain.DomainMetadata;
import com.supervision.pdfqaserver.dto.neo4j.RelationObject;
import com.supervision.pdfqaserver.service.ChatService;
import com.supervision.pdfqaserver.service.ChineseEnglishWordsService;
import com.supervision.pdfqaserver.service.DomainMetadataService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static com.supervision.pdfqaserver.cache.PromptCache.GENERATE_ANSWER;
import static com.supervision.pdfqaserver.cache.PromptCache.TEXT_TO_CYPHER;

@Slf4j
@Service
@RequiredArgsConstructor
public class ChatServiceImpl implements ChatService {
    private static final String PROMPT_PARAM_DOMAIN_METADATA = "domainMetadata";
    private static final String PROMPT_PARAM_TRIPLE_METADATA = "tripleMetaData";
    private static final String PROMPT_PARAM_USER_QUERY = "userQuery";

    private final Neo4jRepository neo4jRepository;

    private final OllamaChatModel ollamaChatModel;

    private final DomainMetadataService domainMetadataService;
    private final ChineseEnglishWordsService chineseEnglishWordsService;

    @Override
    public Flux<String> knowledgeQA(String userQuery) {
        //拼装领域元数据
        Map<String, String> chineseEnglishWordsMap = chineseEnglishWordsService.list().stream()
                .collect(Collectors.toMap(ChineseEnglishWords::getChineseWord, ChineseEnglishWords::getEnglishWord));
        List<Map<String, String>> domainMappings = domainMetadataService.list().stream().map(domainMetadata -> {
            Map<String, String> mapping = new HashMap<>();
            mapping.put("source", domainMetadata.getSourceType());
            mapping.put("sourceType", chineseEnglishWordsMap.get(domainMetadata.getSourceType()));
            mapping.put("relation", domainMetadata.getRelation());
            mapping.put("relationType", chineseEnglishWordsMap.get(domainMetadata.getRelation()));
            mapping.put("target", domainMetadata.getTargetType());
            mapping.put("targetType", chineseEnglishWordsMap.get(domainMetadata.getTargetType()));
            return mapping;
        }).toList();
        log.info("domainMappings: {}", domainMappings);
        //生成CYPHER
        SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER));
        Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_DOMAIN_METADATA, domainMappings, PROMPT_PARAM_USER_QUERY, userQuery));
        ChatResponse textToCypherResponse = ollamaChatModel.call(new Prompt(textToCypherMessage));
        String queryCypher = "MATCH (startNode:公司)-[r]->(endNode) RETURN startNode,r,endNode";
        log.info(textToCypherResponse.getResult().getOutput().getText());
//        String queryCypher = textToCypherResponse.getResult().getOutput().getText();
        List<RelationObject> relationObjects = neo4jRepository.execute(queryCypher, null);
        if (relationObjects.isEmpty()) {
            return Flux.just("没有找到相关数据");
        }
        log.info("relationObjects: {}", relationObjects);

        //生成回答
        SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER));
        Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_TRIPLE_METADATA, relationObjects, PROMPT_PARAM_USER_QUERY, userQuery));
        return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText());
    }
}