You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
132 lines
3.9 KiB
Python
132 lines
3.9 KiB
Python
# -*- encoding: utf-8 -*-
|
|
# @Author: SWHL
|
|
# @Contact: liekkaskono@163.com
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
from typing import Any, Union
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import requests
|
|
from PIL import Image, UnidentifiedImageError
|
|
|
|
from .utils import is_url
|
|
|
|
root_dir = Path(__file__).resolve().parent
|
|
InputType = Union[str, np.ndarray, bytes, Path, Image.Image]
|
|
|
|
|
|
class LoadImage:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def __call__(self, img: InputType) -> np.ndarray:
|
|
if not isinstance(img, InputType.__args__):
|
|
raise LoadImageError(
|
|
f"The img type {type(img)} does not in {InputType.__args__}"
|
|
)
|
|
|
|
origin_img_type = type(img)
|
|
img = self.load_img(img)
|
|
img = self.convert_img(img, origin_img_type)
|
|
return img
|
|
|
|
def load_img(self, img: InputType) -> np.ndarray:
|
|
if isinstance(img, (str, Path)):
|
|
if is_url(img):
|
|
img = Image.open(requests.get(img, stream=True, timeout=60).raw)
|
|
else:
|
|
self.verify_exist(img)
|
|
img = Image.open(img)
|
|
|
|
try:
|
|
img = self.img_to_ndarray(img)
|
|
except UnidentifiedImageError as e:
|
|
raise LoadImageError(f"cannot identify image file {img}") from e
|
|
return img
|
|
|
|
if isinstance(img, bytes):
|
|
img = self.img_to_ndarray(Image.open(BytesIO(img)))
|
|
return img
|
|
|
|
if isinstance(img, np.ndarray):
|
|
return img
|
|
|
|
if isinstance(img, Image.Image):
|
|
return self.img_to_ndarray(img)
|
|
|
|
raise LoadImageError(f"{type(img)} is not supported!")
|
|
|
|
def img_to_ndarray(self, img: Image.Image) -> np.ndarray:
|
|
if img.mode == "1":
|
|
img = img.convert("L")
|
|
return np.array(img)
|
|
return np.array(img)
|
|
|
|
def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray:
|
|
if img.ndim == 2:
|
|
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
|
|
if img.ndim == 3:
|
|
channel = img.shape[2]
|
|
if channel == 1:
|
|
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
|
|
if channel == 2:
|
|
return self.cvt_two_to_three(img)
|
|
|
|
if channel == 3:
|
|
if issubclass(origin_img_type, (str, Path, bytes, Image.Image)):
|
|
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
|
return img
|
|
|
|
if channel == 4:
|
|
return self.cvt_four_to_three(img)
|
|
|
|
raise LoadImageError(
|
|
f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
|
|
)
|
|
|
|
raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
|
|
|
|
@staticmethod
|
|
def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
|
|
"""gray + alpha → BGR"""
|
|
img_gray = img[..., 0]
|
|
img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
|
|
|
|
img_alpha = img[..., 1]
|
|
not_a = cv2.bitwise_not(img_alpha)
|
|
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
|
|
|
|
new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
|
|
new_img = cv2.add(new_img, not_a)
|
|
return new_img
|
|
|
|
@staticmethod
|
|
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
|
|
"""RGBA → BGR"""
|
|
r, g, b, a = cv2.split(img)
|
|
new_img = cv2.merge((b, g, r))
|
|
|
|
not_a = cv2.bitwise_not(a)
|
|
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
|
|
|
|
new_img = cv2.bitwise_and(new_img, new_img, mask=a)
|
|
|
|
mean_color = np.mean(new_img)
|
|
if mean_color <= 0.0:
|
|
new_img = cv2.add(new_img, not_a)
|
|
else:
|
|
new_img = cv2.bitwise_not(new_img)
|
|
return new_img
|
|
|
|
@staticmethod
|
|
def verify_exist(file_path: Union[str, Path]):
|
|
if not Path(file_path).exists():
|
|
raise LoadImageError(f"{file_path} does not exist.")
|
|
|
|
|
|
class LoadImageError(Exception):
|
|
pass
|