# -*- coding: utf-8 -*- import asyncio import json import logging from sqlalchemy import text from website import consts from website import db_mysql from website import errors from website import settings from website.db.alg_model.alg_model import ModelRepositry from website.db.enterprise import enterprise from website.db.enterprise_device.enterprise_device import EnterpriseDeviceRepository from website.handler import APIHandler, authenticated from website.util import shortuuid, aes # from concurrent.futures import ThreadPoolExecutor # from functools import partial class EntityIndexHandler(APIHandler): """首页""" @authenticated async def post(self): pageNo = self.get_int_argument("pageNo", 1) pageSize = self.get_int_argument("pageSize", 10) name = self.tostr(self.get_escaped_argument("name", "")) with self.app_mysql.connect() as conn: sql_text = "select id, suid, name, industry, logo, create_time from enterprise where 1=1 " param = {} count_sql_text = "select count(id) c from enterprise where 1=1 " count_param = {} if name: sql_text += "and name like :name" param["name"] = "%{}%".format(name) count_sql_text += "and name like :name" count_param["name"] = "%{}%".format(name) sql_text += " and del=0" count_sql_text += " and del=0" sql_text += " order by id desc limit :pageSize offset :offset" param["pageSize"] = pageSize param["offset"] = (pageNo - 1) * pageSize cur = conn.execute(text(sql_text), param) result = db_mysql.to_json_list(cur) cur_count = conn.execute(text(count_sql_text), count_param) count = db_mysql.to_json(cur_count) count = count["c"] if count else 0 data = [] # for item in result: # modelCount = enterprise.get_enterprise_model_count(item["id"]) # deviceCount = enterprise.get_enterprise_device_count(item["id"]) # data.append( # { # "id": item["id"], # "name": item["name"], # "industry": consts.industry_map[item["industry"]], # "modelCount": modelCount, # "deviceCount": deviceCount, # "logo": item["logo"], # "createTime": str(item["create_time"]), # } # ) # with ThreadPoolExecutor() as executor: # get_count = partial(enterprise.get_enterprise_model_and_device_count) # futures = [ # executor.submit(get_count, entity_id=item["id"], entity_suid="") # for item in result # ] # results = [future.result() for future in futures] # model_counts = [result[0] for result in results] # device_counts = [result[1] for result in results] count_results = await asyncio.gather( *[ enterprise.get_enterprise_model_and_device_count(entity_id=item["id"]) for item in result ] ) model_counts = [result[0] for result in count_results] device_counts = [result[1] for result in count_results] data = [ { "id": item["id"], "name": item["name"], "industry": consts.industry_map[item["industry"]], "modelCount": model_count, "deviceCount": device_count, "logo": item["logo"], "createTime": str(item["create_time"]), } for item, model_count, device_count in zip( result, model_counts, device_counts ) ] self.finish({"count": count, "data": data}) class EntityIndexBasecountHandler(APIHandler): """首页基础统计书记""" @authenticated def post(self): entity_count = enterprise.get_enterprise_entity_count(self.app_mysql) model_repository = ModelRepositry() model_count = model_repository.get_model_count() device_repository = EnterpriseDeviceRepository() device_count = device_repository.get_all_device_count() self.finish({"entity": entity_count, "model": model_count, "device": device_count}) class EntityAddHandler(APIHandler): """添加企业""" @authenticated def post(self): name = self.tostr(self.get_escaped_argument("name", "")) province = self.get_escaped_argument("province", "") city = self.get_escaped_argument("city", "") addr = self.get_escaped_argument("addr", "") industry = self.get_int_argument("industry") contact = self.get_escaped_argument("contact", "") phone = self.get_escaped_argument("phone", "") summary = self.get_escaped_argument("summary", "") logo = self.get_escaped_argument("logo", "") if ( not name or not province or not city or not addr or not industry or not contact or not phone or not summary ): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") if industry not in consts.industry_map: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "清选择行业类型") if logo and len(logo) * 0.75 / 1024 / 1024 > 1.2: raise errors.HTTPAPIError( errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制" ) short_uid = shortuuid.ShortUUID().random(length=8) pwd = aes.encrypt(settings.enterprise_aes_key, short_uid) with self.app_mysql.connect() as conn: conn.execute( text( "insert into enterprise(suid, name, province, city, addr, industry, contact, phone, summary, logo, account, pwd) " "values(:suid, :name, :province, :city, :addr, :industry, :contact, :phone, :summary, :logo, :account, :pwd)" ), { "suid": shortuuid.ShortUUID().random(length=10), "name": name, "province": province, "city": city, "addr": addr, "industry": industry, "contact": contact, "phone": phone, "summary": summary, "logo": logo, "account": "admin", "pwd": pwd, }, ) conn.commit() # self.db_app.insert( # "insert into enterprise(name, province, city, addr, industry, contact, phone, summary, logo, account, pwd) " # "values(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)", # name, province, city, addr, industry, contact, phone, summary, logo, "admin", pwd, # ) self.finish() class EntityEditHandler(APIHandler): """编辑企业""" @authenticated def post(self): eid = self.get_int_argument("id") name = self.tostr(self.get_escaped_argument("name", "")) province = self.get_escaped_argument("province", "") city = self.get_escaped_argument("city", "") addr = self.get_escaped_argument("addr", "") industry = self.get_int_argument("industry") contact = self.get_escaped_argument("contact", "") phone = self.get_escaped_argument("phone", "") summary = self.get_escaped_argument("summary", "") logo = self.get_escaped_argument("logo", "") account = self.get_escaped_argument("account", "") if ( not name or not province or not city or not addr or not industry or not contact or not phone or not summary ): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") if industry not in consts.industry_map: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "清选择行业类型") if logo and len(logo) * 0.75 / 1024 / 1024 > 1.2: raise errors.HTTPAPIError( errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制" ) with self.app_mysql.connect() as conn: conn.execute( text( # "insert into enterprise(name, province, city, addr, industry, contact, phone, summary, logo, account, pwd) " # "values(:name, :province, :city, :addr, :industry, :contact, :phone, :summary, :logo, :account, :pwd)" "update enterprise set name=:name, province=:province, city=:city, addr=:addr, industry=:industry, contact" "=:contact, phone=:phone, summary=:summary, logo=:logo, account=:account where id=:id", ), { "name": name, "province": province, "city": city, "addr": addr, "industry": industry, "contact": contact, "phone": phone, "summary": summary, "logo": logo, "account": account, "id": eid, }, ) conn.commit() self.finish() class EntityInfoHandler(APIHandler): """企业信息""" @authenticated def post(self): eid = self.get_int_argument("id") row = {} with self.app_mysql.connect() as conn: cur = conn.execute( text("select * from enterprise where id=:id"), {"id": eid} ) # keys = list(cur.keys()) # # one = cur.fetchone() # row = dict(zip(keys, one)) # logging.info(db.Row(itertools.zip_longest(keys, one))) row = db_mysql.to_json(cur) cur.close() if not row: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "请求失败") data = { "name": row["name"], "province": row["province"], "city": row["city"], "addr": row["addr"], "industry": row["industry"], "contact": row["contact"], "phone": row["phone"], "summary": row["summary"], "logo": row["logo"], "createTime": str(row["create_time"]), "account": row["account"], # 企业账号 } self.finish(data) class ModelsHandler(APIHandler): """企业模型""" @authenticated def post(self): eid = self.get_int_argument("id") model_ids = [] with self.app_mysql.connect() as conn: cur = conn.execute( text("select base_models from enterprise_busi_model where entity_id=:eid"), {"eid": eid} ) rows = db_mysql.to_json_list(cur) for row in rows: base_models = json.loads(row["base_models"]) model_ids.extend([item["id"] for item in base_models]) cur.close() model_ids = list(set(model_ids)) cur = conn.execute(text( """ select m.name, m.model_type, mc.name as classification_name, m.default_version from model m, model_classification mc where m.classification=mc.id and m.id in :model_ids """ ), {"model_ids": model_ids}) rows = db_mysql.to_json_list(cur) cur.close() data = [] for row in rows: data.append({ "name": row["name"], "model_type": consts.model_type_map[row["model_type"]], "classification_name": row["classification_name"], "default_version": row["default_version"] }) self.finish({"count": len(model_ids), "data": data}) class EntityDeleteHandler(APIHandler): """删除企业""" @authenticated def post(self): eid = self.get_int_argument("id") with self.app_mysql.connect() as conn: conn.execute(text("update enterprise set del=1 where id=:id"), {"id": eid}) conn.commit() self.finish() class EntityPwdcheckHandler(APIHandler): """查看企业密码""" @authenticated def post(self): eid = self.get_int_argument("id") with self.app_mysql.connect() as conn: cur = conn.execute( text("select pwd from enterprise where id=:id"), {"id": eid} ) # row = cur.fetchone() logging.info(cur) row = db_mysql.to_json(cur) if not row: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "请求失败") pwd = row["pwd"] cur.close() pwd_dcrypt = aes.decrypt(settings.enterprise_aes_key, pwd) self.finish({"pwd": pwd_dcrypt}) class IndustryMapHandler(APIHandler): @authenticated def post(self): self.finish(consts.industry_map)