更新代码

main
周平 11 months ago
parent ada8cb2053
commit 5c524390b9

@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-
from sqlalchemy.ext.declarative import declarative_base
import logging
from typing import Union
from sqlalchemy import Column, Integer, String, DateTime, func
from typing import Any, Dict, List, Optional, Tuple, Union
from sqlalchemy.ext.declarative import declarative_base
from website.db_mysql import get_session
from website.util import shortuuid
Base = declarative_base()
@ -63,6 +64,29 @@ class ModelRepositry(object):
return model
def get_model_dict_by_id(self, model_id: int) -> dict:
with get_session() as session:
logging.info(f"model id is : {model_id}")
model = session.query(Model).filter(Model.id == model_id).first()
# if not model:
# return {}
model_dict = {
'id': model.id,
'suid': model.suid,
'name': model.name,
'model_type': model.model_type,
'classification': model.classification,
'comment': model.comment,
'default_version': model.default_version,
'delete': model.delete,
'create_time': model.create_time,
'update_time': model.update_time
}
logging.info(f"model dict is : {model_dict}")
return model_dict
def get_model_by_ids(self, model_ids: list) -> list:
with get_session() as session:
models = session.query(Model).filter(Model.id.in_(model_ids)).all()
@ -71,7 +95,6 @@ class ModelRepositry(object):
return models
def get_model_count(self) -> int:
with get_session() as session:
return session.query(Model).count()

@ -67,7 +67,10 @@ class EnterpriseBusiModel(Base):
def __repr__(self):
return f"EnterpriseBusiModel(id={self.id}, suid='{self.suid}', name='{self.name}')"
def __init__(self, **kwargs):
valid_columns = {col.name for col in self.__table__.columns}
filtered_data = {key: value for key, value in kwargs.items() if key in valid_columns}
super().__init__(**filtered_data)
class EnterpriseBusiModelRepository(object):
@ -82,15 +85,19 @@ class EnterpriseBusiModelRepository(object):
data['entity_suid'] = entity_suid
base_model_ids = [int(model_id) for model_id in data['basemodel_ids'].split(',')]
base_model_db = DB_alg_model.ModelRepositry()
logging.info("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
logging.info(base_model_ids)
base_model = []
for base_model_id in base_model_ids:
base_model_db = DB_alg_model.ModelRepositry()
# base_model_suid = base_model_db.get_suid(base_model_id)
base_model_info = base_model_db.get_model_by_id(base_model_id)
base_model_info = base_model_db.get_model_dict_by_id(base_model_id)
logging.info("#####################")
logging.info(base_model_info)
base_model.append({
'id': base_model_id,
'suid': base_model_info.suid,
'name': base_model_info.name,
'suid': base_model_info["suid"],
'name': base_model_info["name"],
})
data['base_models'] = json.dumps(base_model)
new_data = copy.copy(data)
@ -104,13 +111,16 @@ class EnterpriseBusiModelRepository(object):
def edit_busi_model(self, data: Dict):
base_model_ids = [int(model_id) for model_id in data['basemodel_ids'].split(',')]
base_model_db = DB_alg_model.ModelRepositry()
base_model = []
for base_model_id in base_model_ids:
base_model_suid = base_model_db.get_suid(base_model_id)
base_model_db = DB_alg_model.ModelRepositry()
# base_model_suid = base_model_db.get_suid(base_model_id)
base_model_info = base_model_db.get_model_dict_by_id(base_model_id)
base_model.append({
'id': base_model_id,
'suid': base_model_suid
"id": base_model_id,
"suid": base_model_info["suid"],
"name": base_model_info["name"],
})
data['base_models'] = json.dumps(base_model)
@ -197,9 +207,10 @@ class EnterpriseBusiModelNodeRepository(object):
node_db = DB_Node.EnterpriseNodeRepository()
node = node_db.get_node_by_id(node_id)
node_suid = node["suid"]
entity_suid = node["entity_suid"]
model_node = EnterpriseBusiModelNode(
suid=shortuuid.ShortUUID().random(10),
entity_suid=data['entity_suid'],
entity_suid=entity_suid,
busi_model_id=data['busi_model_id'],
busi_model_suid=data['busi_model_suid'],
node_id=node_id,

@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-
import logging
from typing import List
from typing import List, Union, Optional, Tuple
from sqlalchemy import Column, Integer, String, DateTime, func, text
from sqlalchemy.ext.declarative import declarative_base
from website.db_mysql import get_session, to_json_list
from website.db_mysql import get_session, to_json_list, Row
"""
CREATE TABLE `enterprise_busi_model_node_device` (
@ -97,7 +97,15 @@ class EnterpriseBusiModelNodeDeviceRepository(object):
)
return to_json_list(res)
def get_busi_model_by_device(self, device_id: int = 0, device_suid: str = "") -> list:
def get_busi_model_by_device(self, device_id: int = 0, device_suid: str = "", pagination: bool = False,
page_no: int = 0,
page_size: int = 0) -> Union[list, Tuple[Optional[List[Row]], int]]:
if not device_id and not device_suid:
logging.error("get_busi_model_by_device error: device_id and device_suid is null")
return []
res = []
count = 0
with get_session() as session:
sql = """
select d.busi_model_id, m.name, m.base_models
@ -105,13 +113,34 @@ class EnterpriseBusiModelNodeDeviceRepository(object):
where d.busi_model_id=m.id"""
p = {}
sql_count = """
select count(1) from enterprise_busi_model_node_device d, enterprise_busi_model m where d.busi_model_id=m.id
"""
p_count = {}
if device_id:
sql += " and d.device_id=:device_id"
p.update({"device_id": device_id})
sql_count += " and d.device_id=:device_id"
p_count.update({"device_id": device_id})
if device_suid:
sql += " and d.device_suid=:device_suid"
p.update({"device_suid": device_suid})
sql_count += " and d.device_suid=:device_suid"
p_count.update({"device_suid": device_suid})
if pagination:
sql += " order by d.id desc"
if page_no > 0:
sql += " limit :pageno, :pagesize"
p.update({"pageno": (page_no - 1) * page_size, "pagesize": page_size})
count = session.execute(text(sql_count), p_count).scalar()
res = session.execute(text(sql), p)
return to_json_list(res)
res = to_json_list(res)
return res, count

@ -245,6 +245,11 @@ class EnterpriseDeviceRepository(object):
device_dict = {}
if device:
device_dict = {
"id": device.id,
"entity_id": device.entity_id,
"entity_suid": device.entity_suid,
"node_id": device.node_id,
"node_suid": device.node_suid,
"suid": device.suid,
"name": device.name,
"addr": device.addr,
@ -277,12 +282,17 @@ class EnterpriseDeviceRepository(object):
for device in devices:
device_dict = {
"id": device.id,
"entity_id": device.entity_id,
"entity_suid": device.entity_suid,
"node_id": device.node_id,
"node_suid": device.node_suid,
"suid": device.suid,
"name": device.name,
"addr": device.addr,
"device_model": device.device_model,
"param": device.param,
"comment": device.comment,
"classification": device.classification,
}
device_dicts.append(device_dict)
return device_dicts

@ -9,13 +9,13 @@ from website import db_mysql
from website import errors
from website import settings
from website.handler import APIHandler, authenticated
from website.util import md5
from website.util import md5, shortuuid
class ClassificationAddHandler(APIHandler):
"""
添加模型分类
"""
@authenticated
@ -28,13 +28,19 @@ class ClassificationAddHandler(APIHandler):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称过长")
with self.app_mysql.connect() as conn:
cur = conn.execute(text("select id from model_classification where name=:name"), {"name": name})
cur = conn.execute(
text("select id from model_classification where name=:name"),
{"name": name},
)
row = db_mysql.to_json(cur)
if row:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称重复")
conn.execute(
text("""insert into model_classification (name, create_time) values (:name, NOW())"""), {"name": name}
text(
"""insert into model_classification (name, create_time) values (:name, NOW())"""
),
{"name": name},
)
conn.commit()
@ -53,8 +59,10 @@ class ClassificationEditHandler(APIHandler):
if not classification_id or not name:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(text("""update model_classification set name=:name where id=:id"""),
{"name": name, "id": classification_id})
conn.execute(
text("""update model_classification set name=:name where id=:id"""),
{"name": name, "id": classification_id},
)
conn.commit()
@ -87,7 +95,10 @@ class ClassificationDeleteHandler(APIHandler):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(text("""DELETE FROM model_classification WHERE id=:id"""), {"id": classification_id})
conn.execute(
text("""DELETE FROM model_classification WHERE id=:id"""),
{"id": classification_id},
)
conn.commit()
self.finish()
@ -106,8 +117,10 @@ class ListHandler(APIHandler):
result = []
with self.app_mysql.connect() as conn:
sql = "select m.id, m.name, m.model_type, m.default_version, mc.name classification_name, m.update_time " \
"from model m left join model_classification mc on m.classification=mc.id where m.del=0 "
sql = (
"select m.id, m.name, m.model_type, m.default_version, mc.name classification_name, m.update_time "
"from model m left join model_classification mc on m.classification=mc.id where m.del=0 "
)
param = {}
@ -132,18 +145,31 @@ class ListHandler(APIHandler):
data = []
for item in result:
data.append({
"id": item["id"],
"name": item["name"],
"model_type": consts.model_type_map[item["model_type"]],
"classification_name": item["classification_name"],
"default_version": item["default_version"],
"update_time": str(item["update_time"])
})
data.append(
{
"id": item["id"],
"name": item["name"],
"model_type": consts.model_type_map[item["model_type"]],
"classification_name": item["classification_name"],
"default_version": item["default_version"],
"update_time": str(item["update_time"]),
}
)
self.finish({"data": data, "count": count})
class ListSimpleHandler(APIHandler):
@authenticated
def post(self):
with self.app_mysql.connect() as conn:
sql = "select id, name from model where del=0"
cur = conn.execute(text(sql))
res = db_mysql.to_json_list(cur)
self.finish({"result": res})
class AddHandler(APIHandler):
"""
添加模型
@ -152,7 +178,9 @@ class AddHandler(APIHandler):
@authenticated
def post(self):
name = self.get_escaped_argument("name", "")
model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法1002/深度学习
model_type = self.get_int_argument(
"model_type", consts.model_type_machine
) # 1001/经典算法1002/深度学习
classification = self.get_int_argument("classification")
comment = self.get_escaped_argument("comment", "")
@ -167,9 +195,16 @@ class AddHandler(APIHandler):
conn.execute(
text(
"""insert into model (name, model_type, classification, comment, create_time, update_time)
values (:name, :model_type, :classification, :comment, NOW(), NOW())"""),
{"name": name, "model_type": model_type, "classification": classification, "comment": comment}
"""insert into model (suid, name, model_type, classification, comment, create_time, update_time)
values (:suid, :name, :model_type, :classification, :comment, NOW(), NOW())"""
),
{
"suid": shortuuid.ShortUUID().random(10),
"name": name,
"model_type": model_type,
"classification": classification,
"comment": comment,
},
)
conn.commit()
@ -186,7 +221,9 @@ class EditHandler(APIHandler):
def post(self):
mid = self.get_int_argument("id")
name = self.get_escaped_argument("name", "")
model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法1002/深度学习
model_type = self.get_int_argument(
"model_type", consts.model_type_machine
) # 1001/经典算法1002/深度学习
classification = self.get_int_argument("classification")
comment = self.get_escaped_argument("comment", "")
@ -199,9 +236,15 @@ class EditHandler(APIHandler):
"""update model
set name=:name, model_type=:model_type, classification=:classification, comment=:comment,
update_time=NOW()
where id=:id"""),
{"name": name, "model_type": model_type, "classification": classification, "comment": comment,
"id": mid}
where id=:id"""
),
{
"name": name,
"model_type": model_type,
"classification": classification,
"comment": comment,
"id": mid,
},
)
conn.commit()
@ -222,14 +265,16 @@ class InfoHandler(APIHandler):
result = {}
with self.app_mysql.connect() as conn:
cur = conn.execute(
text("""
select
m.name, m.model_type, m.comment, m.update_time,
mc.id as classification_id, mc.name as classification_name
from model m, model_classification mc
where m.id=:id and m.classification=mc.id
"""),
{"id": mid}
text(
"""
select
m.name, m.model_type, m.comment, m.update_time,
mc.id as classification_id, mc.name as classification_name
from model m, model_classification mc
where m.id=:id and m.classification=mc.id
"""
),
{"id": mid},
)
result = db_mysql.to_json(cur)
@ -242,7 +287,7 @@ class InfoHandler(APIHandler):
"classification_id": result["classification_id"],
"classification_name": result["classification_name"],
"comment": result["comment"],
"update_time": str(result["update_time"])
"update_time": str(result["update_time"]),
}
self.finish(data)
@ -306,9 +351,15 @@ class VersionAddHandler(APIHandler):
(model_id, version, comment,model_file, config_file, config_str, create_time, update_time)
values (:model_id, :version, :comment, :model_file, :config_file, :config_str, NOW(), NOW())"""
),
{"model_id": mid, "version": version, "comment": comment, "model_file": model_file,
"config_file": config_file, "config_str": config_str})
{
"model_id": mid,
"version": version,
"comment": comment,
"model_file": model_file,
"config_file": config_file,
"config_str": config_str,
},
)
conn.commit()
self.finish()
@ -349,9 +400,16 @@ class VersionEditHandler(APIHandler):
text(
"update model_version "
"set version=:version, comment=:comment, model_file=:model_file, config_file=:config_file, "
"config_str=:config_str, update_time=NOW() where id=:id"),
{"version": version, "comment": comment, "model_file": model_file, "config_file": config_file,
"config_str": config_str, "id": version_id}
"config_str=:config_str, update_time=NOW() where id=:id"
),
{
"version": version,
"comment": comment,
"model_file": model_file,
"config_file": config_file,
"config_str": config_str,
"id": version_id,
},
)
self.finish()
@ -367,7 +425,7 @@ class VersionListHandler(APIHandler):
> - pageSize, int
- 返回值
```
{
{
"count": 123,
"data": [
{
@ -407,10 +465,10 @@ class VersionListHandler(APIHandler):
order by mv.id desc limit :offset, :limit
"""
),
{"mid": model_id, "offset": (pageNo - 1) * pageSize, "limit": pageSize}
{"mid": model_id, "offset": (pageNo - 1) * pageSize, "limit": pageSize},
)
result = db_mysql.to_json(cur)
result = db_mysql.to_json_list(cur)
# 获取记录数量
count = conn.execute(
@ -418,21 +476,23 @@ class VersionListHandler(APIHandler):
"""
select count(*)
from model_version mv
where mv.del=0
where mv.del=0 and model_id=:mid
"""
),
{"mid": model_id}
{"mid": model_id},
).scalar()
for item in result:
data.append({
"version_id": item["version_id"], # 版本id
"version": item["version"], # 版本号
"path": item["filepath"], # 文件路径
"size": item["filesize"],
"update_time": str(item["update_time"]),
"is_default": item["is_default"] # 是否默认1/默认0/非默认
})
data.append(
{
"version_id": item["version_id"], # 版本id
"version": item["version"], # 版本号
"path": item["filepath"], # 文件路径
"size": item["filesize"],
"update_time": str(item["update_time"]),
"is_default": item["is_default"], # 是否默认1/默认0/非默认
}
)
self.finish({"count": count, "data": data})
@ -475,15 +535,16 @@ class VersionInfoHandler(APIHandler):
"config_file_name": "",
"config_file_size": 0,
"config_file_md5": "",
"config_str": ""
"config_str": "",
}
with self.app_mysql.connect() as conn:
cur = conn.execute(
text(
"""select m.name as model_name, mv.version, mv.comment, mv.model_file, mv.config_file, mv.config_str
from model_version mv, model m where mv.id=:id and mv.model_id=m.id"""),
{"id": version_id}
from model_version mv, model m where mv.id=:id and mv.model_id=m.id"""
),
{"id": version_id},
)
result = db_mysql.to_json(cur)
@ -502,21 +563,31 @@ class VersionInfoHandler(APIHandler):
# 获取文件信息
if model_file:
cur_model_file = conn.execute(
text("select filename, filesize from files where md5_str=:md5_str"), {"md5_str": model_file}
text("select filename, filesize from files where md5_str=:md5_str"),
{"md5_str": model_file},
)
model_file_info = db_mysql.to_json(cur_model_file)
response["model_file_name"] = model_file_info["filename"]
response["model_file_size"] = model_file_info["filesize"]
response["model_file_name"] = (
model_file_info["filename"] if model_file_info else ""
)
response["model_file_size"] = (
model_file_info["filesize"] if model_file_info else 0
)
if config_file:
cur_config_file = conn.execute(
text("select filename, filesize from files where md5_str=:md5_str"), {"md5_str": config_file}
text("select filename, filesize from files where md5_str=:md5_str"),
{"md5_str": config_file},
)
config_file_info = db_mysql.to_json(cur_config_file)
response["config_file_name"] = config_file_info["filename"]
response["config_file_size"] = config_file_info["filesize"]
response["config_file_name"] = (
config_file_info["filename"] if config_file_info else ""
)
response["config_file_size"] = (
config_file_info["filesize"] if config_file_info else 0
)
response["config_str"] = result["config_str"]
response["config_str"] = self.unescape_string(result["config_str"])
self.finish(response)
@ -543,11 +614,13 @@ class VersionSetDefaultHandler(APIHandler):
with self.app_mysql.connect() as conn:
conn.execute(
text("update model_version set is_default=0 where model_id=:model_id"),
{"model_id": model_id})
{"model_id": model_id},
)
conn.execute(
text("update model_version set is_default=1 where id=:id"),
{"id": version_id})
{"id": version_id},
)
conn.commit()
self.finish()
@ -573,8 +646,12 @@ class VersionDeleteHandler(APIHandler):
row = {}
# 获取模型对应的model_file, config_file使用model_file, config_file删除对应的存储文件
with self.app_mysql.connect() as conn:
cur = conn.execute(text("select model_id, model_file, config_file from model_version where id=:id"),
{"id": version_id})
cur = conn.execute(
text(
"select model_id, model_file, config_file from model_version where id=:id"
),
{"id": version_id},
)
row = db_mysql.to_json(cur)
if not row:
@ -585,20 +662,29 @@ class VersionDeleteHandler(APIHandler):
# 清空模型默认版本
conn.execute(
text("update model set default_version='' where id=:id"), {"id": model_id}
text("update model set default_version='' where id=:id"),
{"id": model_id},
)
# 删除文件
try:
conn.execute(text("delete from files where md5_str=:md5_str"), {"md5_str": model_file})
conn.execute(text("delete from files where md5_str=:md5_str"), {"md5_str": config_file})
conn.execute(
text("delete from files where md5_str=:md5_str"),
{"md5_str": model_file},
)
conn.execute(
text("delete from files where md5_str=:md5_str"),
{"md5_str": config_file},
)
os.remove(settings.file_upload_dir + "model/" + model_file)
os.remove(settings.file_upload_dir + "model/" + config_file)
except Exception as e:
logging.info(e)
conn.execute(text("delete from model_version where id=:id"), {"id": version_id})
conn.execute(
text("delete from model_version where id=:id"), {"id": version_id}
)
conn.commit()

@ -7,6 +7,7 @@ handlers = [
("/model/classification/delete", handler.ClassificationDeleteHandler),
("/model/list", handler.ListHandler),
("/model/list/simple", handler.ListSimpleHandler),
("/model/add", handler.AddHandler),
("/model/edit", handler.EditHandler),
("/model/info", handler.InfoHandler),

@ -11,6 +11,7 @@ from website import errors
from website import settings
from website.handler import APIHandler, authenticated
class ListHandler(APIHandler):
"""
- 描述 模型运行库列表
@ -35,30 +36,31 @@ class ListHandler(APIHandler):
}
```
"""
@authenticated
def post(self):
pageNo = self.get_int_argument("pageNo", 1)
pageSize = self.get_int_argument("pageSize", consts.PAGE_SIZE)
name = self.get_escaped_argument("name", "")
result = []
count = 0
with self.app_mysql.connect() as conn:
sql = "select id, name, create_time, update_time from model_hub where 1=1"
sql = "select id, name, path, create_time, update_time from model_hub where 1=1"
param = {}
sql_count = "select count(id) from model where 1=1"
sql_count = "select count(id) from model_hub where 1=1"
param_count = {}
if name:
sql += "and m.name like :name"
sql += " and name like :name"
param["name"] = "%{}%".format(name)
sql_count += "and m.name like :name"
sql_count += " and name like :name"
param_count["name"] = "%{}%".format(name)
sql += " order by m.id desc limit :pageSize offset :offset"
sql += " order by id desc limit :pageSize offset :offset"
param["pageSize"] = pageSize
param["offset"] = (pageNo - 1) * pageSize
@ -69,55 +71,63 @@ class ListHandler(APIHandler):
data = []
for item in result:
data.append({
"id": item["id"],
"name": item["name"],
"create_time": item["create_time"].strftime("%Y-%m-%d %H:%M:%S"),
"update_time": item["update_time"].strftime("%Y-%m-%d %H:%M:%S")
})
data.append(
{
"id": item["id"],
"name": item["name"],
"path": item["path"],
"create_time": item["create_time"].strftime("%Y-%m-%d %H:%M:%S"),
"update_time": item["update_time"].strftime("%Y-%m-%d %H:%M:%S"),
}
)
self.finish({"count": count, "data": data})
class SyncHandler(APIHandler):
"""
- 描述 查询docker registry中的镜像
- 请求方式post
- 请求参数
> - host, string, ip地址
> - port, int, 端口
- 返回值
- 描述 查询docker registry中的镜像
- 请求方式post
- 请求参数
> - host, string, ip地址
> - port, int, 端口
- 返回值
```
{
"data": [
"xxx", # docker registry中docker images的地址
"xxx",
...
]
}
```
{
"data": [
"xxx", # docker registry中docker images的地址
"xxx",
...
]
}
```
"""
@authenticated
def post(self):
host = self.get_escaped_argument("host", "")
port = self.get_int_argument("port")
if not host or not port:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "host and port must be provided.")
raise errors.HTTPAPIError(
errors.ERROR_BAD_REQUEST, "host and port must be provided."
)
images = []
# 查询docker registry中的镜像
repositories = requests.get("http://{}:{}/v2/_catalog".format(host, port)).json()["repositories"]
repositories = requests.get(
"http://{}:{}/v2/_catalog".format(host, port)
).json()["repositories"]
for repository in repositories:
# 查询docker registry中的镜像的tag
tags = requests.get("http://{}:{}/v2/{}/tags/list".format(host, port, repository)).json()["tags"]
tags = requests.get(
"http://{}:{}/v2/{}/tags/list".format(host, port, repository)
).json()["tags"]
for tag in tags:
image_name = "{}:{}/{}:{}".format(host, port, repository, tag)
images.append(image_name)
self.finish({"data": images})
class AddHandler(APIHandler):
"""
- 描述 新建模型运行库
@ -130,6 +140,7 @@ class AddHandler(APIHandler):
> - comment, string, 备注
- 返回值
"""
@authenticated
def post(self):
name = self.get_escaped_argument("name", "")
@ -138,12 +149,24 @@ class AddHandler(APIHandler):
path = self.get_escaped_argument("path", "")
comment = self.get_escaped_argument("comment", "")
if not name or not host or not port:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "name and host and port must be provided.")
raise errors.HTTPAPIError(
errors.ERROR_BAD_REQUEST, "name and host and port must be provided."
)
with self.app_mysql.connect() as conn:
conn.execute(text("""insert into model_hub (name, host, port, path, comment, create_time, update_time)
values (:name, :host, :port, :path, :comment, NOW(), NOW())"""),
{"name": name, "host": host, "port": port, "path": path, "comment": comment})
conn.execute(
text(
"""insert into model_hub (name, host, port, path, comment, create_time, update_time)
values (:name, :host, :port, :path, :comment, NOW(), NOW())"""
),
{
"name": name,
"host": host,
"port": port,
"path": path,
"comment": comment,
},
)
conn.commit()
@ -163,6 +186,7 @@ class EditHandler(APIHandler):
> - comment, string, 备注
- 返回值
"""
@authenticated
def post(self):
id = self.get_int_argument("id")
@ -174,8 +198,20 @@ class EditHandler(APIHandler):
if not id or not name or not host or not port or path:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "parameter error")
with self.app_mysql.connect() as conn:
conn.execute(text("""update model_hub set name=:name, host=:host, port=:port, path=:path, comment=:comment, update_time=NOW()
where id=:id"""), {"id": id, "name": name, "host": host, "port": port, "path": path, "comment": comment})
conn.execute(
text(
"""update model_hub set name=:name, host=:host, port=:port, path=:path, comment=:comment, update_time=NOW()
where id=:id"""
),
{
"id": id,
"name": name,
"host": host,
"port": port,
"path": path,
"comment": comment,
},
)
conn.commit()
self.finish()
@ -189,14 +225,15 @@ class InfoHandler(APIHandler):
- 返回值
```
{
"name": "xxx",
"host": "xxx",
"port": 123,
"path": "xxx",
"comment": "xxx",
"name": "xxx",
"host": "xxx",
"port": 123,
"path": "xxx",
"comment": "xxx",
}
```
"""
@authenticated
def post(self):
hid = self.get_int_argument("id")
@ -205,11 +242,18 @@ class InfoHandler(APIHandler):
result = {}
with self.app_mysql.connect() as conn:
cur = conn.execute(text("""select name, host, port, path, comment from model_hub where id=:id"""), {"id": hid})
cur = conn.execute(
text(
"""select name, host, port, path, comment from model_hub where id=:id"""
),
{"id": hid},
)
result = db_mysql.to_json(cur)
if not result:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "model hub not found")
self.finish({"data": result})
raise errors.HTTPAPIError(
errors.ERROR_BAD_REQUEST, "model hub not found"
)
self.finish(result)
class DeleteHandler(APIHandler):
@ -220,6 +264,7 @@ class DeleteHandler(APIHandler):
> - id, int
- 返回值
"""
@authenticated
def post(self):
hid = self.get_int_argument("id")

@ -35,14 +35,14 @@ class ListHandler(APIHandler):
"""
@authenticated
def post(self):
pageNo = self.get_argument("pageNo", 1)
pageSize = self.get_argument("pageSize", 10)
entity_id = self.get_argument("entity_id")
pageNo = self.get_int_argument("pageNo", 1)
pageSize = self.get_int_argument("pageSize", 10)
entity_id = self.get_int_argument("entity_id")
if not entity_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数错误")
db_busimodel = DB_BusiModel.EnterpriseBusiModelRepository()
result = db_busimodel.list_enterprise_busi_model(entity_id, pageNo, pageSize)
result = db_busimodel.list_enterprise_busi_model(entity_id=entity_id, page_no=pageNo, page_size=pageSize)
self.finish({"count": result["count"], "data": result["data"]})
@ -63,7 +63,7 @@ class AddHandler(APIHandler):
"""
@authenticated
def post(self):
entity_id = self.get_argument("entity_id")
entity_id = self.get_int_argument("entity_id")
name = self.get_escaped_argument("name", "")
comment = self.get_escaped_argument("comment", "")
basemodel_ids = self.get_escaped_argument("basemodel_ids", "")
@ -131,7 +131,7 @@ class InfoHandler(APIHandler):
"""
@authenticated
def post(self):
busimodel_id = self.get_argument("id")
busimodel_id = self.get_int_argument("id")
if not busimodel_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数错误")

@ -2,17 +2,21 @@
import json
import logging
import random
from sqlalchemy import text
from website import consts
from website import db_mysql, errors
from website.db.enterprise_device import enterprise_device as DB_Device
from website.db.alg_model.alg_model import ModelRepositry as DB_AlgModel
from website.db.device_classification import (
device_classification as DB_DeviceClassification,
)
from website.db.enterprise_busi_model import enterprise_busi_model_node_device as DB_BusiModelNodeDevice
from website.db.enterprise_busi_model import (
enterprise_busi_model as DB_BusiModel,
)
from website.db.enterprise_device import enterprise_device as DB_Device
from website.handler import APIHandler, authenticated
from website.util import shortuuid
from website import consts
class DeviceClassificationAddHandler(APIHandler):
@ -228,8 +232,6 @@ class DeviceEditHandler(APIHandler):
class DeviceDeleteHandler(APIHandler):
"""
### /enterprise/entity/nodes/device/delete
- 描述企业节点删除设备
- 请求方式post
- 请求参数
@ -338,6 +340,115 @@ class DeviceInfoHandler(APIHandler):
self.finish(device)
class DeviceBasemodelListHandler(APIHandler):
"""
- 描述企业节点节点信息 -> 设备列表 -> 基础模型配置 -> 模型列表
- 请求方式post
- 请求参数
> - pageNo
> - pageSize
> - device_id, int, 设备id
- 返回值
```
{
"count": 123,
"data": [
{
"busi_model": "xxx", # 业务模型的名称
"base_model": [
{
"model_id": 123, # 基础模型id
"model_name": "xxx" # 基础模型name
"model_version": "xxx", # 基础模型的版本
"model_hub_image": "xxx", # 运行库镜像
},
...
]
},
...
]
}
```
"""
@authenticated
def post(self):
device_id = self.get_int_argument("device_id")
if not device_id:
raise errors.HTTPAPIError(
errors.ERROR_BAD_REQUEST, "企业节点或设备不能为空"
)
pageNo = self.get_int_argument("pageNo", 1)
pageSize = self.get_int_argument("pageSize", 10)
db_busi_model = DB_BusiModelNodeDevice.EnterpriseBusiModelNodeDeviceRepository()
busi_models, count = db_busi_model.get_busi_model_by_device(device_id=device_id, pagination=True,
page_no=pageNo,
page_size=pageSize)
for item in busi_models:
busi_model_id = item["busi_model_id"]
busi_model_name = item["name"]
base_model_list = json.loads(item["base_models"])
for base_model in base_model_list:
base_model_id = base_model["id"]
base_model_suid = base_model["suid"]
base_model_name = base_model["name"]
self.finish()
class DeviceBaseModelCustomConfigHandler(APIHandler):
"""
- 描述企业节点节点信息 -> 设备列表 -> 基础模型配置 -> 基础模型参数配置
- 请求方式post
- 请求参数
> - device_id, int, 设备id
> - node_id, int, 节点id
> - base_model_id, int, 基础模型id
> - busi_conf_file, string, 业务参数配置文件md5
> - busi_conf_str, string, 业务参数配置json字符串eg:'{"key1": "value1", "key2": "value2"}'
> - model_hub_image, string, 运行库地址
> - model_conf_file, string, 模型参数配置文件md5
> - model_conf_str, string, 模型参数配置json字符串eg:'{"key1": "value1", "key2": "value2"}'
- 返回值
"""
@authenticated
def post(self):
device_id = self.get_int_argument("device_id")
node_id = self.get_int_argument("node_id")
busi_model_id = self.get_int_argument("busi_model_id")
base_model_id = self.get_int_argument("base_model_id")
busi_conf_file = self.get_escaped_argument("busi_conf_file")
busi_conf_str = self.get_escaped_argument("busi_conf_str")
model_hub_image = self.get_escaped_argument("model_hub_image")
model_conf_file = self.get_escaped_argument("model_conf_file")
model_conf_str = self.get_escaped_argument("model_conf_str")
db_device = DB_Device.EnterpriseDeviceRepository()
device = db_device.get_device(device_id=device_id)
if not device:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "device not exist")
if device["node_id"] != node_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "device not belong to this node")
db_alg_model = DB_AlgModel()
base_model = db_alg_model.get_model_by_id(base_model_id)
if not base_model:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "base model not exist")
device_suid = device["suid"]
node_suid = device["node_suid"]
entity_id = device["entity_id"]
entity_suid = device["entity_suid"]
db_busi_model = DB_BusiModel.EnterpriseBusiModelRepository()
busi_model = db_busi_model.get_busi_model_by_id(busi_model_id)
busi_model_suid = busi_model.suid
self.finish()
class StatusListHandler(APIHandler):
"""
- 描述设备状态列表
@ -379,26 +490,32 @@ class StatusListHandler(APIHandler):
if not entity_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "企业节点不能为空")
if status not in consts.device_status_map:
raise errors.HTTPAPIError(
errors.ERROR_BAD_REQUEST, "状态参数错误")
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "状态参数错误")
db_device = DB_Device.EnterpriseDeviceRepository()
res = db_device.list_entity_devices(entity_id=entity_id, pageNo=pageNo, pageSize=pageSize,
classification=classification, status=status)
res = db_device.list_entity_devices(
entity_id=entity_id,
pageNo=pageNo,
pageSize=pageSize,
classification=classification,
status=status,
)
count = res["count"]
devices = res["devices"]
data = []
for item in devices:
data.append({
"id": item.id,
"name": item.name,
"status": item.status,
"cpu": random.randint(20, 30),
"mem": random.randint(20, 30),
"storage": random.randint(20, 30),
"gpu": random.randint(20, 30),
})
data.append(
{
"id": item.id,
"name": item.name,
"status": item.status,
"cpu": random.randint(20, 30),
"mem": random.randint(20, 30),
"storage": random.randint(20, 30),
"gpu": random.randint(20, 30),
}
)
self.finish({"count": count, "data": data})
@ -418,11 +535,10 @@ class StatusInfoHandler(APIHandler):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "设备不存在")
res = res[0]
db_busi_model = DB_BusiModelNodeDevice.EnterpriseBusiModelNodeDeviceRepository()
busi_models = db_busi_model.get_busi_model_by_device(device_id=device_id)
busi_models, _ = db_busi_model.get_busi_model_by_device(device_id=device_id)
for busi_model in busi_models:
base_model_ids = json.loads(busi_model["base_models"])
self.finish()
@ -431,5 +547,4 @@ class StatusLogHandler(APIHandler):
@authenticated
def post(self):
self.finish()

@ -15,7 +15,11 @@ handlers = [
("/enterprise/entity/nodes/device/list", handler.DeviceListHandler),
("/enterprise/entity/nodes/device/list/simple", handler.DeviceListSimpleHandler),
("/enterprise/entity/nodes/device/info", handler.DeviceInfoHandler),
("/enterprise/entity/nodes/device/basemodel/list", handler.DeviceBasemodelListHandler),
(
"/enterprise/entity/nodes/device/basemodel/custom/config",
handler.DeviceBaseModelCustomConfigHandler,
),
("/enterprise/device/status/list", handler.StatusListHandler),
("/enterprise/device/status/info", handler.StatusInfoHandler),
("/enterprise/device/status/log", handler.StatusLogHandler),

Loading…
Cancel
Save