diff --git a/website/db/alg_model/alg_model.py b/website/db/alg_model/alg_model.py index 267ced4..3eeb20f 100644 --- a/website/db/alg_model/alg_model.py +++ b/website/db/alg_model/alg_model.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- -from sqlalchemy.ext.declarative import declarative_base +import logging +from typing import Union + from sqlalchemy import Column, Integer, String, DateTime, func -from typing import Any, Dict, List, Optional, Tuple, Union +from sqlalchemy.ext.declarative import declarative_base from website.db_mysql import get_session -from website.util import shortuuid Base = declarative_base() @@ -63,6 +64,29 @@ class ModelRepositry(object): return model + def get_model_dict_by_id(self, model_id: int) -> dict: + with get_session() as session: + logging.info(f"model id is : {model_id}") + model = session.query(Model).filter(Model.id == model_id).first() + # if not model: + # return {} + + model_dict = { + 'id': model.id, + 'suid': model.suid, + 'name': model.name, + 'model_type': model.model_type, + 'classification': model.classification, + 'comment': model.comment, + 'default_version': model.default_version, + 'delete': model.delete, + 'create_time': model.create_time, + 'update_time': model.update_time + } + + logging.info(f"model dict is : {model_dict}") + return model_dict + def get_model_by_ids(self, model_ids: list) -> list: with get_session() as session: models = session.query(Model).filter(Model.id.in_(model_ids)).all() @@ -71,7 +95,6 @@ class ModelRepositry(object): return models - def get_model_count(self) -> int: with get_session() as session: return session.query(Model).count() diff --git a/website/db/enterprise_busi_model/enterprise_busi_model.py b/website/db/enterprise_busi_model/enterprise_busi_model.py index 0d76c65..a8b0811 100644 --- a/website/db/enterprise_busi_model/enterprise_busi_model.py +++ b/website/db/enterprise_busi_model/enterprise_busi_model.py @@ -67,7 +67,10 @@ class EnterpriseBusiModel(Base): def __repr__(self): return f"EnterpriseBusiModel(id={self.id}, suid='{self.suid}', name='{self.name}')" - + def __init__(self, **kwargs): + valid_columns = {col.name for col in self.__table__.columns} + filtered_data = {key: value for key, value in kwargs.items() if key in valid_columns} + super().__init__(**filtered_data) class EnterpriseBusiModelRepository(object): @@ -82,15 +85,19 @@ class EnterpriseBusiModelRepository(object): data['entity_suid'] = entity_suid base_model_ids = [int(model_id) for model_id in data['basemodel_ids'].split(',')] - base_model_db = DB_alg_model.ModelRepositry() + logging.info("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@") + logging.info(base_model_ids) base_model = [] for base_model_id in base_model_ids: + base_model_db = DB_alg_model.ModelRepositry() # base_model_suid = base_model_db.get_suid(base_model_id) - base_model_info = base_model_db.get_model_by_id(base_model_id) + base_model_info = base_model_db.get_model_dict_by_id(base_model_id) + logging.info("#####################") + logging.info(base_model_info) base_model.append({ 'id': base_model_id, - 'suid': base_model_info.suid, - 'name': base_model_info.name, + 'suid': base_model_info["suid"], + 'name': base_model_info["name"], }) data['base_models'] = json.dumps(base_model) new_data = copy.copy(data) @@ -104,13 +111,16 @@ class EnterpriseBusiModelRepository(object): def edit_busi_model(self, data: Dict): base_model_ids = [int(model_id) for model_id in data['basemodel_ids'].split(',')] - base_model_db = DB_alg_model.ModelRepositry() + base_model = [] for base_model_id in base_model_ids: - base_model_suid = base_model_db.get_suid(base_model_id) + base_model_db = DB_alg_model.ModelRepositry() + # base_model_suid = base_model_db.get_suid(base_model_id) + base_model_info = base_model_db.get_model_dict_by_id(base_model_id) base_model.append({ - 'id': base_model_id, - 'suid': base_model_suid + "id": base_model_id, + "suid": base_model_info["suid"], + "name": base_model_info["name"], }) data['base_models'] = json.dumps(base_model) @@ -197,9 +207,10 @@ class EnterpriseBusiModelNodeRepository(object): node_db = DB_Node.EnterpriseNodeRepository() node = node_db.get_node_by_id(node_id) node_suid = node["suid"] + entity_suid = node["entity_suid"] model_node = EnterpriseBusiModelNode( suid=shortuuid.ShortUUID().random(10), - entity_suid=data['entity_suid'], + entity_suid=entity_suid, busi_model_id=data['busi_model_id'], busi_model_suid=data['busi_model_suid'], node_id=node_id, diff --git a/website/db/enterprise_busi_model/enterprise_busi_model_node_device.py b/website/db/enterprise_busi_model/enterprise_busi_model_node_device.py index c1bfe19..8812c63 100644 --- a/website/db/enterprise_busi_model/enterprise_busi_model_node_device.py +++ b/website/db/enterprise_busi_model/enterprise_busi_model_node_device.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- import logging -from typing import List +from typing import List, Union, Optional, Tuple from sqlalchemy import Column, Integer, String, DateTime, func, text from sqlalchemy.ext.declarative import declarative_base -from website.db_mysql import get_session, to_json_list +from website.db_mysql import get_session, to_json_list, Row """ CREATE TABLE `enterprise_busi_model_node_device` ( @@ -97,7 +97,15 @@ class EnterpriseBusiModelNodeDeviceRepository(object): ) return to_json_list(res) - def get_busi_model_by_device(self, device_id: int = 0, device_suid: str = "") -> list: + def get_busi_model_by_device(self, device_id: int = 0, device_suid: str = "", pagination: bool = False, + page_no: int = 0, + page_size: int = 0) -> Union[list, Tuple[Optional[List[Row]], int]]: + if not device_id and not device_suid: + logging.error("get_busi_model_by_device error: device_id and device_suid is null") + return [] + + res = [] + count = 0 with get_session() as session: sql = """ select d.busi_model_id, m.name, m.base_models @@ -105,13 +113,34 @@ class EnterpriseBusiModelNodeDeviceRepository(object): where d.busi_model_id=m.id""" p = {} + sql_count = """ + select count(1) from enterprise_busi_model_node_device d, enterprise_busi_model m where d.busi_model_id=m.id + """ + p_count = {} + if device_id: sql += " and d.device_id=:device_id" p.update({"device_id": device_id}) + sql_count += " and d.device_id=:device_id" + p_count.update({"device_id": device_id}) + if device_suid: sql += " and d.device_suid=:device_suid" p.update({"device_suid": device_suid}) + sql_count += " and d.device_suid=:device_suid" + p_count.update({"device_suid": device_suid}) + + if pagination: + sql += " order by d.id desc" + + if page_no > 0: + sql += " limit :pageno, :pagesize" + p.update({"pageno": (page_no - 1) * page_size, "pagesize": page_size}) + + count = session.execute(text(sql_count), p_count).scalar() + res = session.execute(text(sql), p) - return to_json_list(res) + res = to_json_list(res) + return res, count diff --git a/website/db/enterprise_device/enterprise_device.py b/website/db/enterprise_device/enterprise_device.py index fa0eb6e..9bab2ec 100644 --- a/website/db/enterprise_device/enterprise_device.py +++ b/website/db/enterprise_device/enterprise_device.py @@ -245,6 +245,11 @@ class EnterpriseDeviceRepository(object): device_dict = {} if device: device_dict = { + "id": device.id, + "entity_id": device.entity_id, + "entity_suid": device.entity_suid, + "node_id": device.node_id, + "node_suid": device.node_suid, "suid": device.suid, "name": device.name, "addr": device.addr, @@ -277,12 +282,17 @@ class EnterpriseDeviceRepository(object): for device in devices: device_dict = { "id": device.id, + "entity_id": device.entity_id, + "entity_suid": device.entity_suid, + "node_id": device.node_id, + "node_suid": device.node_suid, "suid": device.suid, "name": device.name, "addr": device.addr, "device_model": device.device_model, "param": device.param, "comment": device.comment, + "classification": device.classification, } device_dicts.append(device_dict) return device_dicts diff --git a/website/handlers/alg_model/handler.py b/website/handlers/alg_model/handler.py index 429f0fe..553d3a7 100644 --- a/website/handlers/alg_model/handler.py +++ b/website/handlers/alg_model/handler.py @@ -9,13 +9,13 @@ from website import db_mysql from website import errors from website import settings from website.handler import APIHandler, authenticated -from website.util import md5 +from website.util import md5, shortuuid class ClassificationAddHandler(APIHandler): """ 添加模型分类 - + """ @authenticated @@ -28,13 +28,19 @@ class ClassificationAddHandler(APIHandler): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称过长") with self.app_mysql.connect() as conn: - cur = conn.execute(text("select id from model_classification where name=:name"), {"name": name}) + cur = conn.execute( + text("select id from model_classification where name=:name"), + {"name": name}, + ) row = db_mysql.to_json(cur) if row: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称重复") conn.execute( - text("""insert into model_classification (name, create_time) values (:name, NOW())"""), {"name": name} + text( + """insert into model_classification (name, create_time) values (:name, NOW())""" + ), + {"name": name}, ) conn.commit() @@ -53,8 +59,10 @@ class ClassificationEditHandler(APIHandler): if not classification_id or not name: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") with self.app_mysql.connect() as conn: - conn.execute(text("""update model_classification set name=:name where id=:id"""), - {"name": name, "id": classification_id}) + conn.execute( + text("""update model_classification set name=:name where id=:id"""), + {"name": name, "id": classification_id}, + ) conn.commit() @@ -87,7 +95,10 @@ class ClassificationDeleteHandler(APIHandler): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") with self.app_mysql.connect() as conn: - conn.execute(text("""DELETE FROM model_classification WHERE id=:id"""), {"id": classification_id}) + conn.execute( + text("""DELETE FROM model_classification WHERE id=:id"""), + {"id": classification_id}, + ) conn.commit() self.finish() @@ -106,8 +117,10 @@ class ListHandler(APIHandler): result = [] with self.app_mysql.connect() as conn: - sql = "select m.id, m.name, m.model_type, m.default_version, mc.name classification_name, m.update_time " \ - "from model m left join model_classification mc on m.classification=mc.id where m.del=0 " + sql = ( + "select m.id, m.name, m.model_type, m.default_version, mc.name classification_name, m.update_time " + "from model m left join model_classification mc on m.classification=mc.id where m.del=0 " + ) param = {} @@ -132,18 +145,31 @@ class ListHandler(APIHandler): data = [] for item in result: - data.append({ - "id": item["id"], - "name": item["name"], - "model_type": consts.model_type_map[item["model_type"]], - "classification_name": item["classification_name"], - "default_version": item["default_version"], - "update_time": str(item["update_time"]) - }) + data.append( + { + "id": item["id"], + "name": item["name"], + "model_type": consts.model_type_map[item["model_type"]], + "classification_name": item["classification_name"], + "default_version": item["default_version"], + "update_time": str(item["update_time"]), + } + ) self.finish({"data": data, "count": count}) +class ListSimpleHandler(APIHandler): + @authenticated + def post(self): + with self.app_mysql.connect() as conn: + sql = "select id, name from model where del=0" + cur = conn.execute(text(sql)) + res = db_mysql.to_json_list(cur) + + self.finish({"result": res}) + + class AddHandler(APIHandler): """ 添加模型 @@ -152,7 +178,9 @@ class AddHandler(APIHandler): @authenticated def post(self): name = self.get_escaped_argument("name", "") - model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法,1002/深度学习 + model_type = self.get_int_argument( + "model_type", consts.model_type_machine + ) # 1001/经典算法,1002/深度学习 classification = self.get_int_argument("classification") comment = self.get_escaped_argument("comment", "") @@ -167,9 +195,16 @@ class AddHandler(APIHandler): conn.execute( text( - """insert into model (name, model_type, classification, comment, create_time, update_time) - values (:name, :model_type, :classification, :comment, NOW(), NOW())"""), - {"name": name, "model_type": model_type, "classification": classification, "comment": comment} + """insert into model (suid, name, model_type, classification, comment, create_time, update_time) + values (:suid, :name, :model_type, :classification, :comment, NOW(), NOW())""" + ), + { + "suid": shortuuid.ShortUUID().random(10), + "name": name, + "model_type": model_type, + "classification": classification, + "comment": comment, + }, ) conn.commit() @@ -186,7 +221,9 @@ class EditHandler(APIHandler): def post(self): mid = self.get_int_argument("id") name = self.get_escaped_argument("name", "") - model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法,1002/深度学习 + model_type = self.get_int_argument( + "model_type", consts.model_type_machine + ) # 1001/经典算法,1002/深度学习 classification = self.get_int_argument("classification") comment = self.get_escaped_argument("comment", "") @@ -199,9 +236,15 @@ class EditHandler(APIHandler): """update model set name=:name, model_type=:model_type, classification=:classification, comment=:comment, update_time=NOW() - where id=:id"""), - {"name": name, "model_type": model_type, "classification": classification, "comment": comment, - "id": mid} + where id=:id""" + ), + { + "name": name, + "model_type": model_type, + "classification": classification, + "comment": comment, + "id": mid, + }, ) conn.commit() @@ -222,14 +265,16 @@ class InfoHandler(APIHandler): result = {} with self.app_mysql.connect() as conn: cur = conn.execute( - text(""" - select - m.name, m.model_type, m.comment, m.update_time, - mc.id as classification_id, mc.name as classification_name - from model m, model_classification mc - where m.id=:id and m.classification=mc.id - """), - {"id": mid} + text( + """ + select + m.name, m.model_type, m.comment, m.update_time, + mc.id as classification_id, mc.name as classification_name + from model m, model_classification mc + where m.id=:id and m.classification=mc.id + """ + ), + {"id": mid}, ) result = db_mysql.to_json(cur) @@ -242,7 +287,7 @@ class InfoHandler(APIHandler): "classification_id": result["classification_id"], "classification_name": result["classification_name"], "comment": result["comment"], - "update_time": str(result["update_time"]) + "update_time": str(result["update_time"]), } self.finish(data) @@ -306,9 +351,15 @@ class VersionAddHandler(APIHandler): (model_id, version, comment,model_file, config_file, config_str, create_time, update_time) values (:model_id, :version, :comment, :model_file, :config_file, :config_str, NOW(), NOW())""" ), - - {"model_id": mid, "version": version, "comment": comment, "model_file": model_file, - "config_file": config_file, "config_str": config_str}) + { + "model_id": mid, + "version": version, + "comment": comment, + "model_file": model_file, + "config_file": config_file, + "config_str": config_str, + }, + ) conn.commit() self.finish() @@ -349,9 +400,16 @@ class VersionEditHandler(APIHandler): text( "update model_version " "set version=:version, comment=:comment, model_file=:model_file, config_file=:config_file, " - "config_str=:config_str, update_time=NOW() where id=:id"), - {"version": version, "comment": comment, "model_file": model_file, "config_file": config_file, - "config_str": config_str, "id": version_id} + "config_str=:config_str, update_time=NOW() where id=:id" + ), + { + "version": version, + "comment": comment, + "model_file": model_file, + "config_file": config_file, + "config_str": config_str, + "id": version_id, + }, ) self.finish() @@ -367,7 +425,7 @@ class VersionListHandler(APIHandler): > - pageSize, int - 返回值: ``` - { + { "count": 123, "data": [ { @@ -407,10 +465,10 @@ class VersionListHandler(APIHandler): order by mv.id desc limit :offset, :limit """ ), - {"mid": model_id, "offset": (pageNo - 1) * pageSize, "limit": pageSize} + {"mid": model_id, "offset": (pageNo - 1) * pageSize, "limit": pageSize}, ) - result = db_mysql.to_json(cur) + result = db_mysql.to_json_list(cur) # 获取记录数量 count = conn.execute( @@ -418,21 +476,23 @@ class VersionListHandler(APIHandler): """ select count(*) from model_version mv - where mv.del=0 + where mv.del=0 and model_id=:mid """ ), - {"mid": model_id} + {"mid": model_id}, ).scalar() for item in result: - data.append({ - "version_id": item["version_id"], # 版本id - "version": item["version"], # 版本号 - "path": item["filepath"], # 文件路径 - "size": item["filesize"], - "update_time": str(item["update_time"]), - "is_default": item["is_default"] # 是否默认,1/默认,0/非默认 - }) + data.append( + { + "version_id": item["version_id"], # 版本id + "version": item["version"], # 版本号 + "path": item["filepath"], # 文件路径 + "size": item["filesize"], + "update_time": str(item["update_time"]), + "is_default": item["is_default"], # 是否默认,1/默认,0/非默认 + } + ) self.finish({"count": count, "data": data}) @@ -475,15 +535,16 @@ class VersionInfoHandler(APIHandler): "config_file_name": "", "config_file_size": 0, "config_file_md5": "", - "config_str": "" + "config_str": "", } with self.app_mysql.connect() as conn: cur = conn.execute( text( """select m.name as model_name, mv.version, mv.comment, mv.model_file, mv.config_file, mv.config_str - from model_version mv, model m where mv.id=:id and mv.model_id=m.id"""), - {"id": version_id} + from model_version mv, model m where mv.id=:id and mv.model_id=m.id""" + ), + {"id": version_id}, ) result = db_mysql.to_json(cur) @@ -502,21 +563,31 @@ class VersionInfoHandler(APIHandler): # 获取文件信息 if model_file: cur_model_file = conn.execute( - text("select filename, filesize from files where md5_str=:md5_str"), {"md5_str": model_file} + text("select filename, filesize from files where md5_str=:md5_str"), + {"md5_str": model_file}, ) model_file_info = db_mysql.to_json(cur_model_file) - response["model_file_name"] = model_file_info["filename"] - response["model_file_size"] = model_file_info["filesize"] + response["model_file_name"] = ( + model_file_info["filename"] if model_file_info else "" + ) + response["model_file_size"] = ( + model_file_info["filesize"] if model_file_info else 0 + ) if config_file: cur_config_file = conn.execute( - text("select filename, filesize from files where md5_str=:md5_str"), {"md5_str": config_file} + text("select filename, filesize from files where md5_str=:md5_str"), + {"md5_str": config_file}, ) config_file_info = db_mysql.to_json(cur_config_file) - response["config_file_name"] = config_file_info["filename"] - response["config_file_size"] = config_file_info["filesize"] + response["config_file_name"] = ( + config_file_info["filename"] if config_file_info else "" + ) + response["config_file_size"] = ( + config_file_info["filesize"] if config_file_info else 0 + ) - response["config_str"] = result["config_str"] + response["config_str"] = self.unescape_string(result["config_str"]) self.finish(response) @@ -543,11 +614,13 @@ class VersionSetDefaultHandler(APIHandler): with self.app_mysql.connect() as conn: conn.execute( text("update model_version set is_default=0 where model_id=:model_id"), - {"model_id": model_id}) + {"model_id": model_id}, + ) conn.execute( text("update model_version set is_default=1 where id=:id"), - {"id": version_id}) + {"id": version_id}, + ) conn.commit() self.finish() @@ -573,8 +646,12 @@ class VersionDeleteHandler(APIHandler): row = {} # 获取模型对应的model_file, config_file,使用model_file, config_file删除对应的存储文件 with self.app_mysql.connect() as conn: - cur = conn.execute(text("select model_id, model_file, config_file from model_version where id=:id"), - {"id": version_id}) + cur = conn.execute( + text( + "select model_id, model_file, config_file from model_version where id=:id" + ), + {"id": version_id}, + ) row = db_mysql.to_json(cur) if not row: @@ -585,20 +662,29 @@ class VersionDeleteHandler(APIHandler): # 清空模型默认版本 conn.execute( - text("update model set default_version='' where id=:id"), {"id": model_id} + text("update model set default_version='' where id=:id"), + {"id": model_id}, ) # 删除文件 try: - conn.execute(text("delete from files where md5_str=:md5_str"), {"md5_str": model_file}) - conn.execute(text("delete from files where md5_str=:md5_str"), {"md5_str": config_file}) + conn.execute( + text("delete from files where md5_str=:md5_str"), + {"md5_str": model_file}, + ) + conn.execute( + text("delete from files where md5_str=:md5_str"), + {"md5_str": config_file}, + ) os.remove(settings.file_upload_dir + "model/" + model_file) os.remove(settings.file_upload_dir + "model/" + config_file) except Exception as e: logging.info(e) - conn.execute(text("delete from model_version where id=:id"), {"id": version_id}) + conn.execute( + text("delete from model_version where id=:id"), {"id": version_id} + ) conn.commit() diff --git a/website/handlers/alg_model/url.py b/website/handlers/alg_model/url.py index 6c2ecfe..72bf7ea 100644 --- a/website/handlers/alg_model/url.py +++ b/website/handlers/alg_model/url.py @@ -7,6 +7,7 @@ handlers = [ ("/model/classification/delete", handler.ClassificationDeleteHandler), ("/model/list", handler.ListHandler), + ("/model/list/simple", handler.ListSimpleHandler), ("/model/add", handler.AddHandler), ("/model/edit", handler.EditHandler), ("/model/info", handler.InfoHandler), diff --git a/website/handlers/alg_model_hub/handler.py b/website/handlers/alg_model_hub/handler.py index 129a00e..72a68cd 100644 --- a/website/handlers/alg_model_hub/handler.py +++ b/website/handlers/alg_model_hub/handler.py @@ -11,6 +11,7 @@ from website import errors from website import settings from website.handler import APIHandler, authenticated + class ListHandler(APIHandler): """ - 描述: 模型运行库列表 @@ -35,30 +36,31 @@ class ListHandler(APIHandler): } ``` """ + @authenticated def post(self): pageNo = self.get_int_argument("pageNo", 1) pageSize = self.get_int_argument("pageSize", consts.PAGE_SIZE) name = self.get_escaped_argument("name", "") - + result = [] count = 0 - + with self.app_mysql.connect() as conn: - sql = "select id, name, create_time, update_time from model_hub where 1=1" + sql = "select id, name, path, create_time, update_time from model_hub where 1=1" param = {} - sql_count = "select count(id) from model where 1=1" + sql_count = "select count(id) from model_hub where 1=1" param_count = {} if name: - sql += "and m.name like :name" + sql += " and name like :name" param["name"] = "%{}%".format(name) - sql_count += "and m.name like :name" + sql_count += " and name like :name" param_count["name"] = "%{}%".format(name) - sql += " order by m.id desc limit :pageSize offset :offset" + sql += " order by id desc limit :pageSize offset :offset" param["pageSize"] = pageSize param["offset"] = (pageNo - 1) * pageSize @@ -69,55 +71,63 @@ class ListHandler(APIHandler): data = [] for item in result: - data.append({ - "id": item["id"], - "name": item["name"], - "create_time": item["create_time"].strftime("%Y-%m-%d %H:%M:%S"), - "update_time": item["update_time"].strftime("%Y-%m-%d %H:%M:%S") - }) + data.append( + { + "id": item["id"], + "name": item["name"], + "path": item["path"], + "create_time": item["create_time"].strftime("%Y-%m-%d %H:%M:%S"), + "update_time": item["update_time"].strftime("%Y-%m-%d %H:%M:%S"), + } + ) self.finish({"count": count, "data": data}) - class SyncHandler(APIHandler): """ - - 描述: 查询docker registry中的镜像 - - 请求方式:post - - 请求参数: - > - host, string, ip地址 - > - port, int, 端口 - - 返回值: + - 描述: 查询docker registry中的镜像 + - 请求方式:post + - 请求参数: + > - host, string, ip地址 + > - port, int, 端口 + - 返回值: + ``` + { + "data": [ + "xxx", # docker registry中docker images的地址 + "xxx", + ... + ] + } ``` - { - "data": [ - "xxx", # docker registry中docker images的地址 - "xxx", - ... - ] - } -``` """ + @authenticated def post(self): host = self.get_escaped_argument("host", "") port = self.get_int_argument("port") if not host or not port: - raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "host and port must be provided.") - + raise errors.HTTPAPIError( + errors.ERROR_BAD_REQUEST, "host and port must be provided." + ) + images = [] # 查询docker registry中的镜像 - repositories = requests.get("http://{}:{}/v2/_catalog".format(host, port)).json()["repositories"] + repositories = requests.get( + "http://{}:{}/v2/_catalog".format(host, port) + ).json()["repositories"] for repository in repositories: # 查询docker registry中的镜像的tag - tags = requests.get("http://{}:{}/v2/{}/tags/list".format(host, port, repository)).json()["tags"] + tags = requests.get( + "http://{}:{}/v2/{}/tags/list".format(host, port, repository) + ).json()["tags"] for tag in tags: image_name = "{}:{}/{}:{}".format(host, port, repository, tag) images.append(image_name) self.finish({"data": images}) - class AddHandler(APIHandler): """ - 描述: 新建模型运行库 @@ -130,6 +140,7 @@ class AddHandler(APIHandler): > - comment, string, 备注 - 返回值:无 """ + @authenticated def post(self): name = self.get_escaped_argument("name", "") @@ -138,12 +149,24 @@ class AddHandler(APIHandler): path = self.get_escaped_argument("path", "") comment = self.get_escaped_argument("comment", "") if not name or not host or not port: - raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "name and host and port must be provided.") - + raise errors.HTTPAPIError( + errors.ERROR_BAD_REQUEST, "name and host and port must be provided." + ) + with self.app_mysql.connect() as conn: - conn.execute(text("""insert into model_hub (name, host, port, path, comment, create_time, update_time) - values (:name, :host, :port, :path, :comment, NOW(), NOW())"""), - {"name": name, "host": host, "port": port, "path": path, "comment": comment}) + conn.execute( + text( + """insert into model_hub (name, host, port, path, comment, create_time, update_time) + values (:name, :host, :port, :path, :comment, NOW(), NOW())""" + ), + { + "name": name, + "host": host, + "port": port, + "path": path, + "comment": comment, + }, + ) conn.commit() @@ -163,6 +186,7 @@ class EditHandler(APIHandler): > - comment, string, 备注 - 返回值:无 """ + @authenticated def post(self): id = self.get_int_argument("id") @@ -174,8 +198,20 @@ class EditHandler(APIHandler): if not id or not name or not host or not port or path: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "parameter error") with self.app_mysql.connect() as conn: - conn.execute(text("""update model_hub set name=:name, host=:host, port=:port, path=:path, comment=:comment, update_time=NOW() - where id=:id"""), {"id": id, "name": name, "host": host, "port": port, "path": path, "comment": comment}) + conn.execute( + text( + """update model_hub set name=:name, host=:host, port=:port, path=:path, comment=:comment, update_time=NOW() + where id=:id""" + ), + { + "id": id, + "name": name, + "host": host, + "port": port, + "path": path, + "comment": comment, + }, + ) conn.commit() self.finish() @@ -189,14 +225,15 @@ class InfoHandler(APIHandler): - 返回值: ``` { - "name": "xxx", - "host": "xxx", - "port": 123, - "path": "xxx", - "comment": "xxx", + "name": "xxx", + "host": "xxx", + "port": 123, + "path": "xxx", + "comment": "xxx", } ``` """ + @authenticated def post(self): hid = self.get_int_argument("id") @@ -205,11 +242,18 @@ class InfoHandler(APIHandler): result = {} with self.app_mysql.connect() as conn: - cur = conn.execute(text("""select name, host, port, path, comment from model_hub where id=:id"""), {"id": hid}) + cur = conn.execute( + text( + """select name, host, port, path, comment from model_hub where id=:id""" + ), + {"id": hid}, + ) result = db_mysql.to_json(cur) if not result: - raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "model hub not found") - self.finish({"data": result}) + raise errors.HTTPAPIError( + errors.ERROR_BAD_REQUEST, "model hub not found" + ) + self.finish(result) class DeleteHandler(APIHandler): @@ -220,6 +264,7 @@ class DeleteHandler(APIHandler): > - id, int - 返回值:无 """ + @authenticated def post(self): hid = self.get_int_argument("id") diff --git a/website/handlers/enterprise_busi_model/handler.py b/website/handlers/enterprise_busi_model/handler.py index 72d2083..628f77f 100644 --- a/website/handlers/enterprise_busi_model/handler.py +++ b/website/handlers/enterprise_busi_model/handler.py @@ -35,14 +35,14 @@ class ListHandler(APIHandler): """ @authenticated def post(self): - pageNo = self.get_argument("pageNo", 1) - pageSize = self.get_argument("pageSize", 10) - entity_id = self.get_argument("entity_id") + pageNo = self.get_int_argument("pageNo", 1) + pageSize = self.get_int_argument("pageSize", 10) + entity_id = self.get_int_argument("entity_id") if not entity_id: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数错误") db_busimodel = DB_BusiModel.EnterpriseBusiModelRepository() - result = db_busimodel.list_enterprise_busi_model(entity_id, pageNo, pageSize) + result = db_busimodel.list_enterprise_busi_model(entity_id=entity_id, page_no=pageNo, page_size=pageSize) self.finish({"count": result["count"], "data": result["data"]}) @@ -63,7 +63,7 @@ class AddHandler(APIHandler): """ @authenticated def post(self): - entity_id = self.get_argument("entity_id") + entity_id = self.get_int_argument("entity_id") name = self.get_escaped_argument("name", "") comment = self.get_escaped_argument("comment", "") basemodel_ids = self.get_escaped_argument("basemodel_ids", "") @@ -131,7 +131,7 @@ class InfoHandler(APIHandler): """ @authenticated def post(self): - busimodel_id = self.get_argument("id") + busimodel_id = self.get_int_argument("id") if not busimodel_id: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数错误") diff --git a/website/handlers/enterprise_device/handler.py b/website/handlers/enterprise_device/handler.py index 0108bfd..d215924 100644 --- a/website/handlers/enterprise_device/handler.py +++ b/website/handlers/enterprise_device/handler.py @@ -2,17 +2,21 @@ import json import logging import random + from sqlalchemy import text +from website import consts from website import db_mysql, errors -from website.db.enterprise_device import enterprise_device as DB_Device +from website.db.alg_model.alg_model import ModelRepositry as DB_AlgModel from website.db.device_classification import ( device_classification as DB_DeviceClassification, ) -from website.db.enterprise_busi_model import enterprise_busi_model_node_device as DB_BusiModelNodeDevice +from website.db.enterprise_busi_model import ( + enterprise_busi_model as DB_BusiModel, +) +from website.db.enterprise_device import enterprise_device as DB_Device from website.handler import APIHandler, authenticated from website.util import shortuuid -from website import consts class DeviceClassificationAddHandler(APIHandler): @@ -228,8 +232,6 @@ class DeviceEditHandler(APIHandler): class DeviceDeleteHandler(APIHandler): """ - ### /enterprise/entity/nodes/device/delete - - 描述:企业节点,删除设备 - 请求方式:post - 请求参数: @@ -338,6 +340,115 @@ class DeviceInfoHandler(APIHandler): self.finish(device) +class DeviceBasemodelListHandler(APIHandler): + """ + - 描述:企业节点,节点信息 -> 设备列表 -> 基础模型配置 -> 模型列表 + - 请求方式:post + - 请求参数: + > - pageNo + > - pageSize + > - device_id, int, 设备id + - 返回值: + ``` + { + "count": 123, + "data": [ + { + "busi_model": "xxx", # 业务模型的名称 + "base_model": [ + { + "model_id": 123, # 基础模型id + "model_name": "xxx" # 基础模型name + "model_version": "xxx", # 基础模型的版本 + "model_hub_image": "xxx", # 运行库镜像 + }, + ... + ] + }, + ... + ] + } + ``` + """ + + @authenticated + def post(self): + device_id = self.get_int_argument("device_id") + if not device_id: + raise errors.HTTPAPIError( + errors.ERROR_BAD_REQUEST, "企业节点或设备不能为空" + ) + pageNo = self.get_int_argument("pageNo", 1) + pageSize = self.get_int_argument("pageSize", 10) + db_busi_model = DB_BusiModelNodeDevice.EnterpriseBusiModelNodeDeviceRepository() + busi_models, count = db_busi_model.get_busi_model_by_device(device_id=device_id, pagination=True, + page_no=pageNo, + page_size=pageSize) + + for item in busi_models: + busi_model_id = item["busi_model_id"] + busi_model_name = item["name"] + base_model_list = json.loads(item["base_models"]) + for base_model in base_model_list: + base_model_id = base_model["id"] + base_model_suid = base_model["suid"] + base_model_name = base_model["name"] + + self.finish() + + +class DeviceBaseModelCustomConfigHandler(APIHandler): + """ + - 描述:企业节点,节点信息 -> 设备列表 -> 基础模型配置 -> 基础模型参数配置 + - 请求方式:post + - 请求参数: + > - device_id, int, 设备id + > - node_id, int, 节点id + > - base_model_id, int, 基础模型id + > - busi_conf_file, string, 业务参数配置文件md5 + > - busi_conf_str, string, 业务参数配置,json字符串,eg:'{"key1": "value1", "key2": "value2"}' + > - model_hub_image, string, 运行库地址 + > - model_conf_file, string, 模型参数配置文件md5 + > - model_conf_str, string, 模型参数配置,json字符串,eg:'{"key1": "value1", "key2": "value2"}' + - 返回值:无 + """ + + @authenticated + def post(self): + device_id = self.get_int_argument("device_id") + node_id = self.get_int_argument("node_id") + busi_model_id = self.get_int_argument("busi_model_id") + base_model_id = self.get_int_argument("base_model_id") + busi_conf_file = self.get_escaped_argument("busi_conf_file") + busi_conf_str = self.get_escaped_argument("busi_conf_str") + model_hub_image = self.get_escaped_argument("model_hub_image") + model_conf_file = self.get_escaped_argument("model_conf_file") + model_conf_str = self.get_escaped_argument("model_conf_str") + + db_device = DB_Device.EnterpriseDeviceRepository() + device = db_device.get_device(device_id=device_id) + if not device: + raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "device not exist") + if device["node_id"] != node_id: + raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "device not belong to this node") + db_alg_model = DB_AlgModel() + base_model = db_alg_model.get_model_by_id(base_model_id) + if not base_model: + raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "base model not exist") + + device_suid = device["suid"] + node_suid = device["node_suid"] + entity_id = device["entity_id"] + entity_suid = device["entity_suid"] + + db_busi_model = DB_BusiModel.EnterpriseBusiModelRepository() + busi_model = db_busi_model.get_busi_model_by_id(busi_model_id) + busi_model_suid = busi_model.suid + + + self.finish() + + class StatusListHandler(APIHandler): """ - 描述:设备状态列表 @@ -379,26 +490,32 @@ class StatusListHandler(APIHandler): if not entity_id: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "企业节点不能为空") if status not in consts.device_status_map: - raise errors.HTTPAPIError( - errors.ERROR_BAD_REQUEST, "状态参数错误") + raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "状态参数错误") db_device = DB_Device.EnterpriseDeviceRepository() - res = db_device.list_entity_devices(entity_id=entity_id, pageNo=pageNo, pageSize=pageSize, - classification=classification, status=status) + res = db_device.list_entity_devices( + entity_id=entity_id, + pageNo=pageNo, + pageSize=pageSize, + classification=classification, + status=status, + ) count = res["count"] devices = res["devices"] data = [] for item in devices: - data.append({ - "id": item.id, - "name": item.name, - "status": item.status, - "cpu": random.randint(20, 30), - "mem": random.randint(20, 30), - "storage": random.randint(20, 30), - "gpu": random.randint(20, 30), - }) + data.append( + { + "id": item.id, + "name": item.name, + "status": item.status, + "cpu": random.randint(20, 30), + "mem": random.randint(20, 30), + "storage": random.randint(20, 30), + "gpu": random.randint(20, 30), + } + ) self.finish({"count": count, "data": data}) @@ -418,11 +535,10 @@ class StatusInfoHandler(APIHandler): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "设备不存在") res = res[0] db_busi_model = DB_BusiModelNodeDevice.EnterpriseBusiModelNodeDeviceRepository() - busi_models = db_busi_model.get_busi_model_by_device(device_id=device_id) + busi_models, _ = db_busi_model.get_busi_model_by_device(device_id=device_id) for busi_model in busi_models: base_model_ids = json.loads(busi_model["base_models"]) - self.finish() @@ -431,5 +547,4 @@ class StatusLogHandler(APIHandler): @authenticated def post(self): - self.finish() diff --git a/website/handlers/enterprise_device/url.py b/website/handlers/enterprise_device/url.py index 883feaa..f7a3d1f 100644 --- a/website/handlers/enterprise_device/url.py +++ b/website/handlers/enterprise_device/url.py @@ -15,7 +15,11 @@ handlers = [ ("/enterprise/entity/nodes/device/list", handler.DeviceListHandler), ("/enterprise/entity/nodes/device/list/simple", handler.DeviceListSimpleHandler), ("/enterprise/entity/nodes/device/info", handler.DeviceInfoHandler), - + ("/enterprise/entity/nodes/device/basemodel/list", handler.DeviceBasemodelListHandler), + ( + "/enterprise/entity/nodes/device/basemodel/custom/config", + handler.DeviceBaseModelCustomConfigHandler, + ), ("/enterprise/device/status/list", handler.StatusListHandler), ("/enterprise/device/status/info", handler.StatusInfoHandler), ("/enterprise/device/status/log", handler.StatusLogHandler),