# -*- coding: utf-8 -*- # @Time : 2019/12/4 13:12 # @Author : zhoujun import copy from paddle.io import Dataset from data_loader.modules import * class BaseDataSet(Dataset): def __init__( self, data_path: str, img_mode, pre_processes, filter_keys, ignore_tags, transform=None, target_transform=None, ): assert img_mode in ["RGB", "BRG", "GRAY"] self.ignore_tags = ignore_tags self.data_list = self.load_data(data_path) item_keys = ["img_path", "img_name", "text_polys", "texts", "ignore_tags"] for item in item_keys: assert ( item in self.data_list[0] ), "data_list from load_data must contains {}".format(item_keys) self.img_mode = img_mode self.filter_keys = filter_keys self.transform = transform self.target_transform = target_transform self._init_pre_processes(pre_processes) def _init_pre_processes(self, pre_processes): self.aug = [] if pre_processes is not None: for aug in pre_processes: if "args" not in aug: args = {} else: args = aug["args"] if isinstance(args, dict): cls = eval(aug["type"])(**args) else: cls = eval(aug["type"])(args) self.aug.append(cls) def load_data(self, data_path: str) -> list: """ 把数据加载为一个list: :params data_path: 存储数据的文件夹或者文件 return a dict ,包含了,'img_path','img_name','text_polys','texts','ignore_tags' """ raise NotImplementedError def apply_pre_processes(self, data): for aug in self.aug: data = aug(data) return data def __getitem__(self, index): try: data = copy.deepcopy(self.data_list[index]) im = cv2.imread(data["img_path"], 1 if self.img_mode != "GRAY" else 0) if self.img_mode == "RGB": im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) data["Crop_img"] = im data["shape"] = [im.shape[0], im.shape[1]] data = self.apply_pre_processes(data) if self.transform: data["Crop_img"] = self.transform(data["Crop_img"]) data["text_polys"] = data["text_polys"].tolist() if len(self.filter_keys): data_dict = {} for k, v in data.items(): if k not in self.filter_keys: data_dict[k] = v return data_dict else: return data except: return self.__getitem__(np.random.randint(self.__len__())) def __len__(self): return len(self.data_list)