|
|
# -*- coding: utf-8 -*-
|
|
|
import logging
|
|
|
import os
|
|
|
|
|
|
from sqlalchemy import text
|
|
|
|
|
|
from website import consts
|
|
|
from website import db
|
|
|
from website import errors
|
|
|
from website import settings
|
|
|
from website.handler import APIHandler, authenticated
|
|
|
from website.util import md5
|
|
|
|
|
|
|
|
|
class ClassificationAddHandler(APIHandler):
|
|
|
"""
|
|
|
添加模型分类
|
|
|
|
|
|
"""
|
|
|
|
|
|
@authenticated
|
|
|
def post(self):
|
|
|
name = self.get_escaped_argument("name", "")
|
|
|
if not name:
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
|
|
|
|
|
|
if len(name) > 128:
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称过长")
|
|
|
|
|
|
with self.app_mysql.connect() as conn:
|
|
|
cur = conn.execute(text("select id from model_classification where name=:name"), {"name": name})
|
|
|
classification_id = cur.fetchone()[0]
|
|
|
if classification_id:
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称重复")
|
|
|
|
|
|
conn.execute(
|
|
|
text("""insert into model_classification (name, created_at) values (:name, NOW())"""), {"name": name}
|
|
|
)
|
|
|
|
|
|
self.finish()
|
|
|
|
|
|
|
|
|
class ClassificationEditHandler(APIHandler):
|
|
|
"""
|
|
|
编辑模型分类
|
|
|
"""
|
|
|
|
|
|
@authenticated
|
|
|
def post(self):
|
|
|
classification_id = self.get_int_argument("id")
|
|
|
name = self.get_escaped_argument("name", "")
|
|
|
if not classification_id or not name:
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
|
|
|
with self.app_mysql.connect() as conn:
|
|
|
conn.execute(text("""update model_classification set name=:name where id=:id"""),
|
|
|
{"name": name, "id": classification_id})
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
self.finish()
|
|
|
|
|
|
|
|
|
class ClassificationListHandler(APIHandler):
|
|
|
"""
|
|
|
模型分类列表
|
|
|
"""
|
|
|
|
|
|
@authenticated
|
|
|
def post(self):
|
|
|
with self.app_mysql.connect() as conn:
|
|
|
cur = conn.execute(text("""select id, name from model_classification"""))
|
|
|
result = db.to_json_list(cur)
|
|
|
|
|
|
self.finish({"data": result})
|
|
|
|
|
|
|
|
|
class ClassificationDeleteHandler(APIHandler):
|
|
|
"""
|
|
|
删除模型分类
|
|
|
"""
|
|
|
|
|
|
@authenticated
|
|
|
def post(self):
|
|
|
classification_id = self.get_int_argument("id")
|
|
|
if not classification_id:
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
|
|
|
|
|
|
with self.app_mysql.connect() as conn:
|
|
|
conn.execute(text("""DELETE FROM model_classification WHERE id=:id"""), {"id": classification_id})
|
|
|
conn.commit()
|
|
|
|
|
|
self.finish()
|
|
|
|
|
|
|
|
|
class ListHandler(APIHandler):
|
|
|
"""
|
|
|
模型列表
|
|
|
"""
|
|
|
|
|
|
@authenticated
|
|
|
def post(self):
|
|
|
pageNo = self.get_int_argument("pageNo", 1)
|
|
|
pageSize = self.get_int_argument("pageSize", consts.PAGE_SIZE)
|
|
|
name = self.get_escaped_argument("name", "")
|
|
|
|
|
|
result = []
|
|
|
with self.app_mysql.connect() as conn:
|
|
|
sql = "select m.id, m.name, m.model_type, m.default_version, mc.name classification_name, m.update_time " \
|
|
|
"from model m left join model_classification mc on m.classification=mc.id where m.del=0 "
|
|
|
|
|
|
param = {}
|
|
|
|
|
|
sql_count = "select count(id) from model where del=0 "
|
|
|
param_count = {}
|
|
|
|
|
|
if name:
|
|
|
sql += "and m.name like :name"
|
|
|
param["name"] = "%{}%".format(name)
|
|
|
|
|
|
sql_count += "and m.name like :name"
|
|
|
param_count["name"] = "%{}%".format(name)
|
|
|
|
|
|
sql += " order by m.id desc limit :pageSize offset :offset"
|
|
|
param["pageSize"] = pageSize
|
|
|
param["offset"] = (pageNo - 1) * pageSize
|
|
|
|
|
|
cur = conn.execute(text(sql), param)
|
|
|
result = db.to_json_list(cur)
|
|
|
|
|
|
count = conn.execute(text(sql_count), param_count).fetchone()[0]
|
|
|
|
|
|
data = []
|
|
|
for item in result:
|
|
|
data.append({
|
|
|
"id": item["id"],
|
|
|
"name": item["name"],
|
|
|
"model_type": consts.model_type_map[item["model_type"]],
|
|
|
"classification_name": item["classification_name"],
|
|
|
"default_version": item["default_version"],
|
|
|
"update_time": str(item["update_time"])
|
|
|
})
|
|
|
|
|
|
self.finish({"data": data, "count": count})
|
|
|
|
|
|
|
|
|
class AddHandler(APIHandler):
|
|
|
"""
|
|
|
添加模型
|
|
|
"""
|
|
|
|
|
|
@authenticated
|
|
|
def post(self):
|
|
|
name = self.get_escaped_argument("name", "")
|
|
|
model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法,1002/深度学习
|
|
|
classification = self.get_int_argument("classification")
|
|
|
comment = self.get_escaped_argument("comment", "")
|
|
|
|
|
|
if not name or not classification or model_type not in consts.model_type_map:
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
|
|
|
|
|
|
with self.app_mysql.connect() as conn:
|
|
|
sql = text("select id from model_classification where id=:id", {"id": classification})
|
|
|
cur = conn.execute(sql)
|
|
|
if not cur.fetchone():
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类不存在")
|
|
|
|
|
|
conn.execute(
|
|
|
text(
|
|
|
"""insert into model (name, model_type, classification, comment, created_at)
|
|
|
values (:name, :model_type, :classification, :comment, NOW())"""),
|
|
|
{"name": name, "model_type": model_type, "classification": classification, "comment": comment}
|
|
|
)
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
self.finish()
|
|
|
|
|
|
|
|
|
class EditHandler(APIHandler):
|
|
|
"""
|
|
|
编辑模型
|
|
|
"""
|
|
|
|
|
|
@authenticated
|
|
|
def post(self):
|
|
|
mid = self.get_int_argument("id")
|
|
|
name = self.get_escaped_argument("name", "")
|
|
|
model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法,1002/深度学习
|
|
|
classification = self.get_int_argument("classification")
|
|
|
comment = self.get_escaped_argument("comment", "")
|
|
|
|
|
|
if not name or not classification or model_type not in consts.model_type_map:
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
|
|
|
|
|
|
with self.app_mysql.connect() as conn:
|
|
|
conn.execute(
|
|
|
text(
|
|
|
"""update model
|
|
|
set name=:name, model_type=:model_type, classification=:classification, comment=:comment,
|
|
|
update_time=NOW()
|
|
|
where id=:id"""),
|
|
|
{"name": name, "model_type": model_type, "classification": classification, "comment": comment,
|
|
|
"id": mid}
|
|
|
)
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
self.finish()
|
|
|
|
|
|
|
|
|
class InfoHandler(APIHandler):
|
|
|
"""
|
|
|
模型信息
|
|
|
"""
|
|
|
|
|
|
@authenticated
|
|
|
def post(self):
|
|
|
mid = self.get_int_argument("id")
|
|
|
if not mid:
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
|
|
|
result = {}
|
|
|
with self.app_mysql.connect() as conn:
|
|
|
cur = conn.execute(
|
|
|
text(
|
|
|
"""select m.name, mc.name as classification_name, m.comment, m.update_time
|
|
|
from model m, model_classification mc where m.id=:id and m.classification=c.id"""),
|
|
|
{"id": mid}
|
|
|
)
|
|
|
|
|
|
result = db.to_json(cur)
|
|
|
if not result:
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型不存在")
|
|
|
|
|
|
data = {
|
|
|
"name": result["name"],
|
|
|
"classification_name": result["classification_name"],
|
|
|
"comment": result["comment"],
|
|
|
"update_time": str(result["update_time"])
|
|
|
}
|
|
|
|
|
|
self.finish(data)
|
|
|
|
|
|
|
|
|
class DeleteHandler(APIHandler):
|
|
|
"""
|
|
|
删除模型
|
|
|
"""
|
|
|
|
|
|
@authenticated
|
|
|
def post(self):
|
|
|
mid = self.get_int_argument("id")
|
|
|
if not mid:
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
|
|
|
with self.app_mysql.connect() as conn:
|
|
|
conn.execute(text("update model set del=1 where id=:id"), {"id": mid})
|
|
|
conn.commit()
|
|
|
|
|
|
self.finish()
|
|
|
|
|
|
|
|
|
class VersionAddHandler(APIHandler):
|
|
|
"""
|
|
|
添加模型版本
|
|
|
|
|
|
- 描述: 添加模型版本
|
|
|
- 请求方式:post
|
|
|
- 请求参数:
|
|
|
> - model_id, int, 模型id
|
|
|
> - version, string, 版本
|
|
|
> - comment, string,备注
|
|
|
> - model_file, string, 模型文件的md5, 只允许上传一个文件
|
|
|
> - config_file, string, 模型配置文件的md5
|
|
|
> - config_str, string, 模型配置参数,json格式字符串,eg: '{"a":"xx","b":"xx"}'
|
|
|
- 返回值:无
|
|
|
|
|
|
"""
|
|
|
|
|
|
@authenticated
|
|
|
def post(self):
|
|
|
mid = self.get_int_argument("model_id")
|
|
|
version = self.get_escaped_argument("version", "")
|
|
|
comment = self.get_escaped_argument("comment", "")
|
|
|
model_file = self.get_escaped_argument("model_file", "")
|
|
|
config_file = self.get_escaped_argument("config_file", "")
|
|
|
config_str = self.get_escaped_argument("config_str", "")
|
|
|
if not mid or not version or not model_file:
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
|
|
|
|
|
|
if not md5.md5_validate(model_file):
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误")
|
|
|
if config_file and not md5.md5_validate(config_file):
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误")
|
|
|
|
|
|
with self.app_mysql.connect() as conn:
|
|
|
conn.execute(
|
|
|
text(
|
|
|
"""
|
|
|
insert into model_version
|
|
|
(model_id, version, comment,model_file, config_file, config_str, created_at, update_time)
|
|
|
values (:model_id, :version, :comment, :model_file, :config_file, :config_str, NOW(), NOW())"""
|
|
|
),
|
|
|
|
|
|
{"model_id": mid, "version": version, "comment": comment, "model_file": model_file,
|
|
|
"config_file": config_file, "config_str": config_str})
|
|
|
conn.commit()
|
|
|
self.finish()
|
|
|
|
|
|
|
|
|
class VersionEditHandler(APIHandler):
|
|
|
"""
|
|
|
编辑模型版本
|
|
|
|
|
|
- 描述: 编辑模型版本
|
|
|
- 请求方式:post
|
|
|
- 请求参数:
|
|
|
> - version_id, int, 模型id
|
|
|
> - version, string, 版本
|
|
|
> - comment, string,备注
|
|
|
> - model_file, string, 模型文件的md5, 只允许上传一个文件
|
|
|
> - config_file, string, 模型配置文件的md5
|
|
|
> - config_str, string, 模型配置参数,json格式字符串,eg: '{"a":"xx","b":"xx"}'
|
|
|
- 返回值:无
|
|
|
"""
|
|
|
|
|
|
@authenticated
|
|
|
def post(self):
|
|
|
version_id = self.get_int_argument("version_id")
|
|
|
version = self.get_escaped_argument("version", "")
|
|
|
comment = self.get_escaped_argument("comment", "")
|
|
|
model_file = self.get_escaped_argument("model_file", "")
|
|
|
config_file = self.get_escaped_argument("config_file", "")
|
|
|
config_str = self.get_escaped_argument("config_str", "")
|
|
|
if not version_id or not version or not model_file:
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
|
|
|
|
|
|
if not md5.md5_validate(model_file):
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误")
|
|
|
if config_file and not md5.md5_validate(config_file):
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误")
|
|
|
with self.app_mysql.connect() as conn:
|
|
|
conn.execute(
|
|
|
text(
|
|
|
"update model_version "
|
|
|
"set version=:version, comment=:comment, model_file=:model_file, config_file=:config_file, "
|
|
|
"config_str=:config_str, update_time=NOW() where id=:id"),
|
|
|
{"version": version, "comment": comment, "model_file": model_file, "config_file": config_file,
|
|
|
"config_str": config_str, "id": version_id}
|
|
|
)
|
|
|
self.finish()
|
|
|
|
|
|
|
|
|
class VersionListHandler(APIHandler):
|
|
|
"""
|
|
|
模型版本列表
|
|
|
- 描述: 模型版本列表
|
|
|
- 请求方式:post
|
|
|
- 请求参数:
|
|
|
> - model_id, int, 模型id
|
|
|
> - pageNo, int
|
|
|
> - pageSize, int
|
|
|
- 返回值:
|
|
|
```
|
|
|
{
|
|
|
"count": 123,
|
|
|
"data": [
|
|
|
{
|
|
|
"version_id": 213, # 版本id
|
|
|
"version": "xx", # 版本号
|
|
|
"path": "xxx", # 文件路径
|
|
|
"size": 123,
|
|
|
"update_time": "xxx",
|
|
|
"is_default": 1 # 是否默认,1/默认,0/非默认
|
|
|
},
|
|
|
...
|
|
|
]
|
|
|
}
|
|
|
```
|
|
|
|
|
|
"""
|
|
|
|
|
|
@authenticated
|
|
|
def post(self):
|
|
|
model_id = self.get_int_argument("model_id")
|
|
|
pageNo = self.get_int_argument("pageNo", 1)
|
|
|
pageSize = self.get_int_argument("pageSize", 20)
|
|
|
if not model_id:
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
|
|
|
|
|
|
count = 0
|
|
|
data = []
|
|
|
with self.app_mysql.connect() as conn:
|
|
|
# 查询数据,从model_version, files查询version以及file的信息
|
|
|
cur = conn.execute(
|
|
|
text(
|
|
|
"""
|
|
|
select mv.id as version_id, mv.version, mv.model_file, mv.update_time, mv.is_default, f.filepath, f.filesize
|
|
|
from model_version mv
|
|
|
left join files f on mv.model_file=f.md5
|
|
|
where mv.model_id=:mid and mv.del=0
|
|
|
order by mv.version_id desc limit :offset, :limit
|
|
|
"""
|
|
|
),
|
|
|
{"mid": model_id, "offset": (pageNo - 1) * pageSize, "limit": pageSize}
|
|
|
)
|
|
|
|
|
|
result = db.to_json(cur)
|
|
|
|
|
|
# 获取记录数量
|
|
|
count = conn.execute(
|
|
|
text(
|
|
|
"""
|
|
|
select count(*)
|
|
|
from model_version mv
|
|
|
where mv.del=0
|
|
|
"""
|
|
|
),
|
|
|
{"mid": model_id}
|
|
|
).scalar()
|
|
|
|
|
|
for item in result:
|
|
|
data.append({
|
|
|
"version_id": item["version_id"], # 版本id
|
|
|
"version": item["version"], # 版本号
|
|
|
"path": item["filepath"], # 文件路径
|
|
|
"size": item["filesize"],
|
|
|
"update_time": str(item["update_time"]),
|
|
|
"is_default": item["is_default"] # 是否默认,1/默认,0/非默认
|
|
|
})
|
|
|
|
|
|
self.finish({"count": count, "data": data})
|
|
|
|
|
|
|
|
|
class VersionInfoHandler(APIHandler):
|
|
|
"""
|
|
|
模型版本信息
|
|
|
|
|
|
- 描述: 模型版本信息
|
|
|
- 请求方式:post
|
|
|
- 请求参数:
|
|
|
> - version_id, int, 模型id
|
|
|
- 返回值:
|
|
|
```
|
|
|
{
|
|
|
"model_name": "xxx",
|
|
|
"version": "xxx", # 版本
|
|
|
"comment": "xxx", # 备注
|
|
|
"model_file_name": "xxx", # 模型文件名
|
|
|
"model_file_size": 123, # 模型文件大小
|
|
|
"model_file_md5": "xx", # 模型文件md5
|
|
|
"config_file_name": "xx", # 配置文件名
|
|
|
"config_file_size": "xx", # 配置文件大小
|
|
|
"config_file_md5": "xxx", # 配置文件md5
|
|
|
"config_str": "xxx" # 配置参数
|
|
|
}
|
|
|
```
|
|
|
"""
|
|
|
|
|
|
@authenticated
|
|
|
def post(self):
|
|
|
version_id = self.get_int_argument("version_id")
|
|
|
response = {
|
|
|
"model_name": "",
|
|
|
"version": "",
|
|
|
"comment": "",
|
|
|
"model_file_name": "",
|
|
|
"model_file_size": 0,
|
|
|
"model_file_md5": "",
|
|
|
"config_file_name": "",
|
|
|
"config_file_size": 0,
|
|
|
"config_file_md5": "",
|
|
|
"config_str": ""
|
|
|
}
|
|
|
|
|
|
with self.app_mysql.connect() as conn:
|
|
|
cur = conn.execute(
|
|
|
text(
|
|
|
"""select m.name as model_name, mv.version, mv.comment, mv.model_file, mv.config_file, mv.config_str
|
|
|
from model_version mv, model m where mv.id=:id and mv.model_id=m.id"""),
|
|
|
{"id": version_id}
|
|
|
)
|
|
|
|
|
|
result = db.to_json(cur)
|
|
|
if not result:
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型版本不存在")
|
|
|
|
|
|
model_file = result["model_file"]
|
|
|
config_file = result["config_file"]
|
|
|
|
|
|
response["model_name"] = result["model_name"]
|
|
|
response["version"] = result["version"]
|
|
|
response["comment"] = result["comment"]
|
|
|
response["model_file_md5"] = model_file
|
|
|
response["config_file_md5"] = config_file
|
|
|
|
|
|
# 获取文件信息
|
|
|
if model_file:
|
|
|
cur_model_file = conn.execute(
|
|
|
text("select filename, filesize from files where md5_str=:md5_str"), {"md5_str": model_file}
|
|
|
)
|
|
|
model_file_info = db.to_json(cur_model_file)
|
|
|
response["model_file_name"] = model_file_info["filename"]
|
|
|
response["model_file_size"] = model_file_info["filesize"]
|
|
|
|
|
|
if config_file:
|
|
|
cur_config_file = conn.execute(
|
|
|
text("select filename, filesize from files where md5_str=:md5_str"), {"md5_str": config_file}
|
|
|
)
|
|
|
config_file_info = db.to_json(cur_config_file)
|
|
|
response["config_file_name"] = config_file_info["filename"]
|
|
|
response["config_file_size"] = config_file_info["filesize"]
|
|
|
|
|
|
response["config_str"] = result["config_str"]
|
|
|
|
|
|
self.finish(response)
|
|
|
|
|
|
|
|
|
class VersionSetDefaultHandler(APIHandler):
|
|
|
"""
|
|
|
设置模型版本为默认版本
|
|
|
|
|
|
- 描述: 设置模型默认版本
|
|
|
- 请求方式:post
|
|
|
- 请求参数:
|
|
|
> - version_id, int, 模型版本id
|
|
|
> - model_id, int, 模型id
|
|
|
- 返回值:无
|
|
|
"""
|
|
|
|
|
|
@authenticated
|
|
|
def post(self):
|
|
|
version_id = self.get_int_argument("version_id")
|
|
|
model_id = self.get_int_argument("model_id")
|
|
|
if not version_id or not model_id:
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
|
|
|
|
|
|
with self.app_mysql.connect() as conn:
|
|
|
conn.execute(
|
|
|
text("update model_version set is_default=0 where model_id=:model_id"),
|
|
|
{"model_id": model_id})
|
|
|
|
|
|
conn.execute(
|
|
|
text("update model_version set is_default=1 where id=:id"),
|
|
|
{"id": version_id})
|
|
|
conn.commit()
|
|
|
|
|
|
self.finish()
|
|
|
|
|
|
|
|
|
class VersionDeleteHandler(APIHandler):
|
|
|
"""
|
|
|
删除模型版本
|
|
|
|
|
|
- 描述: 删除模型版本
|
|
|
- 请求方式:post
|
|
|
- 请求参数:
|
|
|
> - version_id, int, 模型版本id
|
|
|
- 返回值:无
|
|
|
"""
|
|
|
|
|
|
@authenticated
|
|
|
def post(self):
|
|
|
version_id = self.get_int_argument("version_id")
|
|
|
if not version_id:
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
|
|
|
|
|
|
row = {}
|
|
|
# 获取模型对应的model_file, config_file,使用model_file, config_file删除对应的存储文件
|
|
|
with self.app_mysql.connect() as conn:
|
|
|
cur = conn.execute(text("select model_id, model_file, config_file from model_version where id=:id"),
|
|
|
{"id": version_id})
|
|
|
row = db.to_json(cur)
|
|
|
|
|
|
if not row:
|
|
|
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型版本不存在")
|
|
|
model_file = row["model_file"]
|
|
|
config_file = row["config_file"]
|
|
|
model_id = row["model_id"]
|
|
|
|
|
|
# 清空模型默认版本
|
|
|
conn.execute(
|
|
|
text("update model set default_version='' where id=:id"), {"id": model_id}
|
|
|
)
|
|
|
|
|
|
# 删除文件
|
|
|
try:
|
|
|
conn.execute(text("delete from files where md5_str=:md5_str"), {"md5_str": model_file})
|
|
|
conn.execute(text("delete from files where md5_str=:md5_str"), {"md5_str": config_file})
|
|
|
|
|
|
os.remove(settings.file_upload_dir + "model/" + model_file)
|
|
|
os.remove(settings.file_upload_dir + "model/" + config_file)
|
|
|
except Exception as e:
|
|
|
logging.info(e)
|
|
|
|
|
|
conn.execute(text("delete from model_version where id=:id"), {"id": version_id})
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
self.finish()
|