General-AI-Platform-Backend/website/db/enterprise_busi_model/enterprise_busi_model.py

298 lines
12 KiB
Python

# -*- coding: utf-8 -*-
import copy
import json
import logging
from typing import Any, Dict, List, Optional
from sqlalchemy import Column, Integer, String, DateTime, func
from sqlalchemy import text
from sqlalchemy.ext.declarative import declarative_base
from website.db.alg_model import alg_model as DB_alg_model
from website.db.enterprise_entity.enterprise_entity import EnterpriseEntityRepository
from website.db.enterprise_node import enterprise_node as DB_Node
from website.db_mysql import get_session, to_json_list, Row, dict_to_obj
from website.util import shortuuid
Base = declarative_base()
"""
CREATE TABLE `enterprise_busi_model` (
`id` int NOT NULL AUTO_INCREMENT,
`suid` varchar(10) NOT NULL DEFAULT '' COMMENT 'short uuid',
`entity_id` int NOT NULL COMMENT '企业id',
`entity_suid` varchar(10) NOT NULL COMMENT '企业uuid',
`name` varchar(255) NOT NULL,
`comment` varchar(255) DEFAULT '',
`basemodel_ids` varchar(255) NOT NULL COMMENT '关联模型json list, [{"id":123,"suid":"xxx"},...]',
`business_logic` varchar(32) DEFAULT NULL COMMENT '业务代码压缩包的md5',
`business_conf_file` varchar(32) DEFAULT NULL COMMENT '业务配置参数压缩包的文件md5',
`business_conf_param` varchar(255) DEFAULT NULL COMMENT '业务配置的参数json字符串eg: ''{"a":1, "b":2}''',
`delete` tinyint(1) DEFAULT '0',
`create_time` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`update_time` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='企业业务模型表';
CREATE TABLE `enterprise_busi_model_node` (
`id` int NOT NULL AUTO_INCREMENT,
`suid` varchar(10) NOT NULL,
`entity_suid` varchar(10) DEFAULT NULL COMMENT '企业suid',
`busi_model_id` int DEFAULT NULL,
`busi_model_suid` varchar(10) DEFAULT NULL,
`node_id` int DEFAULT NULL,
`node_suid` varchar(10) DEFAULT NULL,
`create_time` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;
"""
class EnterpriseBusiModel(Base):
__tablename__ = 'enterprise_busi_model'
id = Column(Integer, primary_key=True, autoincrement=True)
suid = Column(String(10), nullable=False, default='', comment='short uuid')
entity_id = Column(Integer, nullable=False, comment='企业id')
entity_suid = Column(String(10), nullable=False, comment='企业uuid')
name = Column(String(255), nullable=False)
comment = Column(String(255), nullable=True, default='')
base_models = Column(String(255), nullable=False, comment='关联的基础模型json list, [{"id":123,"suid":"xxx"},...]')
business_logic = Column(String(32), nullable=True, comment='业务代码压缩包的md5')
business_conf_file = Column(String(32), nullable=True, comment='业务配置参数压缩包的文件md5')
business_conf_param = Column(String(255), nullable=True, comment='业务配置的参数json字符串eg: ''{"a":1, "b":2}''')
delete = Column(Integer, default=0)
create_time = Column(DateTime, nullable=False, default=func.now(), onupdate=func.now())
update_time = Column(DateTime, nullable=False, default=func.now(), onupdate=func.now())
def __repr__(self):
return f"EnterpriseBusiModel(id={self.id}, suid='{self.suid}', name='{self.name}')"
def __init__(self, **kwargs):
valid_columns = {col.name for col in self.__table__.columns}
filtered_data = {key: value for key, value in kwargs.items() if key in valid_columns}
super().__init__(**filtered_data)
class EnterpriseBusiModelRepository(object):
def get_by_id(self, id: int) -> Optional[EnterpriseBusiModel]:
with get_session() as session:
model = session.query(EnterpriseBusiModel).filter(EnterpriseBusiModel.id == id).first()
return model
def insert_busi_model(self, data: Dict):
entity_suid = EnterpriseEntityRepository().get_entity_suid(data['entity_id'])
data['suid'] = shortuuid.ShortUUID().random(10)
data['entity_suid'] = entity_suid
base_model_ids = [int(model_id) for model_id in data['basemodel_ids'].split(',')]
logging.info("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
logging.info(base_model_ids)
base_model = []
for base_model_id in base_model_ids:
base_model_db = DB_alg_model.ModelRepositry()
# base_model_suid = base_model_db.get_suid(base_model_id)
base_model_info = base_model_db.get_model_dict_by_id(base_model_id)
logging.info("#####################")
logging.info(base_model_info)
base_model.append({
'id': base_model_id,
'suid': base_model_info["suid"],
'name': base_model_info["name"],
})
data['base_models'] = json.dumps(base_model)
new_data = copy.copy(data)
with get_session() as session:
model = EnterpriseBusiModel(**new_data)
session.add(model)
session.commit()
return model.id, model.suid
def edit_busi_model(self, data: Dict):
base_model_ids = [int(model_id) for model_id in data['basemodel_ids'].split(',')]
base_model = []
for base_model_id in base_model_ids:
base_model_db = DB_alg_model.ModelRepositry()
# base_model_suid = base_model_db.get_suid(base_model_id)
base_model_info = base_model_db.get_model_dict_by_id(base_model_id)
base_model.append({
"id": base_model_id,
"suid": base_model_info["suid"],
"name": base_model_info["name"],
})
data['base_models'] = json.dumps(base_model)
with get_session() as session:
try:
# 过滤不必要的字段
valid_columns = {col.name for col in EnterpriseBusiModel.__table__.columns}
filtered_data = {key: value for key, value in data.items() if key in valid_columns}
session.query(EnterpriseBusiModel).filter(EnterpriseBusiModel.id == data['id']).update(filtered_data)
except Exception as e:
logging.error("Failed to edit device, error: {}".format(e))
session.commit()
return
def get_busi_model_by_id(self, id: int) -> Optional[EnterpriseBusiModel]:
with get_session() as session:
model = session.query(EnterpriseBusiModel).filter(EnterpriseBusiModel.id == id).first()
if model:
model_dict = model.__dict__
model_dict.pop('_sa_instance_state')
return dict_to_obj(model_dict)
return None
def list_enterprise_busi_model(self, entity_id: int, page_no: int, page_size: int) -> Dict[Any, Any]:
"""
获取企业部署的业务模型列表
Args:
entity_id (int): 企业 ID
page_no (int): 页码
page_size (int): 每页数量
Returns:
dict: 包含总数和数据列表的字典
"""
with get_session() as session:
total_count = session.query(func.count(EnterpriseBusiModel.id)).filter(
EnterpriseBusiModel.entity_id == entity_id).filter(EnterpriseBusiModel.delete == 0).scalar()
models = (
session.query(
EnterpriseBusiModel.id.label("model_id"),
EnterpriseBusiModel.name.label("model_name"),
EnterpriseBusiModel.create_time
)
.filter(EnterpriseBusiModel.entity_id == entity_id)
.filter(EnterpriseBusiModel.delete == 0)
.offset((page_no - 1) * page_size)
.limit(page_size)
.all()
)
return {
"count": total_count,
"data": [
{
"model_id": model.model_id,
"model_name": model.model_name,
"create_time": model.create_time.strftime("%Y-%m-%d %H:%M:%S"),
}
for model in models
]
}
def delete_by_id(self, model_id: int):
"""
删除业务模型
delete字段置为1
Args:
model_id (int): 业务模型 ID
"""
with get_session() as session:
model = session.query(EnterpriseBusiModel).filter(EnterpriseBusiModel.id == model_id).first()
if model:
model.delete = 1
session.commit()
return
class EnterpriseBusiModelNode(Base):
__tablename__ = 'enterprise_busi_model_node'
id = Column(Integer, primary_key=True)
suid = Column(String(10), nullable=False, default='')
entity_suid = Column(String(10))
busi_model_id = Column(Integer)
busi_model_suid = Column(String(10))
node_id = Column(Integer)
node_suid = Column(String(10))
create_time = Column(DateTime, default=func.current_timestamp())
def __repr__(self):
return f'<EnterpriseBusiModelNode(id={self.id}, suid={self.suid})>'
class EnterpriseBusiModelNodeRepository(object):
# def get_by_id(self, id: int) -> Optional[EnterpriseBusiModelNode]:
# return self.db.query(EnterpriseBusiModelNode).filter(EnterpriseBusiModelNode.id == id).first()
def insert_busi_model_nodes(self, data: Dict):
data['suid'] = shortuuid.ShortUUID().random(10)
link_node_ids = [int(node_id) for node_id in data['node_ids'].split(',')]
with get_session() as session:
for node_id in link_node_ids:
logging.info("node_id: %s")
node_db = DB_Node.EnterpriseNodeRepository()
node = node_db.get_entity_suid_by_node_id(node_id)
node_suid = node["suid"]
entity_suid = node["entity_suid"]
model_node = EnterpriseBusiModelNode(
suid=shortuuid.ShortUUID().random(10),
entity_suid=entity_suid,
busi_model_id=data['busimodel_id'],
busi_model_suid=data['busimodel_suid'],
node_id=node_id,
node_suid=node_suid,
)
session.add(model_node)
session.commit()
return
def get_nodes_by_busi_model(self, busi_model_id: int, entity_suid: str) -> Optional[List[Row]]:
with get_session() as session:
# nodes = session.query(EnterpriseBusiModelNode, ).filter(
# EnterpriseBusiModelNode.busi_model_id == busi_model_id).all()
sql = ("select mn.node_id, n.name node_name from enterprise_busi_model_node mn, enterprise_node n "
"where mn.node_id=n.id "
" and mn.busi_model_id = :busi_model_id"
" and mn.entity_suid = :entity_suid")
nodes = session.execute(text(sql), {"busi_model_id": busi_model_id, "entity_suid": entity_suid})
return to_json_list(nodes)
def delete_by_busi_model_id(self, busi_model_id: int) -> None:
with get_session() as session:
session.query(EnterpriseBusiModelNode).filter(
EnterpriseBusiModelNode.busi_model_id == busi_model_id).delete()
session.commit()
def get_busi_model_by_node_id(self, node_id: int, page_no: int = 1, page_size: int = 10) -> Dict[Any, Any]:
with get_session() as session:
models = (
session.query(
EnterpriseBusiModel.id.label("busi_model_id"),
EnterpriseBusiModel.name.label("busi_model_name"),
EnterpriseBusiModel.create_time
)
.outerjoin(EnterpriseBusiModelNode, EnterpriseBusiModel.id == EnterpriseBusiModelNode.busi_model_id)
.filter(EnterpriseBusiModelNode.node_id == node_id)
.offset((page_no - 1) * page_size)
.limit(page_size)
.all()
)
total_count = session.query(func.count(EnterpriseBusiModelNode.id)).filter(
EnterpriseBusiModelNode.node_id == node_id).scalar()
return {
"count": total_count,
"data": [
{
"busi_model_id": model.busi_model_id,
"busi_model_name": model.busi_model_name,
"create_time": model.create_time.strftime("%Y-%m-%d %H:%M:%S"),
} for model in models
]
}