import os 
import os.path as osp
from paddlevideo.utils.config import get_config
from paddle.inference import Config, create_predictor
from tools.utils import build_inference_helper

class PP_TSMv2_predict(object):

    """PP-TSMv2模型中常用的参数初始化"""

    def __init__(self,use_gpu=True,ir_optim=True,
                 disable_glog=False,save_name=None,enable_mklddn=False,
                 precision="fp32",gpu_mem=8000,cpu_threads=None):

        self.use_gpu = use_gpu                                 #是否使用GPU
        self.cpu_threads = cpu_threads                         #cpu线程数
        self.ir_optim = ir_optim                               #是否开启IR优化
        self.disable_glog = disable_glog
        self.gpu_mem = gpu_mem                                 #GPU存储大小
        self.enable_mkldnn = enable_mklddn                     #是否开启mkldnn
        self.precision = precision                             #mfldnn精度
        self.save_name = save_name                             #转化推理模型存放名称



    def parse_file_paths(self,input_path: str) -> list:
        
        """
            获取模型输入数据
            input_path:模型的输入文件
        """
        if osp.isfile(input_path):
            files = [
                input_path,
            ]
        else:
            files = os.listdir(input_path)
            files = [
                file for file in files
                if (file.endswith(".avi") or file.endswith(".mp4"))
            ]
            files = [osp.join(input_path, file) for file in files]
        return files


    def create_paddle_predictor(self,model_f,pretr_p,cfg):
        """
            创建推理引擎
            model_f:可推理模型存放的路径+配置文件
            pretr_p:训练后的参数存放文件
            cfg:模型配置文件

        """
        config = Config(model_f,pretr_p)
        if self.use_gpu:
            config.enable_use_gpu(self.gpu_mem,0)
        else:
            config.disable_gpu()
            if self.cpu_threads:
                config.set_cpu_math_library_num_threads(self.cpu_threads)
            if self.enable_mkldnn:
                config.set_mkldnn_cache_capacity(10)
                config.enable_mkldnn()
                if self.precision == "fp16":
                    config.enable_mkldnn_bfloat16()

        config.switch_ir_optim(self.ir_optim)
        
        config.enable_memory_optim()
        config.switch_use_feed_fetch_ops(False)

        if self.disable_glog:
            config.disable_glog_info()

        predictor = create_predictor(config)

        return config,predictor

    def create_inference_model(self,config,model_f,params_f):
        """
            创建推理模型以及引擎
            config:模型配置文件
            model_f:可推理模型的存放路径
            params_f:可推理模型的参数
        """
        cfg = get_config(config, overrides=None, show=False)
        InferenceHelper = build_inference_helper(cfg.INFERENCE)
        _, predictor = self.create_paddle_predictor(model_f, params_f, cfg)

        return InferenceHelper,predictor


    def predict(self,input_f,batch_size,predictor,InferenceHelper):

        """
            推理模型,对数据进行推理、预测
            config :PP-TSMv2模型的配置文件
            input_f:待推理数据集的存放路径
            batch_size:模型推理中所取数据的多少,default = 1
            predictor:推理引擎
            InferenceHelper:推理模型
        """
        result = {}

        # cfg = get_config(config, overrides=None, show=False)
        # model_name = cfg.model_name
        # print(f"Inference model({model_name})...")

        # get input_tensor and output_tensor
        input_names = predictor.get_input_names()
        output_names = predictor.get_output_names()
        input_tensor_list = []
        output_tensor_list = []
        for item in input_names:
            input_tensor_list.append(predictor.get_input_handle(item))
        for item in output_names:
            output_tensor_list.append(predictor.get_output_handle(item))

        files = self.parse_file_paths(input_f)#input_path=input_f

        batch_num = batch_size
        for st_idx in range(0, len(files), batch_num):
            ed_idx = min(st_idx + batch_num, len(files))

            #输出数据预处理
            batched_inputs = InferenceHelper.preprocess_batch(files[st_idx:ed_idx])
            for i in range(len(input_tensor_list)):
                input_tensor_list[i].copy_from_cpu(batched_inputs[i])

            #推理引擎开始推理
            predictor.run()

            batched_outputs = []
            for j in range(len(output_tensor_list)):
                batched_outputs.append(output_tensor_list[j].copy_to_cpu())

            #输出推理结果
            res = InferenceHelper.postprocess(batched_outputs,False,True)     
        result["video_id"] = res[0]["video_id"]
        result["topk_class"] = res[0]["topk_class"].tolist()[0]
        result["topk_scores"] = res[0]["topk_scores"].tolist()[0]
        # print(result)
        
        return result
            


# def main():
#     config = 'D:/download/PaddleVideo1/output/output/pptsm_lcnet_k400_16frames_uniform.yaml'  # 配置文件地址
#     input_file='C:/Users/Administrator/Pictures/video_seg_re_hand/test01_3.avi'                      #待推理数据集存放的地址
#     model_file = 'D:/download/PaddleVideo1/output/output/ppTSMv2.pdmodel'  # 推理模型存放地址
#     params_file = 'D:/download/PaddleVideo1/output/output/ppTSMv2.pdiparams'
#     batch_size= 1                                                          #输出推理模型
#     infer,predictor = PP_TSMv2_predict().create_inference_model(config,model_file,params_file)
#     PP_TSMv2_predict().predict(config,input_file,batch_size,predictor,infer)                           #推理模型推理、预测

    
    
    
# if __name__ == "__main__":
#     main()