You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

132 lines
3.9 KiB
Python

# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
from io import BytesIO
from pathlib import Path
from typing import Any, Union
import cv2
import numpy as np
import requests
from PIL import Image, UnidentifiedImageError
from .utils import is_url
root_dir = Path(__file__).resolve().parent
InputType = Union[str, np.ndarray, bytes, Path, Image.Image]
class LoadImage:
def __init__(self):
pass
def __call__(self, img: InputType) -> np.ndarray:
if not isinstance(img, InputType.__args__):
raise LoadImageError(
f"The img type {type(img)} does not in {InputType.__args__}"
)
origin_img_type = type(img)
img = self.load_img(img)
img = self.convert_img(img, origin_img_type)
return img
def load_img(self, img: InputType) -> np.ndarray:
if isinstance(img, (str, Path)):
if is_url(img):
img = Image.open(requests.get(img, stream=True, timeout=60).raw)
else:
self.verify_exist(img)
img = Image.open(img)
try:
img = self.img_to_ndarray(img)
except UnidentifiedImageError as e:
raise LoadImageError(f"cannot identify image file {img}") from e
return img
if isinstance(img, bytes):
img = self.img_to_ndarray(Image.open(BytesIO(img)))
return img
if isinstance(img, np.ndarray):
return img
if isinstance(img, Image.Image):
return self.img_to_ndarray(img)
raise LoadImageError(f"{type(img)} is not supported!")
def img_to_ndarray(self, img: Image.Image) -> np.ndarray:
if img.mode == "1":
img = img.convert("L")
return np.array(img)
return np.array(img)
def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray:
if img.ndim == 2:
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if img.ndim == 3:
channel = img.shape[2]
if channel == 1:
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if channel == 2:
return self.cvt_two_to_three(img)
if channel == 3:
if issubclass(origin_img_type, (str, Path, bytes, Image.Image)):
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img
if channel == 4:
return self.cvt_four_to_three(img)
raise LoadImageError(
f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
)
raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
@staticmethod
def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
"""gray + alpha → BGR"""
img_gray = img[..., 0]
img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
img_alpha = img[..., 1]
not_a = cv2.bitwise_not(img_alpha)
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
new_img = cv2.add(new_img, not_a)
return new_img
@staticmethod
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
"""RGBA → BGR"""
r, g, b, a = cv2.split(img)
new_img = cv2.merge((b, g, r))
not_a = cv2.bitwise_not(a)
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
new_img = cv2.bitwise_and(new_img, new_img, mask=a)
mean_color = np.mean(new_img)
if mean_color <= 0.0:
new_img = cv2.add(new_img, not_a)
else:
new_img = cv2.bitwise_not(new_img)
return new_img
@staticmethod
def verify_exist(file_path: Union[str, Path]):
if not Path(file_path).exists():
raise LoadImageError(f"{file_path} does not exist.")
class LoadImageError(Exception):
pass