You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

184 lines
6.2 KiB
Python

# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/lukas-blecher/LaTeX-OCR/blob/main/pix2tex/dataset/transforms.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
import math
import cv2
import numpy as np
import albumentations as A
from PIL import Image
class LatexTrainTransform:
def __init__(self, bitmap_prob=0.04, **kwargs):
# your init code
self.bitmap_prob = bitmap_prob
self.train_transform = A.Compose(
[
A.Compose(
[
A.ShiftScaleRotate(
shift_limit=0,
scale_limit=(-0.15, 0),
rotate_limit=1,
border_mode=0,
interpolation=3,
value=[255, 255, 255],
p=1,
),
A.GridDistortion(
distort_limit=0.1,
border_mode=0,
interpolation=3,
value=[255, 255, 255],
p=0.5,
),
],
p=0.15,
),
A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.3),
A.GaussNoise(10, p=0.2),
A.RandomBrightnessContrast(0.05, (-0.2, 0), True, p=0.2),
A.ImageCompression(95, p=0.3),
A.ToGray(always_apply=True),
]
)
def __call__(self, data):
img = data["image"]
if np.random.random() < self.bitmap_prob:
img[img != 255] = 0
img = self.train_transform(image=img)["image"]
data["image"] = img
return data
class LatexTestTransform:
def __init__(self, **kwargs):
# your init code
self.test_transform = A.Compose(
[
A.ToGray(always_apply=True),
]
)
def __call__(self, data):
img = data["image"]
img = self.test_transform(image=img)["image"]
data["image"] = img
return data
class MinMaxResize:
def __init__(self, min_dimensions=[32, 32], max_dimensions=[672, 192], **kwargs):
# your init code
self.min_dimensions = min_dimensions
self.max_dimensions = max_dimensions
# pass
def pad_(self, img, divable=32):
threshold = 128
data = np.array(img.convert("LA"))
if data[..., -1].var() == 0:
data = (data[..., 0]).astype(np.uint8)
else:
data = (255 - data[..., -1]).astype(np.uint8)
data = (data - data.min()) / (data.max() - data.min()) * 255
if data.mean() > threshold:
# To invert the text to white
gray = 255 * (data < threshold).astype(np.uint8)
else:
gray = 255 * (data > threshold).astype(np.uint8)
data = 255 - data
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
rect = data[b : b + h, a : a + w]
im = Image.fromarray(rect).convert("L")
dims = []
for x in [w, h]:
div, mod = divmod(x, divable)
dims.append(divable * (div + (1 if mod > 0 else 0)))
padded = Image.new("L", dims, 255)
padded.paste(im, (0, 0, im.size[0], im.size[1]))
return padded
def minmax_size_(self, img, max_dimensions, min_dimensions):
if max_dimensions is not None:
ratios = [a / b for a, b in zip(img.size, max_dimensions)]
if any([r > 1 for r in ratios]):
size = np.array(img.size) // max(ratios)
img = img.resize(tuple(size.astype(int)), Image.BILINEAR)
if min_dimensions is not None:
# hypothesis: there is a dim in Crop_img smaller than min_dimensions, and return a proper dim >= min_dimensions
padded_size = [
max(img_dim, min_dim)
for img_dim, min_dim in zip(img.size, min_dimensions)
]
if padded_size != list(img.size): # assert hypothesis
padded_im = Image.new("L", padded_size, 255)
padded_im.paste(img, img.getbbox())
img = padded_im
return img
def __call__(self, data):
img = data["image"]
h, w = img.shape[:2]
if (
self.min_dimensions[0] <= w <= self.max_dimensions[0]
and self.min_dimensions[1] <= h <= self.max_dimensions[1]
):
return data
else:
im = Image.fromarray(np.uint8(img))
im = self.minmax_size_(
self.pad_(im), self.max_dimensions, self.min_dimensions
)
im = np.array(im)
im = np.dstack((im, im, im))
data["image"] = im
return data
class LatexImageFormat:
def __init__(self, **kwargs):
# your init code
pass
def __call__(self, data):
img = data["image"]
im_h, im_w = img.shape[:2]
divide_h = math.ceil(im_h / 16) * 16
divide_w = math.ceil(im_w / 16) * 16
img = img[:, :, 0]
img = np.pad(
img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
)
img_expanded = img[:, :, np.newaxis].transpose(2, 0, 1)
data["image"] = img_expanded
return data