from export_model import trim_config,get_input_spec
from predict import parse_file_paths
import os 
import os.path as osp
import paddle
from paddlevideo.utils import get_config
from paddlevideo.modeling.builder import build_model
from paddle.jit import to_static
from paddle.inference import Config, create_predictor
from utils import build_inference_helper
import time
import warnings
warnings.filterwarnings("ignore")


class PP_TSMv2(object):

    def __init__(self,use_gpu=True,batch_size=1,ir_optim=True,\
                 disable_glog=False,save_name=None,enable_mklddn=False,\
                 precision="fp32",gpu_mem=8000,cpu_threads=None,time_test_file=False):

        self.use_gpu = use_gpu
        self.cpu_threads = cpu_threads
        self.batch_size = batch_size
        self.ir_optim = ir_optim
        self.disable_glog = disable_glog
        self.gpu_mem = gpu_mem
        self.enable_mkldnn = enable_mklddn
        self.precision = precision
        self.save_name = save_name
        self.time_test_file = time_test_file

    def create_paddle_predictor(self,model_f,pretr_p,cfg):

        config = Config(model_f,pretr_p)
        if self.use_gpu:
            config.enable_use_gpu(self.gpu_mem,0)
        else:
            config.disable_gpu()
            if self.cpu_threads:
                config.set_cpu_math_library_num_threads(self.cpu_threads)
            if self.enable_mkldnn:
                config.set_mkldnn_cache_capacity(10)
                config.enable_mkldnn()
                if self.precision == "fp16":
                    config.enable_mkldnn_bfloat16()

        config.switch_ir_optim(self.ir_optim)
        
        config.enable_memory_optim()
        config.switch_use_feed_fetch_ops(False)

        if self.disable_glog:
            config.disable_glog_info()

        predictor = create_predictor(config)

        return config,predictor


    def exportmodel(self,config,pretr_p,output_p):

        cfg, model_name = trim_config(get_config(config, overrides=None, show=False))
        # pretr_p = str(pretr_p)

        print(f"Building model({model_name})...")
        model = build_model(cfg.MODEL)
        assert osp.isfile(
            pretr_p
        ), f"pretrained params ({pretr_p} is not a file path.)"

        if not os.path.isdir(output_p):
            os.makedirs(output_p)

        print(f"Loading params from ({pretr_p})...")
        params = paddle.load(pretr_p)
        model.set_dict(params)

        model.eval()

        for layer in model.sublayers():
            if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
                layer.rep()

        input_spec = get_input_spec(cfg.INFERENCE, model_name)
        model = to_static(model, input_spec=input_spec)
        paddle.jit.save(model,osp.join(output_p, model_name if self.save_name is None else self.save_name))
        print(f"model ({model_name}) has been already saved in ({output_p}).")
        return model

    def predict(self,config,input_f,batch_size,model_f,params_f):

        cfg = get_config(config, overrides=None, show=False)
        model_name = cfg.model_name
        print(f"Inference model({model_name})...")
        InferenceHelper = build_inference_helper(cfg.INFERENCE)

        _ , predictor = self.create_paddle_predictor(model_f,params_f,cfg) # 要改 model_f,pretr_p,cfg

        # get input_tensor and output_tensor
        input_names = predictor.get_input_names()
        output_names = predictor.get_output_names()
        input_tensor_list = []
        output_tensor_list = []
        for item in input_names:
            input_tensor_list.append(predictor.get_input_handle(item))
        for item in output_names:
            output_tensor_list.append(predictor.get_output_handle(item))

        files = parse_file_paths(input_f)

        batch_num = batch_size
        for st_idx in range(0, len(files), batch_num):
            ed_idx = min(st_idx + batch_num, len(files))

            batched_inputs = InferenceHelper.preprocess_batch(files[st_idx:ed_idx])
            for i in range(len(input_tensor_list)):
                input_tensor_list[i].copy_from_cpu(batched_inputs[i])
            predictor.run()

            batched_outputs = []
            for j in range(len(output_tensor_list)):
                batched_outputs.append(output_tensor_list[j].copy_to_cpu())
            
            InferenceHelper.postprocess(batched_outputs,True)
            



def main():

    config='/home/xznsh/data/PaddleVideo/configs/recognition/pptsm/v2/pptsm_lcnet_k400_16frames_uniform.yaml'   #配置文件地址
    input_file='/home/xznsh/data/PaddleVideo/data/dataset/video_seg_re_hand'                                          #推理数据集存放的地址
    pretrain_params='/home/xznsh/data/PaddleVideo/output/ppTSMv2/ppTSMv2_best.pdparams'                         #训练后模型参数文件存放
    output_path='/home/xznsh/data/PaddleVideo/inference/infer1'                                                           #推理模型存放地址
    model_file='/home/xznsh/data/PaddleVideo/inference/infer1/ppTSMv2.pdmodel'                                      #推理模型存放地址
    params_file='/home/xznsh/data/PaddleVideo/inference/infer1/ppTSMv2.pdiparams'                                   #推理模型参数存放地址
    batch_size= 1
    PP_TSMv2().exportmodel(config,pretrain_params,output_path)                                             #输出推理模型
    time.sleep(2)
    PP_TSMv2().predict(config,input_file,batch_size,model_file,params_file)
    

if __name__ == "__main__":
    main()