# -*- coding: utf-8 -*- import logging import os from sqlalchemy import text from website import consts from website import db_mysql from website import errors from website import settings from website.handler import APIHandler, authenticated from website.util import md5 class ClassificationAddHandler(APIHandler): """ 添加模型分类 """ @authenticated def post(self): name = self.get_escaped_argument("name", "") if not name: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") if len(name) > 128: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称过长") with self.app_mysql.connect() as conn: cur = conn.execute(text("select id from model_classification where name=:name"), {"name": name}) row = db_mysql.to_json(cur) if row: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称重复") conn.execute( text("""insert into model_classification (name, create_time) values (:name, NOW())"""), {"name": name} ) self.finish() class ClassificationEditHandler(APIHandler): """ 编辑模型分类 """ @authenticated def post(self): classification_id = self.get_int_argument("id") name = self.get_escaped_argument("name", "") if not classification_id or not name: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") with self.app_mysql.connect() as conn: conn.execute(text("""update model_classification set name=:name where id=:id"""), {"name": name, "id": classification_id}) conn.commit() self.finish() class ClassificationListHandler(APIHandler): """ 模型分类列表 """ @authenticated def post(self): with self.app_mysql.connect() as conn: cur = conn.execute(text("""select id, name from model_classification""")) result = db_mysql.to_json_list(cur) self.finish({"data": result}) class ClassificationDeleteHandler(APIHandler): """ 删除模型分类 """ @authenticated def post(self): classification_id = self.get_int_argument("id") if not classification_id: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") with self.app_mysql.connect() as conn: conn.execute(text("""DELETE FROM model_classification WHERE id=:id"""), {"id": classification_id}) conn.commit() self.finish() class ListHandler(APIHandler): """ 模型列表 """ @authenticated def post(self): pageNo = self.get_int_argument("pageNo", 1) pageSize = self.get_int_argument("pageSize", consts.PAGE_SIZE) name = self.get_escaped_argument("name", "") result = [] with self.app_mysql.connect() as conn: sql = "select m.id, m.name, m.model_type, m.default_version, mc.name classification_name, m.update_time " \ "from model m left join model_classification mc on m.classification=mc.id where m.del=0 " param = {} sql_count = "select count(id) from model where del=0 " param_count = {} if name: sql += "and m.name like :name" param["name"] = "%{}%".format(name) sql_count += "and m.name like :name" param_count["name"] = "%{}%".format(name) sql += " order by m.id desc limit :pageSize offset :offset" param["pageSize"] = pageSize param["offset"] = (pageNo - 1) * pageSize cur = conn.execute(text(sql), param) result = db_mysql.to_json_list(cur) count = conn.execute(text(sql_count), param_count).fetchone()[0] data = [] for item in result: data.append({ "id": item["id"], "name": item["name"], "model_type": consts.model_type_map[item["model_type"]], "classification_name": item["classification_name"], "default_version": item["default_version"], "update_time": str(item["update_time"]) }) self.finish({"data": data, "count": count}) class AddHandler(APIHandler): """ 添加模型 """ @authenticated def post(self): name = self.get_escaped_argument("name", "") model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法,1002/深度学习 classification = self.get_int_argument("classification") comment = self.get_escaped_argument("comment", "") if not name or not classification or model_type not in consts.model_type_map: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") with self.app_mysql.connect() as conn: sql = text("select id from model_classification where id=:id", {"id": classification}) cur = conn.execute(sql) if not cur.fetchone(): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类不存在") conn.execute( text( """insert into model (name, model_type, classification, comment, create_time) values (:name, :model_type, :classification, :comment, NOW())"""), {"name": name, "model_type": model_type, "classification": classification, "comment": comment} ) conn.commit() self.finish() class EditHandler(APIHandler): """ 编辑模型 """ @authenticated def post(self): mid = self.get_int_argument("id") name = self.get_escaped_argument("name", "") model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法,1002/深度学习 classification = self.get_int_argument("classification") comment = self.get_escaped_argument("comment", "") if not name or not classification or model_type not in consts.model_type_map: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") with self.app_mysql.connect() as conn: conn.execute( text( """update model set name=:name, model_type=:model_type, classification=:classification, comment=:comment, update_time=NOW() where id=:id"""), {"name": name, "model_type": model_type, "classification": classification, "comment": comment, "id": mid} ) conn.commit() self.finish() class InfoHandler(APIHandler): """ 模型信息 """ @authenticated def post(self): mid = self.get_int_argument("id") if not mid: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") result = {} with self.app_mysql.connect() as conn: cur = conn.execute( text( """select m.name, mc.name as classification_name, m.comment, m.update_time from model m, model_classification mc where m.id=:id and m.classification=c.id"""), {"id": mid} ) result = db_mysql.to_json(cur) if not result: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型不存在") data = { "name": result["name"], "classification_name": result["classification_name"], "comment": result["comment"], "update_time": str(result["update_time"]) } self.finish(data) class DeleteHandler(APIHandler): """ 删除模型 """ @authenticated def post(self): mid = self.get_int_argument("id") if not mid: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") with self.app_mysql.connect() as conn: conn.execute(text("update model set del=1 where id=:id"), {"id": mid}) conn.commit() self.finish() class VersionAddHandler(APIHandler): """ 添加模型版本 - 描述: 添加模型版本 - 请求方式:post - 请求参数: > - model_id, int, 模型id > - version, string, 版本 > - comment, string,备注 > - model_file, string, 模型文件的md5, 只允许上传一个文件 > - config_file, string, 模型配置文件的md5 > - config_str, string, 模型配置参数,json格式字符串,eg: '{"a":"xx","b":"xx"}' - 返回值:无 """ @authenticated def post(self): mid = self.get_int_argument("model_id") version = self.get_escaped_argument("version", "") comment = self.get_escaped_argument("comment", "") model_file = self.get_escaped_argument("model_file", "") config_file = self.get_escaped_argument("config_file", "") config_str = self.get_escaped_argument("config_str", "") if not mid or not version or not model_file: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") if not md5.md5_validate(model_file): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误") if config_file and not md5.md5_validate(config_file): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误") with self.app_mysql.connect() as conn: conn.execute( text( """ insert into model_version (model_id, version, comment,model_file, config_file, config_str, create_time, update_time) values (:model_id, :version, :comment, :model_file, :config_file, :config_str, NOW(), NOW())""" ), {"model_id": mid, "version": version, "comment": comment, "model_file": model_file, "config_file": config_file, "config_str": config_str}) conn.commit() self.finish() class VersionEditHandler(APIHandler): """ 编辑模型版本 - 描述: 编辑模型版本 - 请求方式:post - 请求参数: > - version_id, int, 模型id > - version, string, 版本 > - comment, string,备注 > - model_file, string, 模型文件的md5, 只允许上传一个文件 > - config_file, string, 模型配置文件的md5 > - config_str, string, 模型配置参数,json格式字符串,eg: '{"a":"xx","b":"xx"}' - 返回值:无 """ @authenticated def post(self): version_id = self.get_int_argument("version_id") version = self.get_escaped_argument("version", "") comment = self.get_escaped_argument("comment", "") model_file = self.get_escaped_argument("model_file", "") config_file = self.get_escaped_argument("config_file", "") config_str = self.get_escaped_argument("config_str", "") if not version_id or not version or not model_file: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") if not md5.md5_validate(model_file): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误") if config_file and not md5.md5_validate(config_file): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误") with self.app_mysql.connect() as conn: conn.execute( text( "update model_version " "set version=:version, comment=:comment, model_file=:model_file, config_file=:config_file, " "config_str=:config_str, update_time=NOW() where id=:id"), {"version": version, "comment": comment, "model_file": model_file, "config_file": config_file, "config_str": config_str, "id": version_id} ) self.finish() class VersionListHandler(APIHandler): """ 模型版本列表 - 描述: 模型版本列表 - 请求方式:post - 请求参数: > - model_id, int, 模型id > - pageNo, int > - pageSize, int - 返回值: ``` { "count": 123, "data": [ { "version_id": 213, # 版本id "version": "xx", # 版本号 "path": "xxx", # 文件路径 "size": 123, "update_time": "xxx", "is_default": 1 # 是否默认,1/默认,0/非默认 }, ... ] } ``` """ @authenticated def post(self): model_id = self.get_int_argument("model_id") pageNo = self.get_int_argument("pageNo", 1) pageSize = self.get_int_argument("pageSize", 20) if not model_id: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") count = 0 data = [] with self.app_mysql.connect() as conn: # 查询数据,从model_version, files查询version以及file的信息 cur = conn.execute( text( """ select mv.id as version_id, mv.version, mv.model_file, mv.update_time, mv.is_default, f.filepath, f.filesize from model_version mv left join files f on mv.model_file=f.md5 where mv.model_id=:mid and mv.del=0 order by mv.version_id desc limit :offset, :limit """ ), {"mid": model_id, "offset": (pageNo - 1) * pageSize, "limit": pageSize} ) result = db_mysql.to_json(cur) # 获取记录数量 count = conn.execute( text( """ select count(*) from model_version mv where mv.del=0 """ ), {"mid": model_id} ).scalar() for item in result: data.append({ "version_id": item["version_id"], # 版本id "version": item["version"], # 版本号 "path": item["filepath"], # 文件路径 "size": item["filesize"], "update_time": str(item["update_time"]), "is_default": item["is_default"] # 是否默认,1/默认,0/非默认 }) self.finish({"count": count, "data": data}) class VersionInfoHandler(APIHandler): """ 模型版本信息 - 描述: 模型版本信息 - 请求方式:post - 请求参数: > - version_id, int, 模型id - 返回值: ``` { "model_name": "xxx", "version": "xxx", # 版本 "comment": "xxx", # 备注 "model_file_name": "xxx", # 模型文件名 "model_file_size": 123, # 模型文件大小 "model_file_md5": "xx", # 模型文件md5 "config_file_name": "xx", # 配置文件名 "config_file_size": "xx", # 配置文件大小 "config_file_md5": "xxx", # 配置文件md5 "config_str": "xxx" # 配置参数 } ``` """ @authenticated def post(self): version_id = self.get_int_argument("version_id") response = { "model_name": "", "version": "", "comment": "", "model_file_name": "", "model_file_size": 0, "model_file_md5": "", "config_file_name": "", "config_file_size": 0, "config_file_md5": "", "config_str": "" } with self.app_mysql.connect() as conn: cur = conn.execute( text( """select m.name as model_name, mv.version, mv.comment, mv.model_file, mv.config_file, mv.config_str from model_version mv, model m where mv.id=:id and mv.model_id=m.id"""), {"id": version_id} ) result = db_mysql.to_json(cur) if not result: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型版本不存在") model_file = result["model_file"] config_file = result["config_file"] response["model_name"] = result["model_name"] response["version"] = result["version"] response["comment"] = result["comment"] response["model_file_md5"] = model_file response["config_file_md5"] = config_file # 获取文件信息 if model_file: cur_model_file = conn.execute( text("select filename, filesize from files where md5_str=:md5_str"), {"md5_str": model_file} ) model_file_info = db_mysql.to_json(cur_model_file) response["model_file_name"] = model_file_info["filename"] response["model_file_size"] = model_file_info["filesize"] if config_file: cur_config_file = conn.execute( text("select filename, filesize from files where md5_str=:md5_str"), {"md5_str": config_file} ) config_file_info = db_mysql.to_json(cur_config_file) response["config_file_name"] = config_file_info["filename"] response["config_file_size"] = config_file_info["filesize"] response["config_str"] = result["config_str"] self.finish(response) class VersionSetDefaultHandler(APIHandler): """ 设置模型版本为默认版本 - 描述: 设置模型默认版本 - 请求方式:post - 请求参数: > - version_id, int, 模型版本id > - model_id, int, 模型id - 返回值:无 """ @authenticated def post(self): version_id = self.get_int_argument("version_id") model_id = self.get_int_argument("model_id") if not version_id or not model_id: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") with self.app_mysql.connect() as conn: conn.execute( text("update model_version set is_default=0 where model_id=:model_id"), {"model_id": model_id}) conn.execute( text("update model_version set is_default=1 where id=:id"), {"id": version_id}) conn.commit() self.finish() class VersionDeleteHandler(APIHandler): """ 删除模型版本 - 描述: 删除模型版本 - 请求方式:post - 请求参数: > - version_id, int, 模型版本id - 返回值:无 """ @authenticated def post(self): version_id = self.get_int_argument("version_id") if not version_id: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") row = {} # 获取模型对应的model_file, config_file,使用model_file, config_file删除对应的存储文件 with self.app_mysql.connect() as conn: cur = conn.execute(text("select model_id, model_file, config_file from model_version where id=:id"), {"id": version_id}) row = db_mysql.to_json(cur) if not row: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型版本不存在") model_file = row["model_file"] config_file = row["config_file"] model_id = row["model_id"] # 清空模型默认版本 conn.execute( text("update model set default_version='' where id=:id"), {"id": model_id} ) # 删除文件 try: conn.execute(text("delete from files where md5_str=:md5_str"), {"md5_str": model_file}) conn.execute(text("delete from files where md5_str=:md5_str"), {"md5_str": config_file}) os.remove(settings.file_upload_dir + "model/" + model_file) os.remove(settings.file_upload_dir + "model/" + config_file) except Exception as e: logging.info(e) conn.execute(text("delete from model_version where id=:id"), {"id": version_id}) conn.commit() self.finish()