From c1f4dc7778ce93ddb8361626f3e1c239f1fdf1a8 Mon Sep 17 00:00:00 2001 From: zhouping Date: Mon, 27 May 2024 17:36:29 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=BC=82=E6=AD=A5=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- website/db_mysql.py | 42 +++++- website/handlers/enterprise_entity/handler.py | 132 ++++++++++++++---- website/handlers/enterprise_entity/url.py | 6 +- website/service/enterprise.py | 107 +++++++++++++- 4 files changed, 242 insertions(+), 45 deletions(-) diff --git a/website/db_mysql.py b/website/db_mysql.py index 74ccdcc..5e40b68 100644 --- a/website/db_mysql.py +++ b/website/db_mysql.py @@ -7,6 +7,8 @@ from typing import List, Any, Optional from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from sqlalchemy import create_engine +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from contextlib import asynccontextmanager from website import settings @@ -41,11 +43,11 @@ def to_json(cursor): app_engine = create_engine( - 'mysql+pymysql://{}:{}@{}/{}?charset=utf8mb4'.format( - settings.mysql_app['user'], - settings.mysql_app['password'], - settings.mysql_app['host'], - settings.mysql_app['database'] + "mysql+pymysql://{}:{}@{}/{}?charset=utf8mb4".format( + settings.mysql_app["user"], + settings.mysql_app["password"], + settings.mysql_app["host"], + settings.mysql_app["database"], ), # SQLAlchemy 数据库连接串,格式见下面 echo=bool(settings.SQLALCHEMY_ECHO), # 是不是要把所执行的SQL打印出来,一般用于调试 pool_pre_ping=True, @@ -69,3 +71,33 @@ def get_session(): raise e finally: s.close() + + +async_app_engine = create_async_engine( + "mysql+aiomysql://{}:{}@{}/{}?charset=utf8mb4".format( + settings.mysql_app["user"], + settings.mysql_app["password"], + settings.mysql_app["host"], + settings.mysql_app["database"], + ), + echo=bool(settings.SQLALCHEMY_ECHO), + pool_pre_ping=True, + pool_size=int(settings.SQLALCHEMY_POOL_SIZE), + max_overflow=int(settings.SQLALCHEMY_POOL_MAX_SIZE), + pool_recycle=int(settings.SQLALCHEMY_POOL_RECYCLE), +) + +AsyncSession = sessionmaker(bind=async_app_engine, class_=AsyncSession) + + +@asynccontextmanager +async def get_async_session(): + async with AsyncSession() as s: + try: + yield s + await s.commit() + except Exception as e: + await s.rollback() + raise e + finally: + await s.close() diff --git a/website/handlers/enterprise_entity/handler.py b/website/handlers/enterprise_entity/handler.py index 9f323a6..e2ff87f 100644 --- a/website/handlers/enterprise_entity/handler.py +++ b/website/handlers/enterprise_entity/handler.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import logging +import asyncio from sqlalchemy import text @@ -10,22 +11,24 @@ from website import settings from website.handler import APIHandler, authenticated from website.service import enterprise from website.util import shortuuid, aes +from concurrent.futures import ThreadPoolExecutor +from functools import partial class EntityIndexHandler(APIHandler): """首页""" @authenticated - def post(self): + async def post(self): pageNo = self.get_int_argument("pageNo", 1) pageSize = self.get_int_argument("pageSize", 10) name = self.tostr(self.get_escaped_argument("name", "")) with self.app_mysql.connect() as conn: - sql_text = "select id, name, industry, logo, create_time from enterprise where 1=1 " + sql_text = "select id, suid, name, industry, logo, create_time from enterprise where 1=1 " param = {} - count_sql_text = "select count(id) from enterprise " + count_sql_text = "select count(id) c from enterprise where 1=1 " count_param = {} if name: @@ -45,26 +48,64 @@ class EntityIndexHandler(APIHandler): cur = conn.execute(text(sql_text), param) result = db_mysql.to_json_list(cur) - count = conn.execute(text(count_sql_text), count_param).fetchone()[0] - logging.info(count) - logging.info(result) + cur_count = conn.execute(text(count_sql_text), count_param) + count = db_mysql.to_json(cur_count) + count = count["c"] if count else 0 + # data_index = [item["id"] for item in result] data = [] - for item in result: - modelCount = enterprise.get_enterprise_model_count(item["id"]) - deviceCount = enterprise.get_enterprise_device_count(item["id"]) - - data.append( - { - "id": item["id"], - "name": item["name"], - "industry": item["industry"], - "modelCount": modelCount, - "deviceCount": deviceCount, - "logo": item["logo"], - "createTime": str(item["create_time"]) - } + # for item in result: + # modelCount = enterprise.get_enterprise_model_count(item["id"]) + # deviceCount = enterprise.get_enterprise_device_count(item["id"]) + + # data.append( + # { + # "id": item["id"], + # "name": item["name"], + # "industry": consts.industry_map[item["industry"]], + # "modelCount": modelCount, + # "deviceCount": deviceCount, + # "logo": item["logo"], + # "createTime": str(item["create_time"]), + # } + # ) + + # with ThreadPoolExecutor() as executor: + + # get_count = partial(enterprise.get_enterprise_model_and_device_count) + # futures = [ + # executor.submit(get_count, entity_id=item["id"], entity_suid="") + # for item in result + # ] + # results = [future.result() for future in futures] + + # model_counts = [result[0] for result in results] + # device_counts = [result[1] for result in results] + + count_results = await asyncio.gather( + *[ + enterprise.get_enterprise_model_and_device_count(entity_id=item["id"]) + for item in result + ] + ) + + model_counts = [result[0] for result in count_results] + device_counts = [result[1] for result in count_results] + + data = [ + { + "id": item["id"], + "name": item["name"], + "industry": consts.industry_map[item["industry"]], + "modelCount": model_count, + "deviceCount": device_count, + "logo": item["logo"], + "createTime": str(item["create_time"]), + } + for item, model_count, device_count in zip( + result, model_counts, device_counts ) + ] self.finish({"count": count, "data": data}) @@ -96,14 +137,25 @@ class EntityAddHandler(APIHandler): summary = self.get_escaped_argument("summary", "") logo = self.get_escaped_argument("logo", "") - if not name or not province or not city or not addr or not industry or not contact or not phone or not summary: + if ( + not name + or not province + or not city + or not addr + or not industry + or not contact + or not phone + or not summary + ): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") if industry not in consts.industry_map: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "清选择行业类型") if logo and len(logo) * 0.75 / 1024 / 1024 > 1.2: - raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制") + raise errors.HTTPAPIError( + errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制" + ) short_uid = shortuuid.ShortUUID().random(length=8) pwd = aes.encrypt(settings.enterprise_aes_key, short_uid) @@ -127,7 +179,7 @@ class EntityAddHandler(APIHandler): "logo": logo, "account": "admin", "pwd": pwd, - } + }, ) conn.commit() @@ -157,21 +209,31 @@ class EntityEditHandler(APIHandler): logo = self.get_escaped_argument("logo", "") account = self.get_escaped_argument("account", "") - if not name or not province or not city or not addr or not industry or not contact or not phone or not summary: + if ( + not name + or not province + or not city + or not addr + or not industry + or not contact + or not phone + or not summary + ): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") if industry not in consts.industry_map: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "清选择行业类型") if logo and len(logo) * 0.75 / 1024 / 1024 > 1.2: - raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制") + raise errors.HTTPAPIError( + errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制" + ) with self.app_mysql.connect() as conn: conn.execute( text( # "insert into enterprise(name, province, city, addr, industry, contact, phone, summary, logo, account, pwd) " # "values(:name, :province, :city, :addr, :industry, :contact, :phone, :summary, :logo, :account, :pwd)" - "update enterprise set name=:name, province=:province, city=:city, addr=:addr, industry=:industry, contact" "=:contact, phone=:phone, summary=:summary, logo=:logo, account=:account where id=:id", ), @@ -186,7 +248,8 @@ class EntityEditHandler(APIHandler): "summary": summary, "logo": logo, "account": account, - } + "id": eid, + }, ) conn.commit() @@ -222,6 +285,7 @@ class EntityInfoHandler(APIHandler): "name": row["name"], "province": row["province"], "city": row["city"], + "addr": row["addr"], "industry": row["industry"], "contact": row["contact"], "phone": row["phone"], @@ -242,9 +306,7 @@ class EntityDeleteHandler(APIHandler): eid = self.get_int_argument("id") with self.app_mysql.connect() as conn: - conn.execute( - text("update enterprise set del=1 where id=:id"), {"id": eid} - ) + conn.execute(text("update enterprise set del=1 where id=:id"), {"id": eid}) conn.commit() @@ -259,7 +321,9 @@ class EntityPwdcheckHandler(APIHandler): eid = self.get_int_argument("id") with self.app_mysql.connect() as conn: - cur = conn.execute(text("select pwd from enterprise where id=:id"), {"id": eid}) + cur = conn.execute( + text("select pwd from enterprise where id=:id"), {"id": eid} + ) # row = cur.fetchone() logging.info(cur) row = db_mysql.to_json(cur) @@ -271,3 +335,9 @@ class EntityPwdcheckHandler(APIHandler): pwd_dcrypt = aes.decrypt(settings.enterprise_aes_key, pwd) self.finish({"pwd": pwd_dcrypt}) + + +class IndustryMapHandler(APIHandler): + @authenticated + def post(self): + self.finish(consts.industry_map) diff --git a/website/handlers/enterprise_entity/url.py b/website/handlers/enterprise_entity/url.py index 882be28..c1ca292 100644 --- a/website/handlers/enterprise_entity/url.py +++ b/website/handlers/enterprise_entity/url.py @@ -9,9 +9,7 @@ handlers = [ ("/enterprise/entity/info", handler.EntityInfoHandler), ("/enterprise/entity/delete", handler.EntityDeleteHandler), ("/enterprise/entity/pwdcheck", handler.EntityPwdcheckHandler), - + ("/enterprise/industrymap", handler.IndustryMapHandler), ] -page_handlers = [ - -] \ No newline at end of file +page_handlers = [] diff --git a/website/service/enterprise.py b/website/service/enterprise.py index b1e17e8..7088ee6 100644 --- a/website/service/enterprise.py +++ b/website/service/enterprise.py @@ -1,6 +1,9 @@ from website.handler import BaseHandler from sqlalchemy import text from typing import Any +from website.db_mysql import get_session, get_async_session, to_json +import json + # 获取企业模型数量 def get_enterprise_model_count(id: int) -> int: @@ -20,13 +23,107 @@ def get_enterprise_entity_count(engine: Any) -> int: count = conn.execute(text(count_sql_text)).fetchone() if count: return count[0] - + return 0 + # 获取所有企业模型数量 -def get_enterprise_model_count() -> int: - return 0 +def get_enterprise_model_count(entity_id: int = 0, entity_suid: str = "") -> int: + with get_session() as session: + sql = "select base_models from enterprise_busi_model where " + param = {} + if entity_id: + sql += "entity_id = :entity_id" + param["entity_id"] = entity_id + elif entity_suid: + sql += "entity_suid = :entity_suid" + param["entity_suid"] = entity_suid + + cur = session.execute(text(sql), param) + res = to_json(cur) + if res: + base_model_list = json.loads(res) + return len(base_model_list) + + return 0 + # 获取所有企业设备数量 -def get_enterprise_device_count() -> int: - return 0 \ No newline at end of file +def get_enterprise_device_count(entity_id: int = 0, entity_suid: str = "") -> int: + with get_session() as session: + sql = "select count(id) as device_count from enterprise_device " + param = {} + if entity_id: + sql += "where entity_id = :entity_id" + param = {"entity_id": entity_id} + elif entity_suid: + sql += "where entity_suid = :entity_suid" + param = {"entity_suid": entity_suid} + + cur = session.execute(text(sql), param) + res = to_json(cur) + if res: + return res["device_count"] + return 0 + + +async def get_enterprise_model_and_device_count( + entity_id: int = 0, entity_suid: str = "" +) -> (int, int): + model_count = 0 + device_count = 0 + + async with get_async_session() as session: + # sql_model = "select base_models from enterprise_busi_model where " + # param_model = {} + # if entity_id: + # sql_model += "entity_id = :entity_id" + # param_model["entity_id"] = entity_id + # elif entity_suid: + # sql_model += "entity_suid = :entity_suid" + # param_model["entity_suid"] = entity_suid + + # cur_model = await session.execute(text(sql_model), param_model) + # res_model = to_json(cur_model) + # if res_model: + # base_model_list = json.loads(res_model) + # model_count = len(base_model_list) + + # sql_device = "select count(id) as device_count from enterprise_device where " + # param_device = {} + # if entity_id: + # sql_device += "entity_id = :entity_id" + # param_device["entity_id"] = entity_id + # elif entity_suid: + # sql_device += "where entity_suid = :entity_suid" + # param_device = {"entity_suid": entity_suid} + + # cur_device = await session.execute(text(sql_device), param_device) + # res_device = to_json(cur_device) + # if res_device: + # device_count = res_device["device_count"] + + sql = """ + SELECT + (SELECT base_models FROM enterprise_busi_model WHERE {where_clause}) as base_models, + (SELECT COUNT(*) FROM enterprise_device WHERE {where_clause}) AS device_count + """ + + where_clause = "" + params = {} + if entity_id: + where_clause = "entity_id = :entity_id" + params["entity_id"] = entity_id + elif entity_suid: + where_clause = "entity_suid = :entity_suid" + params["entity_suid"] = entity_suid + + sql = sql.format(where_clause=where_clause) + result = await session.execute(text(sql), params) + # row = result.fetchone() + row = to_json(result) + base_models, device_count = row["base_models"], row["device_count"] + if base_models: + model_count = len(json.loads(base_models)) + + return model_count, device_count