import io
from pathlib import Path
from typing import Optional, Union

import requests
from tqdm import tqdm

from .logger import Logger

PROJECT_DIR = Path(__file__).resolve().parent.parent
DEFAULT_MODEL_DIR = PROJECT_DIR / "models"


class DownloadModel:
    logger = Logger(logger_name=__name__).get_log()

    @classmethod
    def download(
        cls,
        model_full_url: Union[str, Path],
        save_dir: Union[str, Path, None] = None,
        save_model_name: Optional[str] = None,
    ) -> str:
        if save_dir is None:
            save_dir = DEFAULT_MODEL_DIR

        save_dir.mkdir(parents=True, exist_ok=True)

        if save_model_name is None:
            save_model_name = Path(model_full_url).name

        save_file_path = save_dir / save_model_name
        if save_file_path.exists():
            cls.logger.info("%s already exists", save_file_path)
            return str(save_file_path)

        try:
            cls.logger.info("Download %s to %s", model_full_url, save_dir)
            file = cls.download_as_bytes_with_progress(model_full_url, save_model_name)
            cls.save_file(save_file_path, file)
        except Exception as exc:
            raise DownloadModelError from exc
        return str(save_file_path)

    @staticmethod
    def download_as_bytes_with_progress(
        url: Union[str, Path], name: Optional[str] = None
    ) -> bytes:
        resp = requests.get(str(url), stream=True, allow_redirects=True, timeout=180)
        total = int(resp.headers.get("content-length", 0))
        bio = io.BytesIO()
        with tqdm(
            desc=name, total=total, unit="b", unit_scale=True, unit_divisor=1024
        ) as pbar:
            for chunk in resp.iter_content(chunk_size=65536):
                pbar.update(len(chunk))
                bio.write(chunk)
        return bio.getvalue()

    @staticmethod
    def save_file(save_path: Union[str, Path], file: bytes):
        with open(save_path, "wb") as f:
            f.write(file)


class DownloadModelError(Exception):
    pass