import torch
from transformers import AutoModelForObjectDetection
from surya.settings import settings
import numpy as np


class MaxResize(object):
    def __init__(self, max_size=800):
        self.max_size = max_size

    def __call__(self, image):
        width, height = image.size
        current_max_size = max(width, height)
        scale = self.max_size / current_max_size
        resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))

        return resized_image


def to_tensor(image):
    # Convert PIL Image to NumPy array
    np_image = np.array(image).astype(np.float32)

    # Rearrange dimensions to [C, H, W] format
    np_image = np_image.transpose((2, 0, 1))

    # Normalize to [0.0, 1.0]
    np_image /= 255.0

    return torch.from_numpy(np_image)


def normalize(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.sub_(m).div_(s)
    return tensor


def structure_transform(image):
    image = MaxResize(1000)(image)
    tensor = to_tensor(image)
    normalized_tensor = normalize(tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    return normalized_tensor


def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)


def rescale_bboxes(out_bbox, size):
    width, height = size
    boxes = box_cxcywh_to_xyxy(out_bbox)
    boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32)
    return boxes


def outputs_to_objects(outputs, img_sizes, id2label):
    m = outputs.logits.softmax(-1).max(-1)
    batch_labels = list(m.indices.detach().cpu().numpy())
    batch_scores = list(m.values.detach().cpu().numpy())
    batch_bboxes = outputs['pred_boxes'].detach().cpu()

    batch_objects = []
    for i in range(len(img_sizes)):
        pred_bboxes = [elem.tolist() for elem in rescale_bboxes(batch_bboxes[i], img_sizes[i])]
        pred_scores = batch_scores[i]
        pred_labels = batch_labels[i]

        objects = []
        for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
            class_label = id2label[int(label)]
            if not class_label == 'no object':
                objects.append({
                    'label': class_label,
                    'score': float(score),
                    'bbox': [float(elem) for elem in bbox]}
                )

        rows = []
        cols = []
        for cell in objects:
            if cell["label"] == "table column":
                cols.append(cell)

            if cell["label"] == "table row":
                rows.append(cell)
        batch_objects.append({
            "rows": rows,
            "cols": cols
        })

    return batch_objects


def load_tatr():
    return AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(settings.TORCH_DEVICE_MODEL)


def batch_inference_tatr(model, images, batch_size):
    device = model.device
    rows_cols = []
    for i in range(0, len(images), batch_size):
        batch_images = images[i:i + batch_size]
        pixel_values = torch.stack([structure_transform(img) for img in batch_images], dim=0).to(device)

        # forward pass
        with torch.no_grad():
            outputs = model(pixel_values)

        id2label = model.config.id2label
        id2label[len(model.config.id2label)] = "no object"
        rows_cols.extend(outputs_to_objects(outputs, [img.size for img in batch_images], id2label))
    return rows_cols