package com.supervision.pdfqaserver.service.impl;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import com.supervision.pdfqaserver.cache.PromptCache;
import com.supervision.pdfqaserver.constant.LayoutTypeEnum;
import com.supervision.pdfqaserver.dto.*;
import com.supervision.pdfqaserver.service.TripleConversionPipeline;
import edu.stanford.nlp.pipeline.CoreDocument;
import edu.stanford.nlp.pipeline.CoreSentence;
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.stereotype.Service;

import java.util.*;
import java.util.stream.Collectors;

@Slf4j
@Service
@RequiredArgsConstructor
public class TripleConversionPipelineImpl implements TripleConversionPipeline {

    private final OllamaChatModel ollamaChatModel;

    @Override
    public List<TruncateDTO> sliceDocuments(List<DocumentDTO> documents) {
        // 对pdfAnalysisOutputs进行排序
        List<DocumentDTO> documentDTOList = documents.stream().sorted(
                // 先对pageNo进行排序再对layoutOrder进行排序
                (o1, o2) -> {
                    if (o1.getPageNo().equals(o2.getPageNo())) {
                        return Integer.compare(o1.getDisplayOrder(), o2.getDisplayOrder());
                    }
                    return Integer.compare(o1.getPageNo(), o2.getPageNo());
                }
        ).toList();

        Properties props = new Properties();
        props.setProperty("annotators", "tokenize, ssplit");
        // 创建管道
        StanfordCoreNLP pipeline = new StanfordCoreNLP(props);
        List<TruncateDTO> truncateDTOS = new ArrayList<>();
        for (DocumentDTO documentDTO : documentDTOList) {
            String content = documentDTO.getContent();
            if (StrUtil.isEmpty(content)){
                continue;
            }
            Integer layoutType = documentDTO.getLayoutType();
            if (LayoutTypeEnum.TEXT.getCode() == layoutType){
                // 如果是文本类型的布局,进行合并
                CoreDocument document = new CoreDocument(content);
                // 分析文本
                pipeline.annotate(document);
                // 获取句子
                for (CoreSentence sentence : document.sentences()) {
                    TruncateDTO truncateDTO = new TruncateDTO(documentDTO);
                    truncateDTO.setContent(sentence.text());
                    truncateDTOS.add(truncateDTO);
                }
            } else if (LayoutTypeEnum.TABLE.getCode() == layoutType) {
                // 如果是表格类型的布局,直接添加到列表中
                TruncateDTO truncateDTO = new TruncateDTO(documentDTO);
                truncateDTOS.add(truncateDTO);
            } else {
                log.info("sliceDocuments:错误的布局类型: {}", layoutType);
            }
        }
        return truncateDTOS;
    }

    @Override
    public EREDTO doEre(TruncateDTO truncateDTO) {

        if (StrUtil.equals(truncateDTO.getLayoutType(),String.valueOf(LayoutTypeEnum.TEXT.getCode()))){

            return doTextEre(truncateDTO);
        }

        if (StrUtil.equals(truncateDTO.getLayoutType(),String.valueOf(LayoutTypeEnum.TABLE.getCode()))){
            return doTableEre(truncateDTO);
        }
        log.info("doEre:错误的布局类型: {}", truncateDTO.getLayoutType());
        return null;
    }

    private EREDTO doTextEre(TruncateDTO truncateDTO) {
        log.info("doTextEre:开始进行文本实体关系抽取,内容:{}", truncateDTO.getContent());
        String prompt = PromptCache.promptMap.get(PromptCache.DOERE_TEXT);
        String formatted = StrUtil.format(prompt, truncateDTO.getContent());
        String response = ollamaChatModel.call(formatted);
        // todo:暂时不去处理异常返回
        log.info("doTextEre响应结果:{}", response);
        return EREDTO.fromTextJson(response, truncateDTO.getId());
    }

    private EREDTO doTableEre(TruncateDTO truncateDTO) {
        log.info("doTableEre:开始进行表格实体关系抽取,内容:{}", truncateDTO.getContent());
        String prompt = PromptCache.promptMap.get(PromptCache.DOERE_TABLE);
        String formatted = StrUtil.format(prompt, truncateDTO.getContent());
        String response = ollamaChatModel.call(formatted);
        log.info("doTableEre响应结果:{}", response);
        // todo:暂时不去处理异常返回
        EREDTO eredto = EREDTO.fromTableJson(response, truncateDTO.getId());
        // 手动设置表格标题
        EntityExtractionDTO titleEntity = new EntityExtractionDTO();
        titleEntity.setEntity("表");
        titleEntity.setName(truncateDTO.getTitle());
        // 添加关系
        List<RelationExtractionDTO> relations = new ArrayList<>();
        for (EntityExtractionDTO entity : eredto.getEntities()) {
            RelationExtractionDTO relationExtractionDTO = new RelationExtractionDTO(truncateDTO.getId(),
                    titleEntity.getName(), titleEntity.getEntity(), "包含", entity.getName(), entity.getEntity(), entity.getAttributes());
            relations.add(relationExtractionDTO);
        }
        eredto.getEntities().add(titleEntity);
        eredto.setRelations(relations);
        return eredto;
    }

    /**
     * 合并实体关系抽取结果 主要是对实体和关系中的属性进行合并
     * 表不参与合并
     * @param eredtoList 实体关系抽取结果列表
     * @return
     */
    @Override
    public List<EREDTO> mergeEreResults(List<EREDTO> eredtoList) {
        List<EREDTO> merged = new ArrayList<>();
        if (CollUtil.isEmpty(eredtoList)){
            return merged;
        }
        // 将表单独拿出来
        merged = eredtoList.stream().filter(ere->
                ere.getEntities().stream().anyMatch(e->StrUtil.equals(e.getEntity(),"表"))).collect(Collectors.toList());

        // 把剩下的数据进行合并计算
        eredtoList = eredtoList.stream().filter(ere->
                ere.getEntities().stream().noneMatch(e->StrUtil.equals(e.getEntity(),"表"))).collect(Collectors.toList());
        Map<String, EntityExtractionDTO> entityMap = new HashMap<>();
        Map<String, RelationExtractionDTO> relationMap = new HashMap<>();
        for (EREDTO eredto : eredtoList) {
            List<EntityExtractionDTO> entities = eredto.getEntities();
            if (CollUtil.isNotEmpty(entities)){
                for (EntityExtractionDTO entity : entities) {
                    String key = generateEntityMapKey(entity);
                    mergeAttribute(entityMap,entity, key);
                }
            }
            List<RelationExtractionDTO> relations = eredto.getRelations();
            if (CollUtil.isNotEmpty(relations)){
                for (RelationExtractionDTO relation : relations) {
                    // source和target,re完全相等看作是同一个数据
                    String relationMapKey = generateRelationMapKey(relation);
                    mergeAttribute(relationMap,relation, relationMapKey);
                }
            }
        }
        // 利用合并后的map生成新的EREDTO
        // 优先先把有关系的节点与关系组合在一次
        Set<String> relationEntityKey = new HashSet<>();
        for (Map.Entry<String, RelationExtractionDTO> relationEntry : relationMap.entrySet()) {
            RelationExtractionDTO value = relationEntry.getValue();
            EntityExtractionDTO sourceEntity = entityMap.get(StrUtil.join("_",value.getSourceType(), value.getSource()));
            if (null == sourceEntity){
                log.warn("mergeEreResults:根据entity:{},name:{}未在entityMap中找到头节点映射关系", value.getSourceType(), value.getSource());
                continue;
            }
            EntityExtractionDTO targetEntity = entityMap.get(StrUtil.join("_", value.getTargetType(),value.getTarget()));
            if (null == targetEntity){
                log.warn("mergeEreResults:根据entity:{},name:{}未在entityMap中找到尾节点映射关系", value.getTargetType(), value.getTarget());
                continue;
            }
            EREDTO eredto = new EREDTO();
            eredto.setEntities(List.of(sourceEntity,targetEntity));
            eredto.setRelations(List.of(value));
            merged.add(eredto);
            relationEntityKey.addAll(List.of(generateEntityMapKey(sourceEntity),generateEntityMapKey(targetEntity)));
        }
        // 将没有关系的节点单独放在一起
        List<EntityExtractionDTO> leavedEntities = new ArrayList<>();
        for (Map.Entry<String, EntityExtractionDTO> entry : entityMap.entrySet()) {
            if (!relationEntityKey.contains(entry.getKey())){
                leavedEntities.add(entry.getValue());
            }
        }
        EREDTO eredto = new EREDTO();
        eredto.setEntities(leavedEntities);
        merged.add(eredto);
        return merged;
    }

    private void mergeAttribute(Map<String, RelationExtractionDTO> entityMap,RelationExtractionDTO relation, String key) {

        RelationExtractionDTO cachedRelation = entityMap.get(key);
        if (null == cachedRelation){
            entityMap.put(key, relation);
        }else {
            if (CollUtil.isEmpty(relation.getAttributes())){
                return;
            }
            // 合并属性
            List<ERAttributeDTO> cachedAttributes = cachedRelation.getAttributes();
            if (null == cachedAttributes){
                cachedAttributes = new ArrayList<>();
            }
            for (ERAttributeDTO attribute : relation.getAttributes()) {
                String attributeKey = attribute.getAttribute();
                String attributeValue = attribute.getValue();
                if (StrUtil.isEmpty(attributeKey) || StrUtil.isEmpty(attributeValue)){
                    continue;
                }
                // 如果属性已经存在,则不添加
                if (cachedAttributes.stream().noneMatch(a -> StrUtil.equals(a.getAttribute(), attributeKey))) {
                    cachedAttributes.add(attribute);
                }
            }
        }
    }
    private void mergeAttribute(Map<String, EntityExtractionDTO> entityMap,EntityExtractionDTO entity, String key) {

        EntityExtractionDTO cachedEntity = entityMap.get(key);
        if (null == cachedEntity){
            entityMap.put(key, entity);
        }else {
            if (CollUtil.isEmpty(entity.getAttributes())){
                return;
            }
            // 合并属性
            List<ERAttributeDTO> cachedAttributes = cachedEntity.getAttributes();
            if (null == cachedAttributes){
                cachedAttributes = new ArrayList<>();
                cachedEntity.setAttributes(cachedAttributes);
            }
            for (ERAttributeDTO attribute : entity.getAttributes()) {
                String attributeKey = attribute.getAttribute();
                String attributeValue = attribute.getValue();
                if (StrUtil.isEmpty(attributeKey) || StrUtil.isEmpty(attributeValue)){
                    continue;
                }
                // 如果属性已经存在,则不添加
                if (cachedAttributes.stream().noneMatch(a -> StrUtil.equals(a.getAttribute(), attributeKey))) {
                    cachedAttributes.add(attribute);
                }
            }
        }
    }

    private String generateEntityMapKey(EntityExtractionDTO entityExtractionDTO) {
        return entityExtractionDTO.getEntity() + "_" + entityExtractionDTO.getName();
    }

    private String generateRelationMapKey(RelationExtractionDTO relationExtractionDTO) {
        return relationExtractionDTO.getSource()+ "_" + relationExtractionDTO.getRelation() + "_" + relationExtractionDTO.getTarget();
    }
}