# -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com from io import BytesIO from pathlib import Path from typing import Any, Union import cv2 import numpy as np import requests from PIL import Image, UnidentifiedImageError from .utils import is_url root_dir = Path(__file__).resolve().parent InputType = Union[str, np.ndarray, bytes, Path, Image.Image] class LoadImage: def __init__(self): pass def __call__(self, img: InputType) -> np.ndarray: if not isinstance(img, InputType.__args__): raise LoadImageError( f"The img type {type(img)} does not in {InputType.__args__}" ) origin_img_type = type(img) img = self.load_img(img) img = self.convert_img(img, origin_img_type) return img def load_img(self, img: InputType) -> np.ndarray: if isinstance(img, (str, Path)): if is_url(img): img = Image.open(requests.get(img, stream=True, timeout=60).raw) else: self.verify_exist(img) img = Image.open(img) try: img = self.img_to_ndarray(img) except UnidentifiedImageError as e: raise LoadImageError(f"cannot identify image file {img}") from e return img if isinstance(img, bytes): img = self.img_to_ndarray(Image.open(BytesIO(img))) return img if isinstance(img, np.ndarray): return img if isinstance(img, Image.Image): return self.img_to_ndarray(img) raise LoadImageError(f"{type(img)} is not supported!") def img_to_ndarray(self, img: Image.Image) -> np.ndarray: if img.mode == "1": img = img.convert("L") return np.array(img) return np.array(img) def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray: if img.ndim == 2: return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if img.ndim == 3: channel = img.shape[2] if channel == 1: return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if channel == 2: return self.cvt_two_to_three(img) if channel == 3: if issubclass(origin_img_type, (str, Path, bytes, Image.Image)): return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) return img if channel == 4: return self.cvt_four_to_three(img) raise LoadImageError( f"The channel({channel}) of the img is not in [1, 2, 3, 4]" ) raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") @staticmethod def cvt_two_to_three(img: np.ndarray) -> np.ndarray: """gray + alpha → BGR""" img_gray = img[..., 0] img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) img_alpha = img[..., 1] not_a = cv2.bitwise_not(img_alpha) not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) new_img = cv2.add(new_img, not_a) return new_img @staticmethod def cvt_four_to_three(img: np.ndarray) -> np.ndarray: """RGBA → BGR""" r, g, b, a = cv2.split(img) new_img = cv2.merge((b, g, r)) not_a = cv2.bitwise_not(a) not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) new_img = cv2.bitwise_and(new_img, new_img, mask=a) mean_color = np.mean(new_img) if mean_color <= 0.0: new_img = cv2.add(new_img, not_a) else: new_img = cv2.bitwise_not(new_img) return new_img @staticmethod def verify_exist(file_path: Union[str, Path]): if not Path(file_path).exists(): raise LoadImageError(f"{file_path} does not exist.") class LoadImageError(Exception): pass