# -*- coding: utf-8 -*-
import logging
import os

from sqlalchemy import text

from website import consts
from website import db_mysql
from website import errors
from website import settings
from website.handler import APIHandler, authenticated
from website.util import md5, shortuuid


class ClassificationAddHandler(APIHandler):
    """
    添加模型分类

    """

    @authenticated
    def post(self):
        name = self.get_escaped_argument("name", "")
        if not name:
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")

        if len(name) > 128:
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称过长")

        with self.app_mysql.connect() as conn:
            cur = conn.execute(
                text("select id from model_classification where name=:name"),
                {"name": name},
            )
            row = db_mysql.to_json(cur)
            if row:
                raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称重复")

            conn.execute(
                text(
                    """insert into model_classification (name, create_time) values (:name, NOW())"""
                ),
                {"name": name},
            )
            conn.commit()

        self.finish()


class ClassificationEditHandler(APIHandler):
    """
    编辑模型分类
    """

    @authenticated
    def post(self):
        classification_id = self.get_int_argument("id")
        name = self.get_escaped_argument("name", "")
        if not classification_id or not name:
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
        with self.app_mysql.connect() as conn:
            conn.execute(
                text("""update model_classification set name=:name where id=:id"""),
                {"name": name, "id": classification_id},
            )

            conn.commit()

        self.finish()


class ClassificationListHandler(APIHandler):
    """
    模型分类列表
    """

    @authenticated
    def post(self):
        with self.app_mysql.connect() as conn:
            cur = conn.execute(text("""select id, name from model_classification"""))
            result = db_mysql.to_json_list(cur)

        self.finish({"data": result})


class ClassificationDeleteHandler(APIHandler):
    """
    删除模型分类
    """

    @authenticated
    def post(self):
        classification_id = self.get_int_argument("id")
        if not classification_id:
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")

        with self.app_mysql.connect() as conn:
            conn.execute(
                text("""DELETE FROM model_classification WHERE id=:id"""),
                {"id": classification_id},
            )
            conn.commit()

        self.finish()


class ListHandler(APIHandler):
    """
    模型列表
    """

    @authenticated
    def post(self):
        pageNo = self.get_int_argument("pageNo", 1)
        pageSize = self.get_int_argument("pageSize", consts.PAGE_SIZE)
        name = self.get_escaped_argument("name", "")

        result = []
        with self.app_mysql.connect() as conn:
            sql = (
                "select m.id, m.name, m.model_type, m.default_version, mc.name classification_name, m.update_time "
                "from model m left join model_classification mc on m.classification=mc.id where m.del=0 "
            )

            param = {}

            sql_count = "select count(id) from model where del=0 "
            param_count = {}

            if name:
                sql += "and m.name like :name"
                param["name"] = "%{}%".format(name)

                sql_count += "and m.name like :name"
                param_count["name"] = "%{}%".format(name)

            sql += " order by m.id desc limit :pageSize offset :offset"
            param["pageSize"] = pageSize
            param["offset"] = (pageNo - 1) * pageSize

            cur = conn.execute(text(sql), param)
            result = db_mysql.to_json_list(cur)

            count = conn.execute(text(sql_count), param_count).fetchone()[0]

        data = []
        for item in result:
            data.append(
                {
                    "id": item["id"],
                    "name": item["name"],
                    "model_type": consts.model_type_map[item["model_type"]],
                    "classification_name": item["classification_name"],
                    "default_version": item["default_version"],
                    "update_time": str(item["update_time"]),
                }
            )

        self.finish({"data": data, "count": count})


class ListSimpleHandler(APIHandler):
    @authenticated
    def post(self):
        with self.app_mysql.connect() as conn:
            sql = "select id, name from model where del=0"
            cur = conn.execute(text(sql))
            res = db_mysql.to_json_list(cur)

        self.finish({"result": res})


class AddHandler(APIHandler):
    """
    添加模型
    """

    @authenticated
    def post(self):
        name = self.get_escaped_argument("name", "")
        model_type = self.get_int_argument(
            "model_type", consts.model_type_machine
        )  # 1001/经典算法,1002/深度学习
        classification = self.get_int_argument("classification")
        comment = self.get_escaped_argument("comment", "")

        if not name or not classification or model_type not in consts.model_type_map:
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")

        with self.app_mysql.connect() as conn:
            sql = text("select id from model_classification where id=:id")
            cur = conn.execute(sql, {"id": classification})
            if not cur.fetchone():
                raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类不存在")

            conn.execute(
                text(
                    """insert into model (suid, name, model_type, classification, comment, create_time, update_time) 
                    values (:suid, :name, :model_type, :classification, :comment, NOW(), NOW())"""
                ),
                {
                    "suid": shortuuid.ShortUUID().random(10),
                    "name": name,
                    "model_type": model_type,
                    "classification": classification,
                    "comment": comment,
                },
            )

            conn.commit()

        self.finish()


class EditHandler(APIHandler):
    """
    编辑模型
    """

    @authenticated
    def post(self):
        mid = self.get_int_argument("id")
        name = self.get_escaped_argument("name", "")
        model_type = self.get_int_argument(
            "model_type", consts.model_type_machine
        )  # 1001/经典算法,1002/深度学习
        classification = self.get_int_argument("classification")
        comment = self.get_escaped_argument("comment", "")

        if not name or not classification or model_type not in consts.model_type_map:
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")

        with self.app_mysql.connect() as conn:
            conn.execute(
                text(
                    """update model 
                    set name=:name, model_type=:model_type, classification=:classification, comment=:comment, 
                    update_time=NOW() 
                    where id=:id"""
                ),
                {
                    "name": name,
                    "model_type": model_type,
                    "classification": classification,
                    "comment": comment,
                    "id": mid,
                },
            )

            conn.commit()

        self.finish()


class InfoHandler(APIHandler):
    """
    模型信息
    """

    @authenticated
    def post(self):
        mid = self.get_int_argument("id")
        if not mid:
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
        result = {}
        with self.app_mysql.connect() as conn:
            cur = conn.execute(
                text(
                    """
                    select 
                        m.name, m.model_type, m.comment, m.update_time, 
                        mc.id as classification_id, mc.name as classification_name 
                    from model m, model_classification mc 
                    where m.id=:id and m.classification=mc.id
                    """
                ),
                {"id": mid},
            )

            result = db_mysql.to_json(cur)
            if not result:
                raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型不存在")

        data = {
            "name": result["name"],
            "model_type": result["model_type"],
            "classification_id": result["classification_id"],
            "classification_name": result["classification_name"],
            "comment": result["comment"],
            "update_time": str(result["update_time"]),
        }

        self.finish(data)


class DeleteHandler(APIHandler):
    """
    删除模型
    """

    @authenticated
    def post(self):
        mid = self.get_int_argument("id")
        if not mid:
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
        with self.app_mysql.connect() as conn:
            conn.execute(text("update model set del=1 where id=:id"), {"id": mid})
            conn.commit()

        self.finish()


class VersionAddHandler(APIHandler):
    """
    添加模型版本

    - 描述: 添加模型版本
    - 请求方式:post
    - 请求参数:
    > - model_id, int, 模型id
    > - version, string, 版本
    > - comment, string,备注
    > - model_file, string, 模型文件的md5, 只允许上传一个文件
    > - config_file, string, 模型配置文件的md5
    > - config_str, string, 模型配置参数,json格式字符串,eg: '{"a":"xx","b":"xx"}'
    - 返回值:无

    """

    @authenticated
    def post(self):
        mid = self.get_int_argument("model_id")
        version = self.get_escaped_argument("version", "")
        comment = self.get_escaped_argument("comment", "")
        model_file = self.get_escaped_argument("model_file", "")
        config_file = self.get_escaped_argument("config_file", "")
        config_str = self.get_escaped_argument("config_str", "")
        if not mid or not version or not model_file:
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")

        if not md5.md5_validate(model_file):
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误")
        if config_file and not md5.md5_validate(config_file):
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误")

        with self.app_mysql.connect() as conn:
            conn.execute(
                text(
                    """
                    insert into model_version 
                    (model_id, version, comment,model_file, config_file, config_str, create_time, update_time) 
                    values (:model_id, :version, :comment, :model_file, :config_file, :config_str, NOW(), NOW())"""
                ),
                {
                    "model_id": mid,
                    "version": version,
                    "comment": comment,
                    "model_file": model_file,
                    "config_file": config_file,
                    "config_str": config_str,
                },
            )
            conn.commit()
        self.finish()


class VersionEditHandler(APIHandler):
    """
    编辑模型版本

    - 描述: 编辑模型版本
    - 请求方式:post
    - 请求参数:
    > - version_id, int, 模型id
    > - version, string, 版本
    > - comment, string,备注
    > - model_file, string, 模型文件的md5, 只允许上传一个文件
    > - config_file, string, 模型配置文件的md5
    > - config_str, string, 模型配置参数,json格式字符串,eg: '{"a":"xx","b":"xx"}'
    - 返回值:无
    """

    @authenticated
    def post(self):
        version_id = self.get_int_argument("version_id")
        version = self.get_escaped_argument("version", "")
        comment = self.get_escaped_argument("comment", "")
        model_file = self.get_escaped_argument("model_file", "")
        config_file = self.get_escaped_argument("config_file", "")
        config_str = self.get_escaped_argument("config_str", "")
        if not version_id or not version or not model_file:
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")

        if not md5.md5_validate(model_file):
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误")
        if config_file and not md5.md5_validate(config_file):
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误")
        with self.app_mysql.connect() as conn:
            conn.execute(
                text(
                    "update model_version "
                    "set version=:version, comment=:comment, model_file=:model_file, config_file=:config_file, "
                    "config_str=:config_str, update_time=NOW() where id=:id"
                ),
                {
                    "version": version,
                    "comment": comment,
                    "model_file": model_file,
                    "config_file": config_file,
                    "config_str": config_str,
                    "id": version_id,
                },
            )
        self.finish()


class VersionListHandler(APIHandler):
    """
    模型版本列表
    - 描述: 模型版本列表
    - 请求方式:post
    - 请求参数:
    > - model_id, int, 模型id
    > - pageNo, int
    > - pageSize, int
    - 返回值:
    ```
        {
            "count": 123,
            "data": [
                {
                    "version_id": 213,    # 版本id
                    "version": "xx",      # 版本号
                    "path": "xxx",        # 文件路径
                    "size": 123,
                    "update_time": "xxx",
                    "is_default": 1       # 是否默认,1/默认,0/非默认
                },
                ...
            ]
        }
    ```

    """

    @authenticated
    def post(self):
        model_id = self.get_int_argument("model_id")
        pageNo = self.get_int_argument("pageNo", 1)
        pageSize = self.get_int_argument("pageSize", 20)
        if not model_id:
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")

        count = 0
        data = []
        with self.app_mysql.connect() as conn:
            # 查询数据,从model_version, files查询version以及file的信息
            cur = conn.execute(
                text(
                    """
                    select mv.id as version_id, mv.version, mv.model_file, mv.update_time, mv.is_default, f.filepath, f.filesize 
                    from model_version mv 
                    left join files f on mv.model_file=f.md5_str 
                    where mv.model_id=:mid and mv.del=0 
                    order by mv.id desc limit :offset, :limit
                    """
                ),
                {"mid": model_id, "offset": (pageNo - 1) * pageSize, "limit": pageSize},
            )

            result = db_mysql.to_json_list(cur)

            # 获取记录数量
            count = conn.execute(
                text(
                    """
                    select count(*) 
                    from model_version mv  
                    where mv.del=0 and model_id=:mid
                    """
                ),
                {"mid": model_id},
            ).scalar()

            for item in result:
                data.append(
                    {
                        "version_id": item["version_id"],  # 版本id
                        "version": item["version"],  # 版本号
                        "path": item["filepath"],  # 文件路径
                        "size": item["filesize"],
                        "update_time": str(item["update_time"]),
                        "is_default": item["is_default"],  # 是否默认,1/默认,0/非默认
                    }
                )

        self.finish({"count": count, "data": data})


class VersionInfoHandler(APIHandler):
    """
    模型版本信息

    - 描述: 模型版本信息
    - 请求方式:post
    - 请求参数:
    > - version_id, int, 模型id
    - 返回值:
    ```
    {
        "model_name": "xxx",
        "version": "xxx",     # 版本
        "comment": "xxx",     # 备注
        "model_file_name": "xxx", # 模型文件名
        "model_file_size": 123,   # 模型文件大小
        "model_file_md5": "xx",   # 模型文件md5
        "config_file_name": "xx", # 配置文件名
        "config_file_size": "xx", # 配置文件大小
        "config_file_md5": "xxx", # 配置文件md5
        "config_str": "xxx"       # 配置参数
    }
    ```
    """

    @authenticated
    def post(self):
        version_id = self.get_int_argument("version_id")
        response = {
            "model_name": "",
            "version": "",
            "comment": "",
            "model_file_name": "",
            "model_file_size": 0,
            "model_file_md5": "",
            "config_file_name": "",
            "config_file_size": 0,
            "config_file_md5": "",
            "config_str": "",
        }

        with self.app_mysql.connect() as conn:
            cur = conn.execute(
                text(
                    """select m.name as model_name, mv.version, mv.comment, mv.model_file, mv.config_file, mv.config_str 
                    from model_version mv, model m where mv.id=:id and mv.model_id=m.id"""
                ),
                {"id": version_id},
            )

            result = db_mysql.to_json(cur)
            if not result:
                raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型版本不存在")

            model_file = result["model_file"]
            config_file = result["config_file"]

            response["model_name"] = result["model_name"]
            response["version"] = result["version"]
            response["comment"] = result["comment"]
            response["model_file_md5"] = model_file
            response["config_file_md5"] = config_file

            # 获取文件信息
            if model_file:
                cur_model_file = conn.execute(
                    text("select filename, filesize from files where md5_str=:md5_str"),
                    {"md5_str": model_file},
                )
                model_file_info = db_mysql.to_json(cur_model_file)
                response["model_file_name"] = (
                    model_file_info["filename"] if model_file_info else ""
                )
                response["model_file_size"] = (
                    model_file_info["filesize"] if model_file_info else 0
                )

            if config_file:
                cur_config_file = conn.execute(
                    text("select filename, filesize from files where md5_str=:md5_str"),
                    {"md5_str": config_file},
                )
                config_file_info = db_mysql.to_json(cur_config_file)
                response["config_file_name"] = (
                    config_file_info["filename"] if config_file_info else ""
                )
                response["config_file_size"] = (
                    config_file_info["filesize"] if config_file_info else 0
                )

            response["config_str"] = self.unescape_string(result["config_str"])

        self.finish(response)


class VersionSetDefaultHandler(APIHandler):
    """
    设置模型版本为默认版本

    - 描述: 设置模型默认版本
    - 请求方式:post
    - 请求参数:
    > - version_id, int, 模型版本id
    > - model_id, int, 模型id
    - 返回值:无
    """

    @authenticated
    def post(self):
        version_id = self.get_int_argument("version_id")
        model_id = self.get_int_argument("model_id")
        if not version_id or not model_id:
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")

        with self.app_mysql.connect() as conn:
            conn.execute(
                text("update model_version set is_default=0 where model_id=:model_id"),
                {"model_id": model_id},
            )

            conn.execute(
                text("update model_version set is_default=1 where id=:id"),
                {"id": version_id},
            )
            conn.commit()

        self.finish()


class VersionDeleteHandler(APIHandler):
    """
    删除模型版本

    - 描述: 删除模型版本
    - 请求方式:post
    - 请求参数:
    > - version_id, int, 模型版本id
    - 返回值:无
    """

    @authenticated
    def post(self):
        version_id = self.get_int_argument("version_id")
        if not version_id:
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")

        row = {}
        # 获取模型对应的model_file, config_file,使用model_file, config_file删除对应的存储文件
        with self.app_mysql.connect() as conn:
            cur = conn.execute(
                text(
                    "select model_id, model_file, config_file from model_version where id=:id"
                ),
                {"id": version_id},
            )
            row = db_mysql.to_json(cur)

            if not row:
                raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型版本不存在")
            model_file = row["model_file"]
            config_file = row["config_file"]
            model_id = row["model_id"]

            # 清空模型默认版本
            conn.execute(
                text("update model set default_version='' where id=:id"),
                {"id": model_id},
            )

            # 删除文件
            try:
                conn.execute(
                    text("delete from files where md5_str=:md5_str"),
                    {"md5_str": model_file},
                )
                conn.execute(
                    text("delete from files where md5_str=:md5_str"),
                    {"md5_str": config_file},
                )

                os.remove(settings.file_upload_dir + "model/" + model_file)
                os.remove(settings.file_upload_dir + "model/" + config_file)
            except Exception as e:
                logging.info(e)

            conn.execute(
                text("delete from model_version where id=:id"), {"id": version_id}
            )

            conn.commit()

        self.finish()