diff --git a/website/db/enterprise/enterprise.py b/website/db/enterprise/enterprise.py index b1791ab..823498d 100644 --- a/website/db/enterprise/enterprise.py +++ b/website/db/enterprise/enterprise.py @@ -1,7 +1,7 @@ from website.handler import BaseHandler from sqlalchemy import text from typing import Any -from website.db_mysql import get_session, get_async_session, to_json +from website.db_mysql import get_session, get_async_session, to_json, to_json_list import json @@ -103,11 +103,10 @@ async def get_enterprise_model_and_device_count( # if res_device: # device_count = res_device["device_count"] - sql = """ - SELECT - (SELECT base_models FROM enterprise_busi_model WHERE {where_clause}) as base_models, - (SELECT COUNT(*) FROM enterprise_device WHERE {where_clause}) AS device_count - """ + + + sql_device_count = "SELECT COUNT(*) AS device_count FROM enterprise_device WHERE {where_clause} " + sql_base_model = "SELECT base_models FROM enterprise_busi_model WHERE {where_clause} " where_clause = "" params = {} @@ -118,12 +117,20 @@ async def get_enterprise_model_and_device_count( where_clause = "entity_suid = :entity_suid" params["entity_suid"] = entity_suid - sql = sql.format(where_clause=where_clause) - result = await session.execute(text(sql), params) - # row = result.fetchone() - row = to_json(result) - base_models, device_count = row["base_models"], row["device_count"] - if base_models: - model_count = len(json.loads(base_models)) + sql_device_count = sql_device_count.format(where_clause=where_clause) + result_device_count = await session.execute(text(sql_device_count), params) + device_count = to_json(result_device_count)["device_count"] + + sql_base_model = sql_base_model.format(where_clause=where_clause) + result_base_model = await session.execute(text(sql_base_model), params) + base_models = to_json_list(result_base_model) + base_model_ids = [] + for item in base_models: + base_models = json.loads(item["base_models"]) + for base_model in base_models: + if base_model["id"] not in base_model_ids: + base_model_ids.append(base_model["id"]) + + model_count = len(base_model_ids) return model_count, device_count diff --git a/website/db/enterprise_busi_model/enterprise_busi_model.py b/website/db/enterprise_busi_model/enterprise_busi_model.py index a8b0811..0f569ae 100644 --- a/website/db/enterprise_busi_model/enterprise_busi_model.py +++ b/website/db/enterprise_busi_model/enterprise_busi_model.py @@ -204,15 +204,16 @@ class EnterpriseBusiModelNodeRepository(object): link_node_ids = [int(node_id) for node_id in data['node_ids'].split(',')] with get_session() as session: for node_id in link_node_ids: + logging.info("node_id: %s") node_db = DB_Node.EnterpriseNodeRepository() - node = node_db.get_node_by_id(node_id) + node = node_db.get_entity_suid_by_node_id(node_id) node_suid = node["suid"] entity_suid = node["entity_suid"] model_node = EnterpriseBusiModelNode( suid=shortuuid.ShortUUID().random(10), entity_suid=entity_suid, - busi_model_id=data['busi_model_id'], - busi_model_suid=data['busi_model_suid'], + busi_model_id=data['busimodel_id'], + busi_model_suid=data['busimodel_suid'], node_id=node_id, node_suid=node_suid, ) diff --git a/website/db/enterprise_node/enterprise_node.py b/website/db/enterprise_node/enterprise_node.py index b3325db..fc350ae 100644 --- a/website/db/enterprise_node/enterprise_node.py +++ b/website/db/enterprise_node/enterprise_node.py @@ -176,7 +176,7 @@ class EnterpriseNodeRepository(object): def get_entity_suid_by_node_id(self, node_id: int) -> Union[dict, None]: with get_session() as session: res = session.execute( - text("select entity_suid from enterprise_node where id=:id"), + text("select suid, entity_suid from enterprise_node where id=:id"), {"id": node_id}, ) entity = to_json(res)