# -*- coding: utf-8 -*- import logging from sqlalchemy import text from website import errors from website import settings from website import consts from website import db from website.util import md5 from website.handler import APIHandler, authenticated class ClassificationAddHandler(APIHandler): """ 添加模型分类 """ @authenticated def post(self): name = self.get_escaped_argument("name", "") if not name: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") if len(name) > 128: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称过长") with self.app_mysql.connect() as conn: cur = conn.execute(text("select id from model_classification where name=:name"), {"name":name}) classification_id = cur.fetchone()[0] if classification_id: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称重复") conn.execute( text("""insert into model_classification (name, created_at) values (:name, NOW())"""), {"name": name} ) self.finish() class ClassificationEditHandler(APIHandler): """ 编辑模型分类 """ @authenticated def post(self): classification_id = self.get_int_argument("id") name = self.get_escaped_argument("name", "") if not classification_id or not name: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") with self.app_mysql.connect() as conn: conn.execute(text("""update model_classification set name=:name where id=:id"""), {"name": name, "id": classification_id}) conn.commit() self.finish() class ClassificationListHandler(APIHandler): """ 模型分类列表 """ @authenticated def post(self): with self.app_mysql.connect() as conn: cur = conn.execute(text("""select id, name from model_classification""")) result = db.to_json_list(cur) self.finish({"data": result}) class ClassificationDeleteHandler(APIHandler): """ 删除模型分类 """ @authenticated def post(self): classification_id = self.get_int_argument("id") if not classification_id: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") with self.app_mysql.connect() as conn: conn.execute(text("""DELETE FROM model_classification WHERE id=:id"""), {"id": classification_id}) conn.commit() self.finish() class ListHandler(APIHandler): """ 模型列表 """ @authenticated def post(self): pageNo = self.get_int_argument("pageNo", 1) pageSize = self.get_int_argument("pageSize", 10) name = self.get_escaped_argument("name", "") result = [] with self.app_mysql.connect() as conn: sql = "select m.id, m.name, m.model_type, m.default_version, mc.name classification_name, m.update_at " \ "from model m left join model_classification mc on m.classification=mc.id where m.del=0 " param = {} sql_count = "select count(id) from model where del=0 " param_count = {} if name: sql += "and m.name like :name" param["name"] = "%{}%".format(name) sql_count += "and m.name like :name" param_count["name"] = "%{}%".format(name) sql += " order by m.id desc limit :pageSize offset :offset" param["pageSize"] = pageSize param["offset"] = (pageNo - 1) * pageSize cur = conn.execute(text(sql), param) result = db.to_json_list(cur) count = conn.execute(text(sql_count), param_count).fetchone()[0] data = [] for item in result: data.append({ "id": item["id"], "name": item["name"], "model_type": consts.model_type_map[item["model_type"]], "classification_name": item["classification_name"], "default_version": item["default_version"], "update_time": str(item["update_at"]) }) self.finish({"data": data, "count": count}) class AddHandler(APIHandler): """ 添加模型 """ @authenticated def post(self): name = self.get_escaped_argument("name", "") model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法,1002/深度学习 classification = self.get_int_argument("classification") comment = self.get_escaped_argument("comment", "") if not name or not classification or model_type not in consts.MODEL_TYPE: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") with self.app_mysql.connect() as conn: sql = text("select id from model_classification where id=:id", {"id": classification}) cur = conn.execute(sql) if not cur.fetchone(): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类不存在") conn.execute( text("""insert into model (name, model_type, classification, comment, created_at) values (:name, :model_type, :classification, :comment, NOW())"""), {"name": name, "model_type": model_type, "classification": classification, "comment": comment} ) conn.commit() self.finish() class EditHandler(APIHandler): """ 编辑模型 """ @authenticated def post(self): mid = self.get_int_argument("id") name = self.get_escaped_argument("name", "") model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法,1002/深度学习 classification = self.get_int_argument("classification") comment = self.get_escaped_argument("comment", "") if not name or not classification or model_type not in consts.MODEL_TYPE: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") with self.app_mysql.connect() as conn: conn.execute( text("""update model set name=:name, model_type=:model_type, classification=:classification, comment=:comment, update_at=NOW() where id=:id"""), {"name": name, "model_type": model_type, "classification": classification, "comment": comment, "id": mid} ) conn.commit() self.finish() class InfoHandler(APIHandler): """ 模型信息 """ @authenticated def post(self): mid = self.get_int_argument("id") if not mid: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") result = {} with self.app_mysql.connect() as conn: cur = conn.execute( text("""select m.name, mc.name as classification_name, m.comment, m.update_at from model m, model_classification mc where m.id=:id and m.classification=c.id"""), {"id": mid} ) result = db.to_json(cur) if not result: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型不存在") data = { "name": result["name"], "classification_name": result["classification_name"], "comment": result["comment"], "update_time": str(result["update_at"]) } self.finish(data) class DeleteHandler(APIHandler): """ 删除模型 """ @authenticated def post(self): mid = self.get_int_argument("id") if not mid: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") with self.app_mysql.connect() as conn: conn.execute(text("update model set del=1 where id=:id"), {"id": mid}) conn.commit() self.finish() class VersionAddHandler(APIHandler): """ 添加模型版本 - 描述: 添加模型版本 - 请求方式:post - 请求参数: > - model_id, int, 模型id > - version, string, 版本 > - comment, string,备注 > - model_file, string, 模型文件的md5, 只允许上传一个文件 > - config_file, string, 模型配置文件的md5 > - config_str, string, 模型配置参数,json格式字符串,eg: '{"a":"xx","b":"xx"}' - 返回值:无 """ @authenticated def post(self): mid = self.get_int_argument("model_id") version = self.get_escaped_argument("version", "") comment = self.get_escaped_argument("comment", "") model_file = self.get_escaped_argument("model_file", "") config_file = self.get_escaped_argument("config_file", "") config_str = self.get_escaped_argument("config_str", "") if not mid or not version or not model_file: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") if not md5.md5_validate(model_file): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误") if config_file and not md5.md5_validate(config_file): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误") with self.app_mysql.connect() as conn: conn.execute( text( """ insert into model_version (model_id, version, comment,model_file, config_file, config_str, created_at, update_at) values (:model_id, :version, :comment, :model_file, :config_file, :config_str, NOW(), NOW())""" ), {"model_id": mid, "version": version, "comment": comment, "model_file": model_file, "config_file": config_file, "config_str": config_str}) conn.commit() self.finish() class VersionEditHandler(APIHandler): """ 编辑模型版本 - 描述: 编辑模型版本 - 请求方式:post - 请求参数: > - version_id, int, 模型id > - version, string, 版本 > - comment, string,备注 > - model_file, string, 模型文件的md5, 只允许上传一个文件 > - config_file, string, 模型配置文件的md5 > - config_str, string, 模型配置参数,json格式字符串,eg: '{"a":"xx","b":"xx"}' - 返回值:无 """ @authenticated def post(self): version_id = self.get_int_argument("version_id") version = self.get_escaped_argument("version", "") comment = self.get_escaped_argument("comment", "") model_file = self.get_escaped_argument("model_file", "") config_file = self.get_escaped_argument("config_file", "") config_str = self.get_escaped_argument("config_str", "") if not version_id or not version or not model_file: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") if not md5.md5_validate(model_file): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误") if config_file and not md5.md5_validate(config_file): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误") with self.app_mysql.connect() as conn: conn.execute( text("update model_version set version=:version, comment=:comment, model_file=:model_file, config_file=:config_file, config_str=:config_str, update_at=NOW() where id=:id"), {"version": version, "comment": comment, "model_file": model_file, "config_file": config_file, "config_str": config_str, "id": version_id} ) self.finish() class VersionListHandler(APIHandler): """ 模型版本列表 """ @authenticated def post(self): self.finish() class VersionInfoHandler(APIHandler): """ 模型版本信息 """ @authenticated def post(self): self.finish() class VersionSetDefaultHandler(APIHandler): """ 设置模型版本为默认版本 - 描述: 设置模型默认版本 - 请求方式:post - 请求参数: > - version_id, int, 模型版本id > - model_id, int, 模型id - 返回值:无 """ @authenticated def post(self): version_id = self.get_int_argument("version_id") model_id = self.get_int_argument("model_id") if not version_id or not model_id: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") with self.app_mysql.connect() as conn: conn.execute( text("update model_version set is_default=0 where model_id=:model_id"), {"model_id": model_id}) conn.execute( text("update model_version set is_default=1 where id=:id"), {"id": version_id}) conn.commit() self.finish() class VersionDeleteHandler(APIHandler): """ 删除模型版本 """ @authenticated def post(self): self.finish()