更新代码

main
周平 11 months ago
parent ada8cb2053
commit 5c524390b9

@ -1,11 +1,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from sqlalchemy.ext.declarative import declarative_base import logging
from typing import Union
from sqlalchemy import Column, Integer, String, DateTime, func from sqlalchemy import Column, Integer, String, DateTime, func
from typing import Any, Dict, List, Optional, Tuple, Union from sqlalchemy.ext.declarative import declarative_base
from website.db_mysql import get_session from website.db_mysql import get_session
from website.util import shortuuid
Base = declarative_base() Base = declarative_base()
@ -63,6 +64,29 @@ class ModelRepositry(object):
return model return model
def get_model_dict_by_id(self, model_id: int) -> dict:
with get_session() as session:
logging.info(f"model id is : {model_id}")
model = session.query(Model).filter(Model.id == model_id).first()
# if not model:
# return {}
model_dict = {
'id': model.id,
'suid': model.suid,
'name': model.name,
'model_type': model.model_type,
'classification': model.classification,
'comment': model.comment,
'default_version': model.default_version,
'delete': model.delete,
'create_time': model.create_time,
'update_time': model.update_time
}
logging.info(f"model dict is : {model_dict}")
return model_dict
def get_model_by_ids(self, model_ids: list) -> list: def get_model_by_ids(self, model_ids: list) -> list:
with get_session() as session: with get_session() as session:
models = session.query(Model).filter(Model.id.in_(model_ids)).all() models = session.query(Model).filter(Model.id.in_(model_ids)).all()
@ -71,7 +95,6 @@ class ModelRepositry(object):
return models return models
def get_model_count(self) -> int: def get_model_count(self) -> int:
with get_session() as session: with get_session() as session:
return session.query(Model).count() return session.query(Model).count()

@ -67,7 +67,10 @@ class EnterpriseBusiModel(Base):
def __repr__(self): def __repr__(self):
return f"EnterpriseBusiModel(id={self.id}, suid='{self.suid}', name='{self.name}')" return f"EnterpriseBusiModel(id={self.id}, suid='{self.suid}', name='{self.name}')"
def __init__(self, **kwargs):
valid_columns = {col.name for col in self.__table__.columns}
filtered_data = {key: value for key, value in kwargs.items() if key in valid_columns}
super().__init__(**filtered_data)
class EnterpriseBusiModelRepository(object): class EnterpriseBusiModelRepository(object):
@ -82,15 +85,19 @@ class EnterpriseBusiModelRepository(object):
data['entity_suid'] = entity_suid data['entity_suid'] = entity_suid
base_model_ids = [int(model_id) for model_id in data['basemodel_ids'].split(',')] base_model_ids = [int(model_id) for model_id in data['basemodel_ids'].split(',')]
base_model_db = DB_alg_model.ModelRepositry() logging.info("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
logging.info(base_model_ids)
base_model = [] base_model = []
for base_model_id in base_model_ids: for base_model_id in base_model_ids:
base_model_db = DB_alg_model.ModelRepositry()
# base_model_suid = base_model_db.get_suid(base_model_id) # base_model_suid = base_model_db.get_suid(base_model_id)
base_model_info = base_model_db.get_model_by_id(base_model_id) base_model_info = base_model_db.get_model_dict_by_id(base_model_id)
logging.info("#####################")
logging.info(base_model_info)
base_model.append({ base_model.append({
'id': base_model_id, 'id': base_model_id,
'suid': base_model_info.suid, 'suid': base_model_info["suid"],
'name': base_model_info.name, 'name': base_model_info["name"],
}) })
data['base_models'] = json.dumps(base_model) data['base_models'] = json.dumps(base_model)
new_data = copy.copy(data) new_data = copy.copy(data)
@ -104,13 +111,16 @@ class EnterpriseBusiModelRepository(object):
def edit_busi_model(self, data: Dict): def edit_busi_model(self, data: Dict):
base_model_ids = [int(model_id) for model_id in data['basemodel_ids'].split(',')] base_model_ids = [int(model_id) for model_id in data['basemodel_ids'].split(',')]
base_model_db = DB_alg_model.ModelRepositry()
base_model = [] base_model = []
for base_model_id in base_model_ids: for base_model_id in base_model_ids:
base_model_suid = base_model_db.get_suid(base_model_id) base_model_db = DB_alg_model.ModelRepositry()
# base_model_suid = base_model_db.get_suid(base_model_id)
base_model_info = base_model_db.get_model_dict_by_id(base_model_id)
base_model.append({ base_model.append({
'id': base_model_id, "id": base_model_id,
'suid': base_model_suid "suid": base_model_info["suid"],
"name": base_model_info["name"],
}) })
data['base_models'] = json.dumps(base_model) data['base_models'] = json.dumps(base_model)
@ -197,9 +207,10 @@ class EnterpriseBusiModelNodeRepository(object):
node_db = DB_Node.EnterpriseNodeRepository() node_db = DB_Node.EnterpriseNodeRepository()
node = node_db.get_node_by_id(node_id) node = node_db.get_node_by_id(node_id)
node_suid = node["suid"] node_suid = node["suid"]
entity_suid = node["entity_suid"]
model_node = EnterpriseBusiModelNode( model_node = EnterpriseBusiModelNode(
suid=shortuuid.ShortUUID().random(10), suid=shortuuid.ShortUUID().random(10),
entity_suid=data['entity_suid'], entity_suid=entity_suid,
busi_model_id=data['busi_model_id'], busi_model_id=data['busi_model_id'],
busi_model_suid=data['busi_model_suid'], busi_model_suid=data['busi_model_suid'],
node_id=node_id, node_id=node_id,

@ -1,11 +1,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging import logging
from typing import List from typing import List, Union, Optional, Tuple
from sqlalchemy import Column, Integer, String, DateTime, func, text from sqlalchemy import Column, Integer, String, DateTime, func, text
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from website.db_mysql import get_session, to_json_list from website.db_mysql import get_session, to_json_list, Row
""" """
CREATE TABLE `enterprise_busi_model_node_device` ( CREATE TABLE `enterprise_busi_model_node_device` (
@ -97,7 +97,15 @@ class EnterpriseBusiModelNodeDeviceRepository(object):
) )
return to_json_list(res) return to_json_list(res)
def get_busi_model_by_device(self, device_id: int = 0, device_suid: str = "") -> list: def get_busi_model_by_device(self, device_id: int = 0, device_suid: str = "", pagination: bool = False,
page_no: int = 0,
page_size: int = 0) -> Union[list, Tuple[Optional[List[Row]], int]]:
if not device_id and not device_suid:
logging.error("get_busi_model_by_device error: device_id and device_suid is null")
return []
res = []
count = 0
with get_session() as session: with get_session() as session:
sql = """ sql = """
select d.busi_model_id, m.name, m.base_models select d.busi_model_id, m.name, m.base_models
@ -105,13 +113,34 @@ class EnterpriseBusiModelNodeDeviceRepository(object):
where d.busi_model_id=m.id""" where d.busi_model_id=m.id"""
p = {} p = {}
sql_count = """
select count(1) from enterprise_busi_model_node_device d, enterprise_busi_model m where d.busi_model_id=m.id
"""
p_count = {}
if device_id: if device_id:
sql += " and d.device_id=:device_id" sql += " and d.device_id=:device_id"
p.update({"device_id": device_id}) p.update({"device_id": device_id})
sql_count += " and d.device_id=:device_id"
p_count.update({"device_id": device_id})
if device_suid: if device_suid:
sql += " and d.device_suid=:device_suid" sql += " and d.device_suid=:device_suid"
p.update({"device_suid": device_suid}) p.update({"device_suid": device_suid})
sql_count += " and d.device_suid=:device_suid"
p_count.update({"device_suid": device_suid})
if pagination:
sql += " order by d.id desc"
if page_no > 0:
sql += " limit :pageno, :pagesize"
p.update({"pageno": (page_no - 1) * page_size, "pagesize": page_size})
count = session.execute(text(sql_count), p_count).scalar()
res = session.execute(text(sql), p) res = session.execute(text(sql), p)
return to_json_list(res) res = to_json_list(res)
return res, count

@ -245,6 +245,11 @@ class EnterpriseDeviceRepository(object):
device_dict = {} device_dict = {}
if device: if device:
device_dict = { device_dict = {
"id": device.id,
"entity_id": device.entity_id,
"entity_suid": device.entity_suid,
"node_id": device.node_id,
"node_suid": device.node_suid,
"suid": device.suid, "suid": device.suid,
"name": device.name, "name": device.name,
"addr": device.addr, "addr": device.addr,
@ -277,12 +282,17 @@ class EnterpriseDeviceRepository(object):
for device in devices: for device in devices:
device_dict = { device_dict = {
"id": device.id, "id": device.id,
"entity_id": device.entity_id,
"entity_suid": device.entity_suid,
"node_id": device.node_id,
"node_suid": device.node_suid,
"suid": device.suid, "suid": device.suid,
"name": device.name, "name": device.name,
"addr": device.addr, "addr": device.addr,
"device_model": device.device_model, "device_model": device.device_model,
"param": device.param, "param": device.param,
"comment": device.comment, "comment": device.comment,
"classification": device.classification,
} }
device_dicts.append(device_dict) device_dicts.append(device_dict)
return device_dicts return device_dicts

@ -9,7 +9,7 @@ from website import db_mysql
from website import errors from website import errors
from website import settings from website import settings
from website.handler import APIHandler, authenticated from website.handler import APIHandler, authenticated
from website.util import md5 from website.util import md5, shortuuid
class ClassificationAddHandler(APIHandler): class ClassificationAddHandler(APIHandler):
@ -28,13 +28,19 @@ class ClassificationAddHandler(APIHandler):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称过长") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称过长")
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
cur = conn.execute(text("select id from model_classification where name=:name"), {"name": name}) cur = conn.execute(
text("select id from model_classification where name=:name"),
{"name": name},
)
row = db_mysql.to_json(cur) row = db_mysql.to_json(cur)
if row: if row:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称重复") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称重复")
conn.execute( conn.execute(
text("""insert into model_classification (name, create_time) values (:name, NOW())"""), {"name": name} text(
"""insert into model_classification (name, create_time) values (:name, NOW())"""
),
{"name": name},
) )
conn.commit() conn.commit()
@ -53,8 +59,10 @@ class ClassificationEditHandler(APIHandler):
if not classification_id or not name: if not classification_id or not name:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
conn.execute(text("""update model_classification set name=:name where id=:id"""), conn.execute(
{"name": name, "id": classification_id}) text("""update model_classification set name=:name where id=:id"""),
{"name": name, "id": classification_id},
)
conn.commit() conn.commit()
@ -87,7 +95,10 @@ class ClassificationDeleteHandler(APIHandler):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
conn.execute(text("""DELETE FROM model_classification WHERE id=:id"""), {"id": classification_id}) conn.execute(
text("""DELETE FROM model_classification WHERE id=:id"""),
{"id": classification_id},
)
conn.commit() conn.commit()
self.finish() self.finish()
@ -106,8 +117,10 @@ class ListHandler(APIHandler):
result = [] result = []
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
sql = "select m.id, m.name, m.model_type, m.default_version, mc.name classification_name, m.update_time " \ sql = (
"from model m left join model_classification mc on m.classification=mc.id where m.del=0 " "select m.id, m.name, m.model_type, m.default_version, mc.name classification_name, m.update_time "
"from model m left join model_classification mc on m.classification=mc.id where m.del=0 "
)
param = {} param = {}
@ -132,18 +145,31 @@ class ListHandler(APIHandler):
data = [] data = []
for item in result: for item in result:
data.append({ data.append(
"id": item["id"], {
"name": item["name"], "id": item["id"],
"model_type": consts.model_type_map[item["model_type"]], "name": item["name"],
"classification_name": item["classification_name"], "model_type": consts.model_type_map[item["model_type"]],
"default_version": item["default_version"], "classification_name": item["classification_name"],
"update_time": str(item["update_time"]) "default_version": item["default_version"],
}) "update_time": str(item["update_time"]),
}
)
self.finish({"data": data, "count": count}) self.finish({"data": data, "count": count})
class ListSimpleHandler(APIHandler):
@authenticated
def post(self):
with self.app_mysql.connect() as conn:
sql = "select id, name from model where del=0"
cur = conn.execute(text(sql))
res = db_mysql.to_json_list(cur)
self.finish({"result": res})
class AddHandler(APIHandler): class AddHandler(APIHandler):
""" """
添加模型 添加模型
@ -152,7 +178,9 @@ class AddHandler(APIHandler):
@authenticated @authenticated
def post(self): def post(self):
name = self.get_escaped_argument("name", "") name = self.get_escaped_argument("name", "")
model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法1002/深度学习 model_type = self.get_int_argument(
"model_type", consts.model_type_machine
) # 1001/经典算法1002/深度学习
classification = self.get_int_argument("classification") classification = self.get_int_argument("classification")
comment = self.get_escaped_argument("comment", "") comment = self.get_escaped_argument("comment", "")
@ -167,9 +195,16 @@ class AddHandler(APIHandler):
conn.execute( conn.execute(
text( text(
"""insert into model (name, model_type, classification, comment, create_time, update_time) """insert into model (suid, name, model_type, classification, comment, create_time, update_time)
values (:name, :model_type, :classification, :comment, NOW(), NOW())"""), values (:suid, :name, :model_type, :classification, :comment, NOW(), NOW())"""
{"name": name, "model_type": model_type, "classification": classification, "comment": comment} ),
{
"suid": shortuuid.ShortUUID().random(10),
"name": name,
"model_type": model_type,
"classification": classification,
"comment": comment,
},
) )
conn.commit() conn.commit()
@ -186,7 +221,9 @@ class EditHandler(APIHandler):
def post(self): def post(self):
mid = self.get_int_argument("id") mid = self.get_int_argument("id")
name = self.get_escaped_argument("name", "") name = self.get_escaped_argument("name", "")
model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法1002/深度学习 model_type = self.get_int_argument(
"model_type", consts.model_type_machine
) # 1001/经典算法1002/深度学习
classification = self.get_int_argument("classification") classification = self.get_int_argument("classification")
comment = self.get_escaped_argument("comment", "") comment = self.get_escaped_argument("comment", "")
@ -199,9 +236,15 @@ class EditHandler(APIHandler):
"""update model """update model
set name=:name, model_type=:model_type, classification=:classification, comment=:comment, set name=:name, model_type=:model_type, classification=:classification, comment=:comment,
update_time=NOW() update_time=NOW()
where id=:id"""), where id=:id"""
{"name": name, "model_type": model_type, "classification": classification, "comment": comment, ),
"id": mid} {
"name": name,
"model_type": model_type,
"classification": classification,
"comment": comment,
"id": mid,
},
) )
conn.commit() conn.commit()
@ -222,14 +265,16 @@ class InfoHandler(APIHandler):
result = {} result = {}
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
cur = conn.execute( cur = conn.execute(
text(""" text(
select """
m.name, m.model_type, m.comment, m.update_time, select
mc.id as classification_id, mc.name as classification_name m.name, m.model_type, m.comment, m.update_time,
from model m, model_classification mc mc.id as classification_id, mc.name as classification_name
where m.id=:id and m.classification=mc.id from model m, model_classification mc
"""), where m.id=:id and m.classification=mc.id
{"id": mid} """
),
{"id": mid},
) )
result = db_mysql.to_json(cur) result = db_mysql.to_json(cur)
@ -242,7 +287,7 @@ class InfoHandler(APIHandler):
"classification_id": result["classification_id"], "classification_id": result["classification_id"],
"classification_name": result["classification_name"], "classification_name": result["classification_name"],
"comment": result["comment"], "comment": result["comment"],
"update_time": str(result["update_time"]) "update_time": str(result["update_time"]),
} }
self.finish(data) self.finish(data)
@ -306,9 +351,15 @@ class VersionAddHandler(APIHandler):
(model_id, version, comment,model_file, config_file, config_str, create_time, update_time) (model_id, version, comment,model_file, config_file, config_str, create_time, update_time)
values (:model_id, :version, :comment, :model_file, :config_file, :config_str, NOW(), NOW())""" values (:model_id, :version, :comment, :model_file, :config_file, :config_str, NOW(), NOW())"""
), ),
{
{"model_id": mid, "version": version, "comment": comment, "model_file": model_file, "model_id": mid,
"config_file": config_file, "config_str": config_str}) "version": version,
"comment": comment,
"model_file": model_file,
"config_file": config_file,
"config_str": config_str,
},
)
conn.commit() conn.commit()
self.finish() self.finish()
@ -349,9 +400,16 @@ class VersionEditHandler(APIHandler):
text( text(
"update model_version " "update model_version "
"set version=:version, comment=:comment, model_file=:model_file, config_file=:config_file, " "set version=:version, comment=:comment, model_file=:model_file, config_file=:config_file, "
"config_str=:config_str, update_time=NOW() where id=:id"), "config_str=:config_str, update_time=NOW() where id=:id"
{"version": version, "comment": comment, "model_file": model_file, "config_file": config_file, ),
"config_str": config_str, "id": version_id} {
"version": version,
"comment": comment,
"model_file": model_file,
"config_file": config_file,
"config_str": config_str,
"id": version_id,
},
) )
self.finish() self.finish()
@ -407,10 +465,10 @@ class VersionListHandler(APIHandler):
order by mv.id desc limit :offset, :limit order by mv.id desc limit :offset, :limit
""" """
), ),
{"mid": model_id, "offset": (pageNo - 1) * pageSize, "limit": pageSize} {"mid": model_id, "offset": (pageNo - 1) * pageSize, "limit": pageSize},
) )
result = db_mysql.to_json(cur) result = db_mysql.to_json_list(cur)
# 获取记录数量 # 获取记录数量
count = conn.execute( count = conn.execute(
@ -418,21 +476,23 @@ class VersionListHandler(APIHandler):
""" """
select count(*) select count(*)
from model_version mv from model_version mv
where mv.del=0 where mv.del=0 and model_id=:mid
""" """
), ),
{"mid": model_id} {"mid": model_id},
).scalar() ).scalar()
for item in result: for item in result:
data.append({ data.append(
"version_id": item["version_id"], # 版本id {
"version": item["version"], # 版本号 "version_id": item["version_id"], # 版本id
"path": item["filepath"], # 文件路径 "version": item["version"], # 版本号
"size": item["filesize"], "path": item["filepath"], # 文件路径
"update_time": str(item["update_time"]), "size": item["filesize"],
"is_default": item["is_default"] # 是否默认1/默认0/非默认 "update_time": str(item["update_time"]),
}) "is_default": item["is_default"], # 是否默认1/默认0/非默认
}
)
self.finish({"count": count, "data": data}) self.finish({"count": count, "data": data})
@ -475,15 +535,16 @@ class VersionInfoHandler(APIHandler):
"config_file_name": "", "config_file_name": "",
"config_file_size": 0, "config_file_size": 0,
"config_file_md5": "", "config_file_md5": "",
"config_str": "" "config_str": "",
} }
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
cur = conn.execute( cur = conn.execute(
text( text(
"""select m.name as model_name, mv.version, mv.comment, mv.model_file, mv.config_file, mv.config_str """select m.name as model_name, mv.version, mv.comment, mv.model_file, mv.config_file, mv.config_str
from model_version mv, model m where mv.id=:id and mv.model_id=m.id"""), from model_version mv, model m where mv.id=:id and mv.model_id=m.id"""
{"id": version_id} ),
{"id": version_id},
) )
result = db_mysql.to_json(cur) result = db_mysql.to_json(cur)
@ -502,21 +563,31 @@ class VersionInfoHandler(APIHandler):
# 获取文件信息 # 获取文件信息
if model_file: if model_file:
cur_model_file = conn.execute( cur_model_file = conn.execute(
text("select filename, filesize from files where md5_str=:md5_str"), {"md5_str": model_file} text("select filename, filesize from files where md5_str=:md5_str"),
{"md5_str": model_file},
) )
model_file_info = db_mysql.to_json(cur_model_file) model_file_info = db_mysql.to_json(cur_model_file)
response["model_file_name"] = model_file_info["filename"] response["model_file_name"] = (
response["model_file_size"] = model_file_info["filesize"] model_file_info["filename"] if model_file_info else ""
)
response["model_file_size"] = (
model_file_info["filesize"] if model_file_info else 0
)
if config_file: if config_file:
cur_config_file = conn.execute( cur_config_file = conn.execute(
text("select filename, filesize from files where md5_str=:md5_str"), {"md5_str": config_file} text("select filename, filesize from files where md5_str=:md5_str"),
{"md5_str": config_file},
) )
config_file_info = db_mysql.to_json(cur_config_file) config_file_info = db_mysql.to_json(cur_config_file)
response["config_file_name"] = config_file_info["filename"] response["config_file_name"] = (
response["config_file_size"] = config_file_info["filesize"] config_file_info["filename"] if config_file_info else ""
)
response["config_file_size"] = (
config_file_info["filesize"] if config_file_info else 0
)
response["config_str"] = result["config_str"] response["config_str"] = self.unescape_string(result["config_str"])
self.finish(response) self.finish(response)
@ -543,11 +614,13 @@ class VersionSetDefaultHandler(APIHandler):
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
conn.execute( conn.execute(
text("update model_version set is_default=0 where model_id=:model_id"), text("update model_version set is_default=0 where model_id=:model_id"),
{"model_id": model_id}) {"model_id": model_id},
)
conn.execute( conn.execute(
text("update model_version set is_default=1 where id=:id"), text("update model_version set is_default=1 where id=:id"),
{"id": version_id}) {"id": version_id},
)
conn.commit() conn.commit()
self.finish() self.finish()
@ -573,8 +646,12 @@ class VersionDeleteHandler(APIHandler):
row = {} row = {}
# 获取模型对应的model_file, config_file使用model_file, config_file删除对应的存储文件 # 获取模型对应的model_file, config_file使用model_file, config_file删除对应的存储文件
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
cur = conn.execute(text("select model_id, model_file, config_file from model_version where id=:id"), cur = conn.execute(
{"id": version_id}) text(
"select model_id, model_file, config_file from model_version where id=:id"
),
{"id": version_id},
)
row = db_mysql.to_json(cur) row = db_mysql.to_json(cur)
if not row: if not row:
@ -585,20 +662,29 @@ class VersionDeleteHandler(APIHandler):
# 清空模型默认版本 # 清空模型默认版本
conn.execute( conn.execute(
text("update model set default_version='' where id=:id"), {"id": model_id} text("update model set default_version='' where id=:id"),
{"id": model_id},
) )
# 删除文件 # 删除文件
try: try:
conn.execute(text("delete from files where md5_str=:md5_str"), {"md5_str": model_file}) conn.execute(
conn.execute(text("delete from files where md5_str=:md5_str"), {"md5_str": config_file}) text("delete from files where md5_str=:md5_str"),
{"md5_str": model_file},
)
conn.execute(
text("delete from files where md5_str=:md5_str"),
{"md5_str": config_file},
)
os.remove(settings.file_upload_dir + "model/" + model_file) os.remove(settings.file_upload_dir + "model/" + model_file)
os.remove(settings.file_upload_dir + "model/" + config_file) os.remove(settings.file_upload_dir + "model/" + config_file)
except Exception as e: except Exception as e:
logging.info(e) logging.info(e)
conn.execute(text("delete from model_version where id=:id"), {"id": version_id}) conn.execute(
text("delete from model_version where id=:id"), {"id": version_id}
)
conn.commit() conn.commit()

@ -7,6 +7,7 @@ handlers = [
("/model/classification/delete", handler.ClassificationDeleteHandler), ("/model/classification/delete", handler.ClassificationDeleteHandler),
("/model/list", handler.ListHandler), ("/model/list", handler.ListHandler),
("/model/list/simple", handler.ListSimpleHandler),
("/model/add", handler.AddHandler), ("/model/add", handler.AddHandler),
("/model/edit", handler.EditHandler), ("/model/edit", handler.EditHandler),
("/model/info", handler.InfoHandler), ("/model/info", handler.InfoHandler),

@ -11,6 +11,7 @@ from website import errors
from website import settings from website import settings
from website.handler import APIHandler, authenticated from website.handler import APIHandler, authenticated
class ListHandler(APIHandler): class ListHandler(APIHandler):
""" """
- 描述 模型运行库列表 - 描述 模型运行库列表
@ -35,6 +36,7 @@ class ListHandler(APIHandler):
} }
``` ```
""" """
@authenticated @authenticated
def post(self): def post(self):
pageNo = self.get_int_argument("pageNo", 1) pageNo = self.get_int_argument("pageNo", 1)
@ -45,20 +47,20 @@ class ListHandler(APIHandler):
count = 0 count = 0
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
sql = "select id, name, create_time, update_time from model_hub where 1=1" sql = "select id, name, path, create_time, update_time from model_hub where 1=1"
param = {} param = {}
sql_count = "select count(id) from model where 1=1" sql_count = "select count(id) from model_hub where 1=1"
param_count = {} param_count = {}
if name: if name:
sql += "and m.name like :name" sql += " and name like :name"
param["name"] = "%{}%".format(name) param["name"] = "%{}%".format(name)
sql_count += "and m.name like :name" sql_count += " and name like :name"
param_count["name"] = "%{}%".format(name) param_count["name"] = "%{}%".format(name)
sql += " order by m.id desc limit :pageSize offset :offset" sql += " order by id desc limit :pageSize offset :offset"
param["pageSize"] = pageSize param["pageSize"] = pageSize
param["offset"] = (pageNo - 1) * pageSize param["offset"] = (pageNo - 1) * pageSize
@ -69,55 +71,63 @@ class ListHandler(APIHandler):
data = [] data = []
for item in result: for item in result:
data.append({ data.append(
"id": item["id"], {
"name": item["name"], "id": item["id"],
"create_time": item["create_time"].strftime("%Y-%m-%d %H:%M:%S"), "name": item["name"],
"update_time": item["update_time"].strftime("%Y-%m-%d %H:%M:%S") "path": item["path"],
}) "create_time": item["create_time"].strftime("%Y-%m-%d %H:%M:%S"),
"update_time": item["update_time"].strftime("%Y-%m-%d %H:%M:%S"),
}
)
self.finish({"count": count, "data": data}) self.finish({"count": count, "data": data})
class SyncHandler(APIHandler): class SyncHandler(APIHandler):
""" """
- 描述 查询docker registry中的镜像 - 描述 查询docker registry中的镜像
- 请求方式post - 请求方式post
- 请求参数 - 请求参数
> - host, string, ip地址 > - host, string, ip地址
> - port, int, 端口 > - port, int, 端口
- 返回值 - 返回值
```
{
"data": [
"xxx", # docker registry中docker images的地址
"xxx",
...
]
}
``` ```
{
"data": [
"xxx", # docker registry中docker images的地址
"xxx",
...
]
}
```
""" """
@authenticated @authenticated
def post(self): def post(self):
host = self.get_escaped_argument("host", "") host = self.get_escaped_argument("host", "")
port = self.get_int_argument("port") port = self.get_int_argument("port")
if not host or not port: if not host or not port:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "host and port must be provided.") raise errors.HTTPAPIError(
errors.ERROR_BAD_REQUEST, "host and port must be provided."
)
images = [] images = []
# 查询docker registry中的镜像 # 查询docker registry中的镜像
repositories = requests.get("http://{}:{}/v2/_catalog".format(host, port)).json()["repositories"] repositories = requests.get(
"http://{}:{}/v2/_catalog".format(host, port)
).json()["repositories"]
for repository in repositories: for repository in repositories:
# 查询docker registry中的镜像的tag # 查询docker registry中的镜像的tag
tags = requests.get("http://{}:{}/v2/{}/tags/list".format(host, port, repository)).json()["tags"] tags = requests.get(
"http://{}:{}/v2/{}/tags/list".format(host, port, repository)
).json()["tags"]
for tag in tags: for tag in tags:
image_name = "{}:{}/{}:{}".format(host, port, repository, tag) image_name = "{}:{}/{}:{}".format(host, port, repository, tag)
images.append(image_name) images.append(image_name)
self.finish({"data": images}) self.finish({"data": images})
class AddHandler(APIHandler): class AddHandler(APIHandler):
""" """
- 描述 新建模型运行库 - 描述 新建模型运行库
@ -130,6 +140,7 @@ class AddHandler(APIHandler):
> - comment, string, 备注 > - comment, string, 备注
- 返回值 - 返回值
""" """
@authenticated @authenticated
def post(self): def post(self):
name = self.get_escaped_argument("name", "") name = self.get_escaped_argument("name", "")
@ -138,12 +149,24 @@ class AddHandler(APIHandler):
path = self.get_escaped_argument("path", "") path = self.get_escaped_argument("path", "")
comment = self.get_escaped_argument("comment", "") comment = self.get_escaped_argument("comment", "")
if not name or not host or not port: if not name or not host or not port:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "name and host and port must be provided.") raise errors.HTTPAPIError(
errors.ERROR_BAD_REQUEST, "name and host and port must be provided."
)
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
conn.execute(text("""insert into model_hub (name, host, port, path, comment, create_time, update_time) conn.execute(
values (:name, :host, :port, :path, :comment, NOW(), NOW())"""), text(
{"name": name, "host": host, "port": port, "path": path, "comment": comment}) """insert into model_hub (name, host, port, path, comment, create_time, update_time)
values (:name, :host, :port, :path, :comment, NOW(), NOW())"""
),
{
"name": name,
"host": host,
"port": port,
"path": path,
"comment": comment,
},
)
conn.commit() conn.commit()
@ -163,6 +186,7 @@ class EditHandler(APIHandler):
> - comment, string, 备注 > - comment, string, 备注
- 返回值 - 返回值
""" """
@authenticated @authenticated
def post(self): def post(self):
id = self.get_int_argument("id") id = self.get_int_argument("id")
@ -174,8 +198,20 @@ class EditHandler(APIHandler):
if not id or not name or not host or not port or path: if not id or not name or not host or not port or path:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "parameter error") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "parameter error")
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
conn.execute(text("""update model_hub set name=:name, host=:host, port=:port, path=:path, comment=:comment, update_time=NOW() conn.execute(
where id=:id"""), {"id": id, "name": name, "host": host, "port": port, "path": path, "comment": comment}) text(
"""update model_hub set name=:name, host=:host, port=:port, path=:path, comment=:comment, update_time=NOW()
where id=:id"""
),
{
"id": id,
"name": name,
"host": host,
"port": port,
"path": path,
"comment": comment,
},
)
conn.commit() conn.commit()
self.finish() self.finish()
@ -197,6 +233,7 @@ class InfoHandler(APIHandler):
} }
``` ```
""" """
@authenticated @authenticated
def post(self): def post(self):
hid = self.get_int_argument("id") hid = self.get_int_argument("id")
@ -205,11 +242,18 @@ class InfoHandler(APIHandler):
result = {} result = {}
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
cur = conn.execute(text("""select name, host, port, path, comment from model_hub where id=:id"""), {"id": hid}) cur = conn.execute(
text(
"""select name, host, port, path, comment from model_hub where id=:id"""
),
{"id": hid},
)
result = db_mysql.to_json(cur) result = db_mysql.to_json(cur)
if not result: if not result:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "model hub not found") raise errors.HTTPAPIError(
self.finish({"data": result}) errors.ERROR_BAD_REQUEST, "model hub not found"
)
self.finish(result)
class DeleteHandler(APIHandler): class DeleteHandler(APIHandler):
@ -220,6 +264,7 @@ class DeleteHandler(APIHandler):
> - id, int > - id, int
- 返回值 - 返回值
""" """
@authenticated @authenticated
def post(self): def post(self):
hid = self.get_int_argument("id") hid = self.get_int_argument("id")

@ -35,14 +35,14 @@ class ListHandler(APIHandler):
""" """
@authenticated @authenticated
def post(self): def post(self):
pageNo = self.get_argument("pageNo", 1) pageNo = self.get_int_argument("pageNo", 1)
pageSize = self.get_argument("pageSize", 10) pageSize = self.get_int_argument("pageSize", 10)
entity_id = self.get_argument("entity_id") entity_id = self.get_int_argument("entity_id")
if not entity_id: if not entity_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数错误") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数错误")
db_busimodel = DB_BusiModel.EnterpriseBusiModelRepository() db_busimodel = DB_BusiModel.EnterpriseBusiModelRepository()
result = db_busimodel.list_enterprise_busi_model(entity_id, pageNo, pageSize) result = db_busimodel.list_enterprise_busi_model(entity_id=entity_id, page_no=pageNo, page_size=pageSize)
self.finish({"count": result["count"], "data": result["data"]}) self.finish({"count": result["count"], "data": result["data"]})
@ -63,7 +63,7 @@ class AddHandler(APIHandler):
""" """
@authenticated @authenticated
def post(self): def post(self):
entity_id = self.get_argument("entity_id") entity_id = self.get_int_argument("entity_id")
name = self.get_escaped_argument("name", "") name = self.get_escaped_argument("name", "")
comment = self.get_escaped_argument("comment", "") comment = self.get_escaped_argument("comment", "")
basemodel_ids = self.get_escaped_argument("basemodel_ids", "") basemodel_ids = self.get_escaped_argument("basemodel_ids", "")
@ -131,7 +131,7 @@ class InfoHandler(APIHandler):
""" """
@authenticated @authenticated
def post(self): def post(self):
busimodel_id = self.get_argument("id") busimodel_id = self.get_int_argument("id")
if not busimodel_id: if not busimodel_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数错误") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数错误")

@ -2,17 +2,21 @@
import json import json
import logging import logging
import random import random
from sqlalchemy import text from sqlalchemy import text
from website import consts
from website import db_mysql, errors from website import db_mysql, errors
from website.db.enterprise_device import enterprise_device as DB_Device from website.db.alg_model.alg_model import ModelRepositry as DB_AlgModel
from website.db.device_classification import ( from website.db.device_classification import (
device_classification as DB_DeviceClassification, device_classification as DB_DeviceClassification,
) )
from website.db.enterprise_busi_model import enterprise_busi_model_node_device as DB_BusiModelNodeDevice from website.db.enterprise_busi_model import (
enterprise_busi_model as DB_BusiModel,
)
from website.db.enterprise_device import enterprise_device as DB_Device
from website.handler import APIHandler, authenticated from website.handler import APIHandler, authenticated
from website.util import shortuuid from website.util import shortuuid
from website import consts
class DeviceClassificationAddHandler(APIHandler): class DeviceClassificationAddHandler(APIHandler):
@ -228,8 +232,6 @@ class DeviceEditHandler(APIHandler):
class DeviceDeleteHandler(APIHandler): class DeviceDeleteHandler(APIHandler):
""" """
### /enterprise/entity/nodes/device/delete
- 描述企业节点删除设备 - 描述企业节点删除设备
- 请求方式post - 请求方式post
- 请求参数 - 请求参数
@ -338,6 +340,115 @@ class DeviceInfoHandler(APIHandler):
self.finish(device) self.finish(device)
class DeviceBasemodelListHandler(APIHandler):
"""
- 描述企业节点节点信息 -> 设备列表 -> 基础模型配置 -> 模型列表
- 请求方式post
- 请求参数
> - pageNo
> - pageSize
> - device_id, int, 设备id
- 返回值
```
{
"count": 123,
"data": [
{
"busi_model": "xxx", # 业务模型的名称
"base_model": [
{
"model_id": 123, # 基础模型id
"model_name": "xxx" # 基础模型name
"model_version": "xxx", # 基础模型的版本
"model_hub_image": "xxx", # 运行库镜像
},
...
]
},
...
]
}
```
"""
@authenticated
def post(self):
device_id = self.get_int_argument("device_id")
if not device_id:
raise errors.HTTPAPIError(
errors.ERROR_BAD_REQUEST, "企业节点或设备不能为空"
)
pageNo = self.get_int_argument("pageNo", 1)
pageSize = self.get_int_argument("pageSize", 10)
db_busi_model = DB_BusiModelNodeDevice.EnterpriseBusiModelNodeDeviceRepository()
busi_models, count = db_busi_model.get_busi_model_by_device(device_id=device_id, pagination=True,
page_no=pageNo,
page_size=pageSize)
for item in busi_models:
busi_model_id = item["busi_model_id"]
busi_model_name = item["name"]
base_model_list = json.loads(item["base_models"])
for base_model in base_model_list:
base_model_id = base_model["id"]
base_model_suid = base_model["suid"]
base_model_name = base_model["name"]
self.finish()
class DeviceBaseModelCustomConfigHandler(APIHandler):
"""
- 描述企业节点节点信息 -> 设备列表 -> 基础模型配置 -> 基础模型参数配置
- 请求方式post
- 请求参数
> - device_id, int, 设备id
> - node_id, int, 节点id
> - base_model_id, int, 基础模型id
> - busi_conf_file, string, 业务参数配置文件md5
> - busi_conf_str, string, 业务参数配置json字符串eg:'{"key1": "value1", "key2": "value2"}'
> - model_hub_image, string, 运行库地址
> - model_conf_file, string, 模型参数配置文件md5
> - model_conf_str, string, 模型参数配置json字符串eg:'{"key1": "value1", "key2": "value2"}'
- 返回值
"""
@authenticated
def post(self):
device_id = self.get_int_argument("device_id")
node_id = self.get_int_argument("node_id")
busi_model_id = self.get_int_argument("busi_model_id")
base_model_id = self.get_int_argument("base_model_id")
busi_conf_file = self.get_escaped_argument("busi_conf_file")
busi_conf_str = self.get_escaped_argument("busi_conf_str")
model_hub_image = self.get_escaped_argument("model_hub_image")
model_conf_file = self.get_escaped_argument("model_conf_file")
model_conf_str = self.get_escaped_argument("model_conf_str")
db_device = DB_Device.EnterpriseDeviceRepository()
device = db_device.get_device(device_id=device_id)
if not device:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "device not exist")
if device["node_id"] != node_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "device not belong to this node")
db_alg_model = DB_AlgModel()
base_model = db_alg_model.get_model_by_id(base_model_id)
if not base_model:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "base model not exist")
device_suid = device["suid"]
node_suid = device["node_suid"]
entity_id = device["entity_id"]
entity_suid = device["entity_suid"]
db_busi_model = DB_BusiModel.EnterpriseBusiModelRepository()
busi_model = db_busi_model.get_busi_model_by_id(busi_model_id)
busi_model_suid = busi_model.suid
self.finish()
class StatusListHandler(APIHandler): class StatusListHandler(APIHandler):
""" """
- 描述设备状态列表 - 描述设备状态列表
@ -379,26 +490,32 @@ class StatusListHandler(APIHandler):
if not entity_id: if not entity_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "企业节点不能为空") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "企业节点不能为空")
if status not in consts.device_status_map: if status not in consts.device_status_map:
raise errors.HTTPAPIError( raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "状态参数错误")
errors.ERROR_BAD_REQUEST, "状态参数错误")
db_device = DB_Device.EnterpriseDeviceRepository() db_device = DB_Device.EnterpriseDeviceRepository()
res = db_device.list_entity_devices(entity_id=entity_id, pageNo=pageNo, pageSize=pageSize, res = db_device.list_entity_devices(
classification=classification, status=status) entity_id=entity_id,
pageNo=pageNo,
pageSize=pageSize,
classification=classification,
status=status,
)
count = res["count"] count = res["count"]
devices = res["devices"] devices = res["devices"]
data = [] data = []
for item in devices: for item in devices:
data.append({ data.append(
"id": item.id, {
"name": item.name, "id": item.id,
"status": item.status, "name": item.name,
"cpu": random.randint(20, 30), "status": item.status,
"mem": random.randint(20, 30), "cpu": random.randint(20, 30),
"storage": random.randint(20, 30), "mem": random.randint(20, 30),
"gpu": random.randint(20, 30), "storage": random.randint(20, 30),
}) "gpu": random.randint(20, 30),
}
)
self.finish({"count": count, "data": data}) self.finish({"count": count, "data": data})
@ -418,11 +535,10 @@ class StatusInfoHandler(APIHandler):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "设备不存在") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "设备不存在")
res = res[0] res = res[0]
db_busi_model = DB_BusiModelNodeDevice.EnterpriseBusiModelNodeDeviceRepository() db_busi_model = DB_BusiModelNodeDevice.EnterpriseBusiModelNodeDeviceRepository()
busi_models = db_busi_model.get_busi_model_by_device(device_id=device_id) busi_models, _ = db_busi_model.get_busi_model_by_device(device_id=device_id)
for busi_model in busi_models: for busi_model in busi_models:
base_model_ids = json.loads(busi_model["base_models"]) base_model_ids = json.loads(busi_model["base_models"])
self.finish() self.finish()
@ -431,5 +547,4 @@ class StatusLogHandler(APIHandler):
@authenticated @authenticated
def post(self): def post(self):
self.finish() self.finish()

@ -15,7 +15,11 @@ handlers = [
("/enterprise/entity/nodes/device/list", handler.DeviceListHandler), ("/enterprise/entity/nodes/device/list", handler.DeviceListHandler),
("/enterprise/entity/nodes/device/list/simple", handler.DeviceListSimpleHandler), ("/enterprise/entity/nodes/device/list/simple", handler.DeviceListSimpleHandler),
("/enterprise/entity/nodes/device/info", handler.DeviceInfoHandler), ("/enterprise/entity/nodes/device/info", handler.DeviceInfoHandler),
("/enterprise/entity/nodes/device/basemodel/list", handler.DeviceBasemodelListHandler),
(
"/enterprise/entity/nodes/device/basemodel/custom/config",
handler.DeviceBaseModelCustomConfigHandler,
),
("/enterprise/device/status/list", handler.StatusListHandler), ("/enterprise/device/status/list", handler.StatusListHandler),
("/enterprise/device/status/info", handler.StatusInfoHandler), ("/enterprise/device/status/info", handler.StatusInfoHandler),
("/enterprise/device/status/log", handler.StatusLogHandler), ("/enterprise/device/status/log", handler.StatusLogHandler),

Loading…
Cancel
Save