diff --git a/website/db/alg_model/alg_model.py b/website/db/alg_model/alg_model.py index 764540a..b06d471 100644 --- a/website/db/alg_model/alg_model.py +++ b/website/db/alg_model/alg_model.py @@ -25,41 +25,44 @@ CREATE TABLE `model` ( ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='模型表'; """ + class Model(Base): - __tablename__ = 'model' + __tablename__ = "model" id = Column(Integer, primary_key=True, autoincrement=True) - suid = Column(String(10), comment='short uuid') - name = Column(String(255), nullable=False, default='') - model_type = Column(Integer, default=1002, comment='模型类型,1001/经典算法,1002/深度学习') - classification = Column(Integer, default=0, comment='模型分类的id') - comment = Column(String(255), default='', comment='备注') - default_version = Column(String(100), default='') - delete = Column(Integer, default=0, info={'alias': 'del'}, comment='删除状态,1/删除,0/正常') + suid = Column(String(10), comment="short uuid") + name = Column(String(255), nullable=False, default="") + model_type = Column(Integer, default=1002, comment="模型类型,1001/经典算法,1002/深度学习") + classification = Column(Integer, default=0, comment="模型分类的id") + comment = Column(String(255), default="", comment="备注") + default_version = Column(String(100), default="") + delete = Column("del", Integer, default=0, comment="删除状态,1/删除,0/正常") create_time = Column(DateTime, default=func.now()) update_time = Column(DateTime, onupdate=func.now()) def __repr__(self): return f"Model(id={self.id}, name='{self.name}', model_type={self.model_type})" - + class ModelRepositry(object): def get_suid(self, model_id: int) -> str: - session = get_session() with get_session() as session: model = session.query(Model).filter(Model.id == model_id).first() if not model or not model.suid: return "" - + return model.suid - - def get_model_by_id(self, model_id: int) -> Model | None: - - session = get_session() + + def get_model_by_id(self, model_id: int) -> Union[Model, None]: + with get_session() as session: model = session.query(Model).filter(Model.id == model_id).first() if not model: return None - - return model \ No newline at end of file + + return model + + def get_model_count(self) -> int: + with get_session() as session: + return session.query(Model).count() diff --git a/website/db/enterprise/enterprise.py b/website/db/enterprise/enterprise.py index 7088ee6..b1791ab 100644 --- a/website/db/enterprise/enterprise.py +++ b/website/db/enterprise/enterprise.py @@ -19,7 +19,7 @@ def get_enterprise_device_count(id: int) -> int: # 获取所有企业实体数量 def get_enterprise_entity_count(engine: Any) -> int: with engine.connect() as conn: - count_sql_text = "select count(id) from enterprise " + count_sql_text = "select count(*) from enterprise " count = conn.execute(text(count_sql_text)).fetchone() if count: return count[0] diff --git a/website/db/enterprise_device/enterprise_device.py b/website/db/enterprise_device/enterprise_device.py index 4a1097f..18c9ff3 100644 --- a/website/db/enterprise_device/enterprise_device.py +++ b/website/db/enterprise_device/enterprise_device.py @@ -283,3 +283,14 @@ class EnterpriseDeviceRepository(object): } device_dicts.append(device_dict) return device_dicts + + def get_all_device_count(self) -> int: + """获取所有设备的数量""" + with get_session() as session: + try: + count = session.query(EnterpriseDevice).filter(EnterpriseDevice.delete != 1).count() + return count + except Exception as e: + logging.error("Failed to get all device count, error: {}".format(e)) + + return 0 \ No newline at end of file diff --git a/website/db/enterprise_node/enterprise_node_alert.py b/website/db/enterprise_node/enterprise_node_alert.py index c25cec6..8f56278 100644 --- a/website/db/enterprise_node/enterprise_node_alert.py +++ b/website/db/enterprise_node/enterprise_node_alert.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from typing import List, Dict, Any, Union +from typing import List, Dict, Union from sqlalchemy import text @@ -17,9 +17,10 @@ class EnterpriseNodeAlertRepository(object): cursor = session.execute( sql, {"enterprise_node_id": self.enterprise_node_id} ) + return to_json_list(cursor) def get_one( - self, enterprise_suid: str, enterprise_node_id: int + self, enterprise_suid: str, enterprise_node_id: int ) -> Union[Dict, None]: with get_session() as session: sql = text( diff --git a/website/handlers/enterprise_device/handler.py b/website/handlers/enterprise_device/handler.py index cff1de4..3e7e38f 100644 --- a/website/handlers/enterprise_device/handler.py +++ b/website/handlers/enterprise_device/handler.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- import logging + from sqlalchemy import text from website import db_mysql, errors +from website.db.enterprise_device import enterprise_device as DB_Device from website.handler import APIHandler, authenticated from website.util import shortuuid -from website.db.enterprise_device import enterprise_device as DB_Device class DeviceClassificationAddHandler(APIHandler): @@ -36,14 +37,15 @@ class DeviceClassificationAddHandler(APIHandler): if row: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "设备分类已存在") - conn.execute( + result = conn.execute( text( "INSERT INTO device_classification (name, suid) VALUES (:name, :suid)" ), {"name": name, "suid": shortuuid.ShortUUID().random(10)}, ) conn.commit() - self.finish() + last_id = result.lastrowid + self.finish({"id": last_id}) class DeviceClassificationHandler(APIHandler): diff --git a/website/handlers/enterprise_entity/handler.py b/website/handlers/enterprise_entity/handler.py index 924c1d9..fe22916 100644 --- a/website/handlers/enterprise_entity/handler.py +++ b/website/handlers/enterprise_entity/handler.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -import logging import asyncio +import logging from sqlalchemy import text @@ -8,11 +8,15 @@ from website import consts from website import db_mysql from website import errors from website import settings -from website.handler import APIHandler, authenticated +from website.db.alg_model.alg_model import ModelRepositry from website.db.enterprise import enterprise +from website.db.enterprise_device.enterprise_device import EnterpriseDeviceRepository +from website.handler import APIHandler, authenticated from website.util import shortuuid, aes -from concurrent.futures import ThreadPoolExecutor -from functools import partial + + +# from concurrent.futures import ThreadPoolExecutor +# from functools import partial class EntityIndexHandler(APIHandler): @@ -114,11 +118,13 @@ class EntityIndexBasecountHandler(APIHandler): @authenticated def post(self): - entity = enterprise.get_enterprise_entity_count(self.app_mysql) - model = enterprise.get_enterprise_model_count() - device = enterprise.get_enterprise_device_count() + entity_count = enterprise.get_enterprise_entity_count(self.app_mysql) + model_repository = ModelRepositry() + model_count = model_repository.get_model_count() + device_repository = EnterpriseDeviceRepository() + device_count = device_repository.get_all_device_count() - self.finish({"entity": entity, "model": model, "device": device}) + self.finish({"entity": entity_count, "model": model_count, "device": device_count}) class EntityAddHandler(APIHandler): @@ -137,14 +143,14 @@ class EntityAddHandler(APIHandler): logo = self.get_escaped_argument("logo", "") if ( - not name - or not province - or not city - or not addr - or not industry - or not contact - or not phone - or not summary + not name + or not province + or not city + or not addr + or not industry + or not contact + or not phone + or not summary ): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") @@ -209,14 +215,14 @@ class EntityEditHandler(APIHandler): account = self.get_escaped_argument("account", "") if ( - not name - or not province - or not city - or not addr - or not industry - or not contact - or not phone - or not summary + not name + or not province + or not city + or not addr + or not industry + or not contact + or not phone + or not summary ): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") diff --git a/website/handlers/enterprise_node/handler.py b/website/handlers/enterprise_node/handler.py index 42f54b1..3470571 100644 --- a/website/handlers/enterprise_node/handler.py +++ b/website/handlers/enterprise_node/handler.py @@ -1,15 +1,16 @@ # -*- coding: utf-8 -*- import json + from website import errors -from website.handler import APIHandler, authenticated -from website.db.enterprise_entity import enterprise_entity as DB_Entity -from website.db.enterprise_node import enterprise_node as DB_Node from website.db.enterprise_busi_model import enterprise_busi_model as DB_BusiModel from website.db.enterprise_busi_model import ( enterprise_busi_model_node_device as DB_BusiModelNodeDevice, ) from website.db.enterprise_device import enterprise_device as DB_Device +from website.db.enterprise_entity import enterprise_entity as DB_Entity +from website.db.enterprise_node import enterprise_node as DB_Node from website.db.enterprise_node import enterprise_node_alert as DB_NodeAlert +from website.handler import APIHandler, authenticated from website.util import shortuuid