# -*- coding: utf-8 -*-
import asyncio
import json
import logging

from sqlalchemy import text

from website import consts
from website import db_mysql
from website import errors
from website import settings
from website.db.alg_model.alg_model import ModelRepositry
from website.db.enterprise import enterprise
from website.db.enterprise_device.enterprise_device import EnterpriseDeviceRepository
from website.handler import APIHandler, authenticated
from website.util import shortuuid, aes


# from concurrent.futures import ThreadPoolExecutor
# from functools import partial


class EntityIndexHandler(APIHandler):
    """首页"""

    @authenticated
    async def post(self):
        pageNo = self.get_int_argument("pageNo", 1)
        pageSize = self.get_int_argument("pageSize", 10)
        name = self.tostr(self.get_escaped_argument("name", ""))

        with self.app_mysql.connect() as conn:
            sql_text = "select id, suid, name, industry, logo, create_time from enterprise where 1=1 "
            param = {}

            count_sql_text = "select count(id) c from enterprise where 1=1 "
            count_param = {}

            if name:
                sql_text += "and name like :name"
                param["name"] = "%{}%".format(name)

                count_sql_text += "and name like :name"
                count_param["name"] = "%{}%".format(name)

            sql_text += " and del=0"
            count_sql_text += " and del=0"

            sql_text += " order by id desc limit :pageSize offset :offset"
            param["pageSize"] = pageSize
            param["offset"] = (pageNo - 1) * pageSize

            cur = conn.execute(text(sql_text), param)
            result = db_mysql.to_json_list(cur)

            cur_count = conn.execute(text(count_sql_text), count_param)
            count = db_mysql.to_json(cur_count)
            count = count["c"] if count else 0

        data = []
        # for item in result:
        #     modelCount = enterprise.get_enterprise_model_count(item["id"])
        #     deviceCount = enterprise.get_enterprise_device_count(item["id"])

        #     data.append(
        #         {
        #             "id": item["id"],
        #             "name": item["name"],
        #             "industry": consts.industry_map[item["industry"]],
        #             "modelCount": modelCount,
        #             "deviceCount": deviceCount,
        #             "logo": item["logo"],
        #             "createTime": str(item["create_time"]),
        #         }
        #     )

        # with ThreadPoolExecutor() as executor:

        #     get_count = partial(enterprise.get_enterprise_model_and_device_count)
        #     futures = [
        #         executor.submit(get_count, entity_id=item["id"], entity_suid="")
        #         for item in result
        #     ]
        #     results = [future.result() for future in futures]

        #     model_counts = [result[0] for result in results]
        #     device_counts = [result[1] for result in results]

        # *用于解包参数列表,将列表中的每个元素作为单独的参数传递给asyncio.gather函数
        count_results = await asyncio.gather(
            *[
                enterprise.get_enterprise_model_and_device_count(entity_id=item["id"])
                for item in result
            ]
        )

        model_counts = [result[0] for result in count_results]
        device_counts = [result[1] for result in count_results]

        data = [
            {
                "id": item["id"],
                "name": item["name"],
                "industry": consts.industry_map[item["industry"]],
                "modelCount": model_count,
                "deviceCount": device_count,
                "logo": item["logo"],
                "createTime": str(item["create_time"]),
            }
            for item, model_count, device_count in zip(
                result, model_counts, device_counts
            )
        ]

        self.finish({"count": count, "data": data})


class EntityIndexBasecountHandler(APIHandler):
    """首页基础统计书记"""

    @authenticated
    def post(self):
        entity_count = enterprise.get_enterprise_entity_count(self.app_mysql)
        model_repository = ModelRepositry()
        model_count = model_repository.get_model_count()
        device_repository = EnterpriseDeviceRepository()
        device_count = device_repository.get_all_device_count()

        self.finish({"entity": entity_count, "model": model_count, "device": device_count})


class EntityAddHandler(APIHandler):
    """添加企业"""

    @authenticated
    def post(self):
        name = self.tostr(self.get_escaped_argument("name", ""))
        province = self.get_escaped_argument("province", "")
        city = self.get_escaped_argument("city", "")
        addr = self.get_escaped_argument("addr", "")
        industry = self.get_int_argument("industry")
        contact = self.get_escaped_argument("contact", "")
        phone = self.get_escaped_argument("phone", "")
        summary = self.get_escaped_argument("summary", "")
        logo = self.get_escaped_argument("logo", "")

        if (
                not name
                or not province
                or not city
                or not addr
                or not industry
                or not contact
                or not phone
                or not summary
        ):
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")

        if industry not in consts.industry_map:
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "清选择行业类型")

        if logo and len(logo) * 0.75 / 1024 / 1024 > 1.2:
            raise errors.HTTPAPIError(
                errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制"
            )

        short_uid = shortuuid.ShortUUID().random(length=8)
        pwd = aes.encrypt(settings.enterprise_aes_key, short_uid)

        with self.app_mysql.connect() as conn:
            conn.execute(
                text(
                    "insert into enterprise(suid, name, province, city, addr, industry, contact, phone, summary, logo, account, pwd) "
                    "values(:suid, :name, :province, :city, :addr, :industry, :contact, :phone, :summary, :logo, :account, :pwd)"
                ),
                {
                    "suid": shortuuid.ShortUUID().random(length=10),
                    "name": name,
                    "province": province,
                    "city": city,
                    "addr": addr,
                    "industry": industry,
                    "contact": contact,
                    "phone": phone,
                    "summary": summary,
                    "logo": logo,
                    "account": "admin",
                    "pwd": pwd,
                },
            )
            conn.commit()

        # self.db_app.insert(
        #     "insert into enterprise(name, province, city, addr, industry, contact, phone, summary, logo, account, pwd) "
        #     "values(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)",
        #     name, province, city, addr, industry, contact, phone, summary, logo, "admin", pwd,
        # )

        self.finish()


class EntityEditHandler(APIHandler):
    """编辑企业"""

    @authenticated
    def post(self):
        eid = self.get_int_argument("id")
        name = self.tostr(self.get_escaped_argument("name", ""))
        province = self.get_escaped_argument("province", "")
        city = self.get_escaped_argument("city", "")
        addr = self.get_escaped_argument("addr", "")
        industry = self.get_int_argument("industry")
        contact = self.get_escaped_argument("contact", "")
        phone = self.get_escaped_argument("phone", "")
        summary = self.get_escaped_argument("summary", "")
        logo = self.get_escaped_argument("logo", "")
        account = self.get_escaped_argument("account", "")

        if (
                not name
                or not province
                or not city
                or not addr
                or not industry
                or not contact
                or not phone
                or not summary
        ):
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")

        if industry not in consts.industry_map:
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "清选择行业类型")

        if logo and len(logo) * 0.75 / 1024 / 1024 > 1.2:
            raise errors.HTTPAPIError(
                errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制"
            )

        with self.app_mysql.connect() as conn:
            conn.execute(
                text(
                    # "insert into enterprise(name, province, city, addr, industry, contact, phone, summary, logo, account, pwd) "
                    # "values(:name, :province, :city, :addr, :industry, :contact, :phone, :summary, :logo, :account, :pwd)"
                    "update enterprise set name=:name, province=:province, city=:city, addr=:addr, industry=:industry, contact"
                    "=:contact, phone=:phone, summary=:summary, logo=:logo, account=:account where id=:id",
                ),
                {
                    "name": name,
                    "province": province,
                    "city": city,
                    "addr": addr,
                    "industry": industry,
                    "contact": contact,
                    "phone": phone,
                    "summary": summary,
                    "logo": logo,
                    "account": account,
                    "id": eid,
                },
            )
            conn.commit()

        self.finish()


class EntityInfoHandler(APIHandler):
    """企业信息"""

    @authenticated
    def post(self):
        eid = self.get_int_argument("id")

        row = {}
        with self.app_mysql.connect() as conn:
            cur = conn.execute(
                text("select * from enterprise where id=:id"), {"id": eid}
            )
            # keys = list(cur.keys())
            #
            # one = cur.fetchone()
            # row = dict(zip(keys, one))
            # logging.info(db.Row(itertools.zip_longest(keys, one)))

            row = db_mysql.to_json(cur)

            cur.close()

        if not row:
            raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "请求失败")

        data = {
            "name": row["name"],
            "province": row["province"],
            "city": row["city"],
            "addr": row["addr"],
            "industry": consts.industry_map[row["industry"]],
            "expire_at": "",
            "contact": row["contact"],
            "phone": row["phone"],
            "summary": row["summary"],
            "logo": row["logo"],
            "createTime": str(row["create_time"]),
            "account": row["account"],  # 企业账号
        }

        self.finish(data)


class ModelsHandler(APIHandler):
    """企业模型"""

    @authenticated
    def post(self):
        eid = self.get_int_argument("id")

        model_ids = []
        with self.app_mysql.connect() as conn:
            cur = conn.execute(
                text("select base_models from enterprise_busi_model where entity_id=:eid"), {"eid": eid}
            )
            rows = db_mysql.to_json_list(cur)

            for row in rows:
                base_models = json.loads(row["base_models"])
                model_ids.extend([item["id"] for item in base_models])
            cur.close()

            model_ids = list(set(model_ids))

            cur = conn.execute(text(
                """
                select m.name, m.model_type, mc.name as classification_name, m.default_version 
                from model m, model_classification mc 
                where m.classification=mc.id and m.id in :model_ids
                """
            ), {"model_ids": model_ids})

            rows = db_mysql.to_json_list(cur)
            cur.close()

            data = []
            for row in rows:
                data.append({
                    "name": row["name"],
                    "model_type": consts.model_type_map[row["model_type"]],
                    "classification_name": row["classification_name"],
                    "default_version": row["default_version"]
                })

            self.finish({"count": len(model_ids), "data": data})


class EntityDeleteHandler(APIHandler):
    """删除企业"""

    @authenticated
    def post(self):
        eid = self.get_int_argument("id")

        with self.app_mysql.connect() as conn:
            conn.execute(text("update enterprise set del=1 where id=:id"), {"id": eid})

            conn.commit()

        self.finish()


class EntityPwdcheckHandler(APIHandler):
    """查看企业密码"""

    @authenticated
    def post(self):
        eid = self.get_int_argument("id")

        with self.app_mysql.connect() as conn:
            cur = conn.execute(
                text("select pwd from enterprise where id=:id"), {"id": eid}
            )
            # row = cur.fetchone()
            logging.info(cur)
            row = db_mysql.to_json(cur)
            if not row:
                raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "请求失败")
            pwd = row["pwd"]
            cur.close()

            pwd_dcrypt = aes.decrypt(settings.enterprise_aes_key, pwd)

        self.finish({"pwd": pwd_dcrypt})


class IndustryMapHandler(APIHandler):
    @authenticated
    def post(self):
        self.finish(consts.industry_map)