You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

68 lines
2.0 KiB
Python

1 month ago
import io
from pathlib import Path
from typing import Optional, Union
import requests
from tqdm import tqdm
from .logger import Logger
PROJECT_DIR = Path(__file__).resolve().parent.parent
DEFAULT_MODEL_DIR = PROJECT_DIR / "models"
class DownloadModel:
logger = Logger(logger_name=__name__).get_log()
@classmethod
def download(
cls,
model_full_url: Union[str, Path],
save_dir: Union[str, Path, None] = None,
save_model_name: Optional[str] = None,
) -> str:
if save_dir is None:
save_dir = DEFAULT_MODEL_DIR
save_dir.mkdir(parents=True, exist_ok=True)
if save_model_name is None:
save_model_name = Path(model_full_url).name
save_file_path = save_dir / save_model_name
if save_file_path.exists():
cls.logger.info("%s already exists", save_file_path)
return str(save_file_path)
try:
cls.logger.info("Download %s to %s", model_full_url, save_dir)
file = cls.download_as_bytes_with_progress(model_full_url, save_model_name)
cls.save_file(save_file_path, file)
except Exception as exc:
raise DownloadModelError from exc
return str(save_file_path)
@staticmethod
def download_as_bytes_with_progress(
url: Union[str, Path], name: Optional[str] = None
) -> bytes:
resp = requests.get(str(url), stream=True, allow_redirects=True, timeout=180)
total = int(resp.headers.get("content-length", 0))
bio = io.BytesIO()
with tqdm(
desc=name, total=total, unit="b", unit_scale=True, unit_divisor=1024
) as pbar:
for chunk in resp.iter_content(chunk_size=65536):
pbar.update(len(chunk))
bio.write(chunk)
return bio.getvalue()
@staticmethod
def save_file(save_path: Union[str, Path], file: bytes):
with open(save_path, "wb") as f:
f.write(file)
class DownloadModelError(Exception):
pass