You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

395 lines
13 KiB
Python

# -*- coding: utf-8 -*-
import logging
from sqlalchemy import text
from website import errors
from website import settings
from website import consts
from website import db
from website.util import md5
from website.handler import APIHandler, authenticated
class ClassificationAddHandler(APIHandler):
"""
添加模型分类
"""
@authenticated
def post(self):
name = self.get_escaped_argument("name", "")
if not name:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if len(name) > 128:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称过长")
with self.app_mysql.connect() as conn:
cur = conn.execute(text("select id from model_classification where name=:name"), {"name":name})
classification_id = cur.fetchone()[0]
if classification_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称重复")
conn.execute(
text("""insert into model_classification (name, created_at) values (:name, NOW())"""), {"name": name}
)
self.finish()
class ClassificationEditHandler(APIHandler):
"""
编辑模型分类
"""
@authenticated
def post(self):
classification_id = self.get_int_argument("id")
name = self.get_escaped_argument("name", "")
if not classification_id or not name:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(text("""update model_classification set name=:name where id=:id"""), {"name": name, "id": classification_id})
conn.commit()
self.finish()
class ClassificationListHandler(APIHandler):
"""
模型分类列表
"""
@authenticated
def post(self):
with self.app_mysql.connect() as conn:
cur = conn.execute(text("""select id, name from model_classification"""))
result = db.to_json_list(cur)
self.finish({"data": result})
class ClassificationDeleteHandler(APIHandler):
"""
删除模型分类
"""
@authenticated
def post(self):
classification_id = self.get_int_argument("id")
if not classification_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(text("""DELETE FROM model_classification WHERE id=:id"""), {"id": classification_id})
conn.commit()
self.finish()
class ListHandler(APIHandler):
"""
模型列表
"""
@authenticated
def post(self):
pageNo = self.get_int_argument("pageNo", 1)
pageSize = self.get_int_argument("pageSize", 10)
name = self.get_escaped_argument("name", "")
result = []
with self.app_mysql.connect() as conn:
sql = "select m.id, m.name, m.model_type, m.default_version, mc.name classification_name, m.update_at " \
"from model m left join model_classification mc on m.classification=mc.id where m.del=0 "
param = {}
sql_count = "select count(id) from model where del=0 "
param_count = {}
if name:
sql += "and m.name like :name"
param["name"] = "%{}%".format(name)
sql_count += "and m.name like :name"
param_count["name"] = "%{}%".format(name)
sql += " order by m.id desc limit :pageSize offset :offset"
param["pageSize"] = pageSize
param["offset"] = (pageNo - 1) * pageSize
cur = conn.execute(text(sql), param)
result = db.to_json_list(cur)
count = conn.execute(text(sql_count), param_count).fetchone()[0]
data = []
for item in result:
data.append({
"id": item["id"],
"name": item["name"],
"model_type": consts.model_type_map[item["model_type"]],
"classification_name": item["classification_name"],
"default_version": item["default_version"],
"update_time": str(item["update_at"])
})
self.finish({"data": data, "count": count})
class AddHandler(APIHandler):
"""
添加模型
"""
@authenticated
def post(self):
name = self.get_escaped_argument("name", "")
model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法1002/深度学习
classification = self.get_int_argument("classification")
comment = self.get_escaped_argument("comment", "")
if not name or not classification or model_type not in consts.MODEL_TYPE:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
sql = text("select id from model_classification where id=:id", {"id": classification})
cur = conn.execute(sql)
if not cur.fetchone():
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类不存在")
conn.execute(
text("""insert into model (name, model_type, classification, comment, created_at) values (:name, :model_type, :classification, :comment, NOW())"""),
{"name": name, "model_type": model_type, "classification": classification, "comment": comment}
)
conn.commit()
self.finish()
class EditHandler(APIHandler):
"""
编辑模型
"""
@authenticated
def post(self):
mid = self.get_int_argument("id")
name = self.get_escaped_argument("name", "")
model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法1002/深度学习
classification = self.get_int_argument("classification")
comment = self.get_escaped_argument("comment", "")
if not name or not classification or model_type not in consts.MODEL_TYPE:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(
text("""update model set name=:name, model_type=:model_type, classification=:classification, comment=:comment, update_at=NOW() where id=:id"""),
{"name": name, "model_type": model_type, "classification": classification, "comment": comment, "id": mid}
)
conn.commit()
self.finish()
class InfoHandler(APIHandler):
"""
模型信息
"""
@authenticated
def post(self):
mid = self.get_int_argument("id")
if not mid:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
result = {}
with self.app_mysql.connect() as conn:
cur = conn.execute(
text("""select m.name, mc.name as classification_name, m.comment, m.update_at from model m, model_classification mc where m.id=:id and m.classification=c.id"""),
{"id": mid}
)
result = db.to_json(cur)
if not result:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型不存在")
data = {
"name": result["name"],
"classification_name": result["classification_name"],
"comment": result["comment"],
"update_time": str(result["update_at"])
}
self.finish(data)
class DeleteHandler(APIHandler):
"""
删除模型
"""
@authenticated
def post(self):
mid = self.get_int_argument("id")
if not mid:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(text("update model set del=1 where id=:id"), {"id": mid})
conn.commit()
self.finish()
class VersionAddHandler(APIHandler):
"""
添加模型版本
- 描述 添加模型版本
- 请求方式post
- 请求参数
> - model_id, int, 模型id
> - version, string, 版本
> - comment, string备注
> - model_file, string, 模型文件的md5, 只允许上传一个文件
> - config_file, string, 模型配置文件的md5
> - config_str, string, 模型配置参数json格式字符串eg: '{"a":"xx","b":"xx"}'
- 返回值
"""
@authenticated
def post(self):
mid = self.get_int_argument("model_id")
version = self.get_escaped_argument("version", "")
comment = self.get_escaped_argument("comment", "")
model_file = self.get_escaped_argument("model_file", "")
config_file = self.get_escaped_argument("config_file", "")
config_str = self.get_escaped_argument("config_str", "")
if not mid or not version or not model_file:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if not md5.md5_validate(model_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误")
if config_file and not md5.md5_validate(config_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误")
with self.app_mysql.connect() as conn:
conn.execute(
text(
"""
insert into model_version (model_id, version, comment,model_file, config_file, config_str, created_at, update_at)
values (:model_id, :version, :comment, :model_file, :config_file, :config_str, NOW(), NOW())"""
),
{"model_id": mid, "version": version, "comment": comment, "model_file": model_file, "config_file": config_file, "config_str": config_str})
conn.commit()
self.finish()
class VersionEditHandler(APIHandler):
"""
编辑模型版本
- 描述 编辑模型版本
- 请求方式post
- 请求参数
> - version_id, int, 模型id
> - version, string, 版本
> - comment, string备注
> - model_file, string, 模型文件的md5, 只允许上传一个文件
> - config_file, string, 模型配置文件的md5
> - config_str, string, 模型配置参数json格式字符串eg: '{"a":"xx","b":"xx"}'
- 返回值
"""
@authenticated
def post(self):
version_id = self.get_int_argument("version_id")
version = self.get_escaped_argument("version", "")
comment = self.get_escaped_argument("comment", "")
model_file = self.get_escaped_argument("model_file", "")
config_file = self.get_escaped_argument("config_file", "")
config_str = self.get_escaped_argument("config_str", "")
if not version_id or not version or not model_file:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if not md5.md5_validate(model_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误")
if config_file and not md5.md5_validate(config_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误")
with self.app_mysql.connect() as conn:
conn.execute(
text("update model_version set version=:version, comment=:comment, model_file=:model_file, config_file=:config_file, config_str=:config_str, update_at=NOW() where id=:id"),
{"version": version, "comment": comment, "model_file": model_file, "config_file": config_file, "config_str": config_str, "id": version_id}
)
self.finish()
class VersionListHandler(APIHandler):
"""
模型版本列表
"""
@authenticated
def post(self):
self.finish()
class VersionInfoHandler(APIHandler):
"""
模型版本信息
"""
@authenticated
def post(self):
self.finish()
class VersionSetDefaultHandler(APIHandler):
"""
设置模型版本为默认版本
- 描述 设置模型默认版本
- 请求方式post
- 请求参数
> - version_id, int, 模型版本id
> - model_id, int, 模型id
- 返回值
"""
@authenticated
def post(self):
version_id = self.get_int_argument("version_id")
model_id = self.get_int_argument("model_id")
if not version_id or not model_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(
text("update model_version set is_default=0 where model_id=:model_id"),
{"model_id": model_id})
conn.execute(
text("update model_version set is_default=1 where id=:id"),
{"id": version_id})
conn.commit()
self.finish()
class VersionDeleteHandler(APIHandler):
"""
删除模型版本
"""
@authenticated
def post(self):
self.finish()