You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
146 lines
4.6 KiB
Python
146 lines
4.6 KiB
Python
# -*- encoding: utf-8 -*-
|
|
# @Author: SWHL
|
|
# @Contact: liekkaskono@163.com
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Optional, Union
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
from .load_image import LoadImage
|
|
|
|
|
|
class VisTable:
|
|
def __init__(self):
|
|
self.load_img = LoadImage()
|
|
|
|
def __call__(
|
|
self,
|
|
img_path: Union[str, Path],
|
|
table_results,
|
|
save_html_path: Optional[str] = None,
|
|
save_drawed_path: Optional[str] = None,
|
|
save_logic_path: Optional[str] = None,
|
|
):
|
|
if save_html_path:
|
|
html_with_border = self.insert_border_style(table_results.pred_html)
|
|
self.save_html(save_html_path, html_with_border)
|
|
|
|
table_cell_bboxes = table_results.cell_bboxes
|
|
if table_cell_bboxes is None:
|
|
return None
|
|
|
|
img = self.load_img(img_path)
|
|
|
|
dims_bboxes = table_cell_bboxes.shape[1]
|
|
if dims_bboxes == 4:
|
|
drawed_img = self.draw_rectangle(img, table_cell_bboxes)
|
|
elif dims_bboxes == 8:
|
|
drawed_img = self.draw_polylines(img, table_cell_bboxes)
|
|
else:
|
|
raise ValueError("Shape of table bounding boxes is not between in 4 or 8.")
|
|
|
|
if save_drawed_path:
|
|
self.save_img(save_drawed_path, drawed_img)
|
|
|
|
if save_logic_path and table_results.logic_points:
|
|
polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes]
|
|
self.plot_rec_box_with_logic_info(
|
|
img, save_logic_path, table_results.logic_points, polygons
|
|
)
|
|
return drawed_img
|
|
|
|
def insert_border_style(self, table_html_str: str):
|
|
style_res = """<meta charset="UTF-8"><style>
|
|
table {
|
|
border-collapse: collapse;
|
|
width: 100%;
|
|
}
|
|
th, td {
|
|
border: 1px solid black;
|
|
padding: 8px;
|
|
text-align: center;
|
|
}
|
|
th {
|
|
background-color: #f2f2f2;
|
|
}
|
|
</style>"""
|
|
|
|
prefix_table, suffix_table = table_html_str.split("<body>")
|
|
html_with_border = f"{prefix_table}{style_res}<body>{suffix_table}"
|
|
return html_with_border
|
|
|
|
def plot_rec_box_with_logic_info(
|
|
self, img: np.ndarray, output_path, logic_points, sorted_polygons
|
|
):
|
|
"""
|
|
:param img_path
|
|
:param output_path
|
|
:param logic_points: [row_start,row_end,col_start,col_end]
|
|
:param sorted_polygons: [xmin,ymin,xmax,ymax]
|
|
:return:
|
|
"""
|
|
# 读取原图
|
|
img = cv2.copyMakeBorder(
|
|
img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
|
|
)
|
|
# 绘制 polygons 矩形
|
|
for idx, polygon in enumerate(sorted_polygons):
|
|
x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
|
|
x0 = round(x0)
|
|
y0 = round(y0)
|
|
x1 = round(x1)
|
|
y1 = round(y1)
|
|
cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
|
|
# 增大字体大小和线宽
|
|
font_scale = 0.9 # 原先是0.5
|
|
thickness = 1 # 原先是1
|
|
logic_point = logic_points[idx]
|
|
cv2.putText(
|
|
img,
|
|
f"row: {logic_point[0]}-{logic_point[1]}",
|
|
(x0 + 3, y0 + 8),
|
|
cv2.FONT_HERSHEY_PLAIN,
|
|
font_scale,
|
|
(0, 0, 255),
|
|
thickness,
|
|
)
|
|
cv2.putText(
|
|
img,
|
|
f"col: {logic_point[2]}-{logic_point[3]}",
|
|
(x0 + 3, y0 + 18),
|
|
cv2.FONT_HERSHEY_PLAIN,
|
|
font_scale,
|
|
(0, 0, 255),
|
|
thickness,
|
|
)
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
# 保存绘制后的图像
|
|
self.save_img(output_path, img)
|
|
|
|
@staticmethod
|
|
def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
|
|
img_copy = img.copy()
|
|
for box in boxes.astype(int):
|
|
x1, y1, x2, y2 = box
|
|
cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2)
|
|
return img_copy
|
|
|
|
@staticmethod
|
|
def draw_polylines(img: np.ndarray, points) -> np.ndarray:
|
|
img_copy = img.copy()
|
|
for point in points.astype(int):
|
|
point = point.reshape(4, 2)
|
|
cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2)
|
|
return img_copy
|
|
|
|
@staticmethod
|
|
def save_img(save_path: Union[str, Path], img: np.ndarray):
|
|
cv2.imwrite(str(save_path), img)
|
|
|
|
@staticmethod
|
|
def save_html(save_path: Union[str, Path], html: str):
|
|
with open(save_path, "w", encoding="utf-8") as f:
|
|
f.write(html)
|