# -*- coding: utf-8 -*-

import logging
from typing import Union

from sqlalchemy import Column, Integer, String, DateTime, func
from sqlalchemy.ext.declarative import declarative_base

from website.db_mysql import get_session

Base = declarative_base()

"""
CREATE TABLE `model` (
  `id` int NOT NULL AUTO_INCREMENT,
  `suid` varchar(10) DEFAULT NULL COMMENT 'short uuid',
  `name` varchar(255) NOT NULL DEFAULT '',
  `model_type` int DEFAULT '1002' COMMENT '模型类型,1001/经典算法,1002/深度学习',
  `classification` int DEFAULT '0' COMMENT '模型分类的id',
  `comment` varchar(255) DEFAULT '' COMMENT '备注',
  `default_version` varchar(100) DEFAULT '',
  `del` tinyint(1) DEFAULT '0' COMMENT '删除状态,1/删除,0/正常',
  `create_time` datetime DEFAULT CURRENT_TIMESTAMP,
  `update_time` datetime DEFAULT NULL,
  PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='模型表';
"""


class Model(Base):
    __tablename__ = "model"

    id = Column(Integer, primary_key=True, autoincrement=True)
    suid = Column(String(10), comment="short uuid")
    name = Column(String(255), nullable=False, default="")
    model_type = Column(Integer, default=1002, comment="模型类型,1001/经典算法,1002/深度学习")
    classification = Column(Integer, default=0, comment="模型分类的id")
    comment = Column(String(255), default="", comment="备注")
    default_version = Column(String(100), default="")
    delete = Column("del", Integer, default=0, comment="删除状态,1/删除,0/正常")
    create_time = Column(DateTime, default=func.now())
    update_time = Column(DateTime, onupdate=func.now())

    def __repr__(self):
        return f"Model(id={self.id}, name='{self.name}', model_type={self.model_type})"


class ModelRepositry(object):

    def get_suid(self, model_id: int) -> str:
        with get_session() as session:
            model = session.query(Model).filter(Model.id == model_id).first()
            if not model or not model.suid:
                return ""

            return model.suid

    def get_model_by_id(self, model_id: int) -> Union[Model, None]:

        with get_session() as session:
            model = session.query(Model).filter(Model.id == model_id).first()
            if not model:
                return None

            return model

    def get_model_dict_by_id(self, model_id: int) -> dict:
        with get_session() as session:
            logging.info(f"model id is : {model_id}")
            model = session.query(Model).filter(Model.id == model_id).first()
            if not model:
                return {}

            model_dict = {
                'id': model.id,
                'suid': model.suid,
                'name': model.name,
                'model_type': model.model_type,
                'classification': model.classification,
                'comment': model.comment,
                'default_version': model.default_version,
                'delete': model.delete,
                'create_time': model.create_time,
                'update_time': model.update_time
            }

            logging.info(f"model dict is : {model_dict}")
            return model_dict

    def get_model_by_ids(self, model_ids: list) -> list:
        with get_session() as session:
            models = session.query(Model).filter(Model.id.in_(model_ids)).all()
            if not models:
                return []

            return models

    def get_model_count(self) -> int:
        with get_session() as session:
            return session.query(Model).filter(Model.delete == 0).count()