You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

396 lines
13 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
import asyncio
import json
import logging
from sqlalchemy import text
from website import consts
from website import db_mysql
from website import errors
from website import settings
from website.db.alg_model.alg_model import ModelRepositry
from website.db.enterprise import enterprise
from website.db.enterprise_device.enterprise_device import EnterpriseDeviceRepository
from website.handler import APIHandler, authenticated
from website.util import shortuuid, aes
# from concurrent.futures import ThreadPoolExecutor
# from functools import partial
class EntityIndexHandler(APIHandler):
"""首页"""
@authenticated
async def post(self):
pageNo = self.get_int_argument("pageNo", 1)
pageSize = self.get_int_argument("pageSize", 10)
name = self.tostr(self.get_escaped_argument("name", ""))
with self.app_mysql.connect() as conn:
sql_text = "select id, suid, name, industry, logo, create_time from enterprise where 1=1 "
param = {}
count_sql_text = "select count(id) c from enterprise where 1=1 "
count_param = {}
if name:
sql_text += "and name like :name"
param["name"] = "%{}%".format(name)
count_sql_text += "and name like :name"
count_param["name"] = "%{}%".format(name)
sql_text += " and del=0"
count_sql_text += " and del=0"
sql_text += " order by id desc limit :pageSize offset :offset"
param["pageSize"] = pageSize
param["offset"] = (pageNo - 1) * pageSize
cur = conn.execute(text(sql_text), param)
result = db_mysql.to_json_list(cur)
cur_count = conn.execute(text(count_sql_text), count_param)
count = db_mysql.to_json(cur_count)
count = count["c"] if count else 0
data = []
# for item in result:
# modelCount = enterprise.get_enterprise_model_count(item["id"])
# deviceCount = enterprise.get_enterprise_device_count(item["id"])
# data.append(
# {
# "id": item["id"],
# "name": item["name"],
# "industry": consts.industry_map[item["industry"]],
# "modelCount": modelCount,
# "deviceCount": deviceCount,
# "logo": item["logo"],
# "createTime": str(item["create_time"]),
# }
# )
# with ThreadPoolExecutor() as executor:
# get_count = partial(enterprise.get_enterprise_model_and_device_count)
# futures = [
# executor.submit(get_count, entity_id=item["id"], entity_suid="")
# for item in result
# ]
# results = [future.result() for future in futures]
# model_counts = [result[0] for result in results]
# device_counts = [result[1] for result in results]
# *用于解包参数列表将列表中的每个元素作为单独的参数传递给asyncio.gather函数
count_results = await asyncio.gather(
*[
enterprise.get_enterprise_model_and_device_count(entity_id=item["id"])
for item in result
]
)
model_counts = [result[0] for result in count_results]
device_counts = [result[1] for result in count_results]
data = [
{
"id": item["id"],
"name": item["name"],
"industry": consts.industry_map[item["industry"]],
"modelCount": model_count,
"deviceCount": device_count,
"logo": item["logo"],
"createTime": str(item["create_time"]),
}
for item, model_count, device_count in zip(
result, model_counts, device_counts
)
]
self.finish({"count": count, "data": data})
class EntityIndexBasecountHandler(APIHandler):
"""首页基础统计书记"""
@authenticated
def post(self):
entity_count = enterprise.get_enterprise_entity_count(self.app_mysql)
model_repository = ModelRepositry()
model_count = model_repository.get_model_count()
device_repository = EnterpriseDeviceRepository()
device_count = device_repository.get_all_device_count()
self.finish({"entity": entity_count, "model": model_count, "device": device_count})
class EntityAddHandler(APIHandler):
"""添加企业"""
@authenticated
def post(self):
name = self.tostr(self.get_escaped_argument("name", ""))
province = self.get_escaped_argument("province", "")
city = self.get_escaped_argument("city", "")
addr = self.get_escaped_argument("addr", "")
industry = self.get_int_argument("industry")
contact = self.get_escaped_argument("contact", "")
phone = self.get_escaped_argument("phone", "")
summary = self.get_escaped_argument("summary", "")
logo = self.get_escaped_argument("logo", "")
if (
not name
or not province
or not city
or not addr
or not industry
or not contact
or not phone
or not summary
):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if industry not in consts.industry_map:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "清选择行业类型")
if logo and len(logo) * 0.75 / 1024 / 1024 > 1.2:
raise errors.HTTPAPIError(
errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制"
)
short_uid = shortuuid.ShortUUID().random(length=8)
pwd = aes.encrypt(settings.enterprise_aes_key, short_uid)
with self.app_mysql.connect() as conn:
conn.execute(
text(
"insert into enterprise(suid, name, province, city, addr, industry, contact, phone, summary, logo, account, pwd) "
"values(:suid, :name, :province, :city, :addr, :industry, :contact, :phone, :summary, :logo, :account, :pwd)"
),
{
"suid": shortuuid.ShortUUID().random(length=10),
"name": name,
"province": province,
"city": city,
"addr": addr,
"industry": industry,
"contact": contact,
"phone": phone,
"summary": summary,
"logo": logo,
"account": "admin",
"pwd": pwd,
},
)
conn.commit()
# self.db_app.insert(
# "insert into enterprise(name, province, city, addr, industry, contact, phone, summary, logo, account, pwd) "
# "values(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)",
# name, province, city, addr, industry, contact, phone, summary, logo, "admin", pwd,
# )
self.finish()
class EntityEditHandler(APIHandler):
"""编辑企业"""
@authenticated
def post(self):
eid = self.get_int_argument("id")
name = self.tostr(self.get_escaped_argument("name", ""))
province = self.get_escaped_argument("province", "")
city = self.get_escaped_argument("city", "")
addr = self.get_escaped_argument("addr", "")
industry = self.get_int_argument("industry")
contact = self.get_escaped_argument("contact", "")
phone = self.get_escaped_argument("phone", "")
summary = self.get_escaped_argument("summary", "")
logo = self.get_escaped_argument("logo", "")
account = self.get_escaped_argument("account", "")
if (
not name
or not province
or not city
or not addr
or not industry
or not contact
or not phone
or not summary
):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if industry not in consts.industry_map:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "清选择行业类型")
if logo and len(logo) * 0.75 / 1024 / 1024 > 1.2:
raise errors.HTTPAPIError(
errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制"
)
with self.app_mysql.connect() as conn:
conn.execute(
text(
# "insert into enterprise(name, province, city, addr, industry, contact, phone, summary, logo, account, pwd) "
# "values(:name, :province, :city, :addr, :industry, :contact, :phone, :summary, :logo, :account, :pwd)"
"update enterprise set name=:name, province=:province, city=:city, addr=:addr, industry=:industry, contact"
"=:contact, phone=:phone, summary=:summary, logo=:logo, account=:account where id=:id",
),
{
"name": name,
"province": province,
"city": city,
"addr": addr,
"industry": industry,
"contact": contact,
"phone": phone,
"summary": summary,
"logo": logo,
"account": account,
"id": eid,
},
)
conn.commit()
self.finish()
class EntityInfoHandler(APIHandler):
"""企业信息"""
@authenticated
def post(self):
eid = self.get_int_argument("id")
row = {}
with self.app_mysql.connect() as conn:
cur = conn.execute(
text("select * from enterprise where id=:id"), {"id": eid}
)
# keys = list(cur.keys())
#
# one = cur.fetchone()
# row = dict(zip(keys, one))
# logging.info(db.Row(itertools.zip_longest(keys, one)))
row = db_mysql.to_json(cur)
cur.close()
if not row:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "请求失败")
data = {
"name": row["name"],
"province": row["province"],
"city": row["city"],
"addr": row["addr"],
"industry": consts.industry_map[row["industry"]],
"expire_at": "",
"contact": row["contact"],
"phone": row["phone"],
"summary": row["summary"],
"logo": row["logo"],
"createTime": str(row["create_time"]),
"account": row["account"], # 企业账号
}
self.finish(data)
class ModelsHandler(APIHandler):
"""企业模型"""
@authenticated
def post(self):
eid = self.get_int_argument("id")
model_ids = []
with self.app_mysql.connect() as conn:
cur = conn.execute(
text("select base_models from enterprise_busi_model where entity_id=:eid"), {"eid": eid}
)
rows = db_mysql.to_json_list(cur)
for row in rows:
base_models = json.loads(row["base_models"])
model_ids.extend([item["id"] for item in base_models])
cur.close()
model_ids = list(set(model_ids))
cur = conn.execute(text(
"""
select m.name, m.model_type, mc.name as classification_name, m.default_version
from model m, model_classification mc
where m.classification=mc.id and m.id in :model_ids
"""
), {"model_ids": model_ids})
rows = db_mysql.to_json_list(cur)
cur.close()
data = []
for row in rows:
data.append({
"name": row["name"],
"model_type": consts.model_type_map[row["model_type"]],
"classification_name": row["classification_name"],
"default_version": row["default_version"]
})
self.finish({"count": len(model_ids), "data": data})
class EntityDeleteHandler(APIHandler):
"""删除企业"""
@authenticated
def post(self):
eid = self.get_int_argument("id")
with self.app_mysql.connect() as conn:
conn.execute(text("update enterprise set del=1 where id=:id"), {"id": eid})
conn.commit()
self.finish()
class EntityPwdcheckHandler(APIHandler):
"""查看企业密码"""
@authenticated
def post(self):
eid = self.get_int_argument("id")
with self.app_mysql.connect() as conn:
cur = conn.execute(
text("select pwd from enterprise where id=:id"), {"id": eid}
)
# row = cur.fetchone()
logging.info(cur)
row = db_mysql.to_json(cur)
if not row:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "请求失败")
pwd = row["pwd"]
cur.close()
pwd_dcrypt = aes.decrypt(settings.enterprise_aes_key, pwd)
self.finish({"pwd": pwd_dcrypt})
class IndustryMapHandler(APIHandler):
@authenticated
def post(self):
self.finish(consts.industry_map)