599 lines
20 KiB
Python

# -*- coding: utf-8 -*-
import logging
import os
from sqlalchemy import text
from website import consts
from website import db_mysql
from website import errors
from website import settings
from website.handler import APIHandler, authenticated
from website.util import md5
class ClassificationAddHandler(APIHandler):
"""
添加模型分类
"""
@authenticated
def post(self):
name = self.get_escaped_argument("name", "")
if not name:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if len(name) > 128:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称过长")
with self.app_mysql.connect() as conn:
cur = conn.execute(text("select id from model_classification where name=:name"), {"name": name})
classification_id = cur.fetchone()[0]
if classification_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称重复")
conn.execute(
text("""insert into model_classification (name, created_at) values (:name, NOW())"""), {"name": name}
)
self.finish()
class ClassificationEditHandler(APIHandler):
"""
编辑模型分类
"""
@authenticated
def post(self):
classification_id = self.get_int_argument("id")
name = self.get_escaped_argument("name", "")
if not classification_id or not name:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(text("""update model_classification set name=:name where id=:id"""),
{"name": name, "id": classification_id})
conn.commit()
self.finish()
class ClassificationListHandler(APIHandler):
"""
模型分类列表
"""
@authenticated
def post(self):
with self.app_mysql.connect() as conn:
cur = conn.execute(text("""select id, name from model_classification"""))
result = db_mysql.to_json_list(cur)
self.finish({"data": result})
class ClassificationDeleteHandler(APIHandler):
"""
删除模型分类
"""
@authenticated
def post(self):
classification_id = self.get_int_argument("id")
if not classification_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(text("""DELETE FROM model_classification WHERE id=:id"""), {"id": classification_id})
conn.commit()
self.finish()
class ListHandler(APIHandler):
"""
模型列表
"""
@authenticated
def post(self):
pageNo = self.get_int_argument("pageNo", 1)
pageSize = self.get_int_argument("pageSize", consts.PAGE_SIZE)
name = self.get_escaped_argument("name", "")
result = []
with self.app_mysql.connect() as conn:
sql = "select m.id, m.name, m.model_type, m.default_version, mc.name classification_name, m.update_time " \
"from model m left join model_classification mc on m.classification=mc.id where m.del=0 "
param = {}
sql_count = "select count(id) from model where del=0 "
param_count = {}
if name:
sql += "and m.name like :name"
param["name"] = "%{}%".format(name)
sql_count += "and m.name like :name"
param_count["name"] = "%{}%".format(name)
sql += " order by m.id desc limit :pageSize offset :offset"
param["pageSize"] = pageSize
param["offset"] = (pageNo - 1) * pageSize
cur = conn.execute(text(sql), param)
result = db_mysql.to_json_list(cur)
count = conn.execute(text(sql_count), param_count).fetchone()[0]
data = []
for item in result:
data.append({
"id": item["id"],
"name": item["name"],
"model_type": consts.model_type_map[item["model_type"]],
"classification_name": item["classification_name"],
"default_version": item["default_version"],
"update_time": str(item["update_time"])
})
self.finish({"data": data, "count": count})
class AddHandler(APIHandler):
"""
添加模型
"""
@authenticated
def post(self):
name = self.get_escaped_argument("name", "")
model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法1002/深度学习
classification = self.get_int_argument("classification")
comment = self.get_escaped_argument("comment", "")
if not name or not classification or model_type not in consts.model_type_map:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
sql = text("select id from model_classification where id=:id", {"id": classification})
cur = conn.execute(sql)
if not cur.fetchone():
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类不存在")
conn.execute(
text(
"""insert into model (name, model_type, classification, comment, created_at)
values (:name, :model_type, :classification, :comment, NOW())"""),
{"name": name, "model_type": model_type, "classification": classification, "comment": comment}
)
conn.commit()
self.finish()
class EditHandler(APIHandler):
"""
编辑模型
"""
@authenticated
def post(self):
mid = self.get_int_argument("id")
name = self.get_escaped_argument("name", "")
model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法1002/深度学习
classification = self.get_int_argument("classification")
comment = self.get_escaped_argument("comment", "")
if not name or not classification or model_type not in consts.model_type_map:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(
text(
"""update model
set name=:name, model_type=:model_type, classification=:classification, comment=:comment,
update_time=NOW()
where id=:id"""),
{"name": name, "model_type": model_type, "classification": classification, "comment": comment,
"id": mid}
)
conn.commit()
self.finish()
class InfoHandler(APIHandler):
"""
模型信息
"""
@authenticated
def post(self):
mid = self.get_int_argument("id")
if not mid:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
result = {}
with self.app_mysql.connect() as conn:
cur = conn.execute(
text(
"""select m.name, mc.name as classification_name, m.comment, m.update_time
from model m, model_classification mc where m.id=:id and m.classification=c.id"""),
{"id": mid}
)
result = db_mysql.to_json(cur)
if not result:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型不存在")
data = {
"name": result["name"],
"classification_name": result["classification_name"],
"comment": result["comment"],
"update_time": str(result["update_time"])
}
self.finish(data)
class DeleteHandler(APIHandler):
"""
删除模型
"""
@authenticated
def post(self):
mid = self.get_int_argument("id")
if not mid:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(text("update model set del=1 where id=:id"), {"id": mid})
conn.commit()
self.finish()
class VersionAddHandler(APIHandler):
"""
添加模型版本
- 描述 添加模型版本
- 请求方式post
- 请求参数
> - model_id, int, 模型id
> - version, string, 版本
> - comment, string备注
> - model_file, string, 模型文件的md5, 只允许上传一个文件
> - config_file, string, 模型配置文件的md5
> - config_str, string, 模型配置参数json格式字符串eg: '{"a":"xx","b":"xx"}'
- 返回值
"""
@authenticated
def post(self):
mid = self.get_int_argument("model_id")
version = self.get_escaped_argument("version", "")
comment = self.get_escaped_argument("comment", "")
model_file = self.get_escaped_argument("model_file", "")
config_file = self.get_escaped_argument("config_file", "")
config_str = self.get_escaped_argument("config_str", "")
if not mid or not version or not model_file:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if not md5.md5_validate(model_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误")
if config_file and not md5.md5_validate(config_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误")
with self.app_mysql.connect() as conn:
conn.execute(
text(
"""
insert into model_version
(model_id, version, comment,model_file, config_file, config_str, created_at, update_time)
values (:model_id, :version, :comment, :model_file, :config_file, :config_str, NOW(), NOW())"""
),
{"model_id": mid, "version": version, "comment": comment, "model_file": model_file,
"config_file": config_file, "config_str": config_str})
conn.commit()
self.finish()
class VersionEditHandler(APIHandler):
"""
编辑模型版本
- 描述 编辑模型版本
- 请求方式post
- 请求参数
> - version_id, int, 模型id
> - version, string, 版本
> - comment, string备注
> - model_file, string, 模型文件的md5, 只允许上传一个文件
> - config_file, string, 模型配置文件的md5
> - config_str, string, 模型配置参数json格式字符串eg: '{"a":"xx","b":"xx"}'
- 返回值
"""
@authenticated
def post(self):
version_id = self.get_int_argument("version_id")
version = self.get_escaped_argument("version", "")
comment = self.get_escaped_argument("comment", "")
model_file = self.get_escaped_argument("model_file", "")
config_file = self.get_escaped_argument("config_file", "")
config_str = self.get_escaped_argument("config_str", "")
if not version_id or not version or not model_file:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if not md5.md5_validate(model_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误")
if config_file and not md5.md5_validate(config_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误")
with self.app_mysql.connect() as conn:
conn.execute(
text(
"update model_version "
"set version=:version, comment=:comment, model_file=:model_file, config_file=:config_file, "
"config_str=:config_str, update_time=NOW() where id=:id"),
{"version": version, "comment": comment, "model_file": model_file, "config_file": config_file,
"config_str": config_str, "id": version_id}
)
self.finish()
class VersionListHandler(APIHandler):
"""
模型版本列表
- 描述 模型版本列表
- 请求方式post
- 请求参数
> - model_id, int, 模型id
> - pageNo, int
> - pageSize, int
- 返回值
```
{
"count": 123,
"data": [
{
"version_id": 213, # 版本id
"version": "xx", # 版本号
"path": "xxx", # 文件路径
"size": 123,
"update_time": "xxx",
"is_default": 1 # 是否默认1/默认0/非默认
},
...
]
}
```
"""
@authenticated
def post(self):
model_id = self.get_int_argument("model_id")
pageNo = self.get_int_argument("pageNo", 1)
pageSize = self.get_int_argument("pageSize", 20)
if not model_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
count = 0
data = []
with self.app_mysql.connect() as conn:
# 查询数据从model_version, files查询version以及file的信息
cur = conn.execute(
text(
"""
select mv.id as version_id, mv.version, mv.model_file, mv.update_time, mv.is_default, f.filepath, f.filesize
from model_version mv
left join files f on mv.model_file=f.md5
where mv.model_id=:mid and mv.del=0
order by mv.version_id desc limit :offset, :limit
"""
),
{"mid": model_id, "offset": (pageNo - 1) * pageSize, "limit": pageSize}
)
result = db_mysql.to_json(cur)
# 获取记录数量
count = conn.execute(
text(
"""
select count(*)
from model_version mv
where mv.del=0
"""
),
{"mid": model_id}
).scalar()
for item in result:
data.append({
"version_id": item["version_id"], # 版本id
"version": item["version"], # 版本号
"path": item["filepath"], # 文件路径
"size": item["filesize"],
"update_time": str(item["update_time"]),
"is_default": item["is_default"] # 是否默认1/默认0/非默认
})
self.finish({"count": count, "data": data})
class VersionInfoHandler(APIHandler):
"""
模型版本信息
- 描述 模型版本信息
- 请求方式post
- 请求参数
> - version_id, int, 模型id
- 返回值
```
{
"model_name": "xxx",
"version": "xxx", # 版本
"comment": "xxx", # 备注
"model_file_name": "xxx", # 模型文件名
"model_file_size": 123, # 模型文件大小
"model_file_md5": "xx", # 模型文件md5
"config_file_name": "xx", # 配置文件名
"config_file_size": "xx", # 配置文件大小
"config_file_md5": "xxx", # 配置文件md5
"config_str": "xxx" # 配置参数
}
```
"""
@authenticated
def post(self):
version_id = self.get_int_argument("version_id")
response = {
"model_name": "",
"version": "",
"comment": "",
"model_file_name": "",
"model_file_size": 0,
"model_file_md5": "",
"config_file_name": "",
"config_file_size": 0,
"config_file_md5": "",
"config_str": ""
}
with self.app_mysql.connect() as conn:
cur = conn.execute(
text(
"""select m.name as model_name, mv.version, mv.comment, mv.model_file, mv.config_file, mv.config_str
from model_version mv, model m where mv.id=:id and mv.model_id=m.id"""),
{"id": version_id}
)
result = db_mysql.to_json(cur)
if not result:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型版本不存在")
model_file = result["model_file"]
config_file = result["config_file"]
response["model_name"] = result["model_name"]
response["version"] = result["version"]
response["comment"] = result["comment"]
response["model_file_md5"] = model_file
response["config_file_md5"] = config_file
# 获取文件信息
if model_file:
cur_model_file = conn.execute(
text("select filename, filesize from files where md5_str=:md5_str"), {"md5_str": model_file}
)
model_file_info = db_mysql.to_json(cur_model_file)
response["model_file_name"] = model_file_info["filename"]
response["model_file_size"] = model_file_info["filesize"]
if config_file:
cur_config_file = conn.execute(
text("select filename, filesize from files where md5_str=:md5_str"), {"md5_str": config_file}
)
config_file_info = db_mysql.to_json(cur_config_file)
response["config_file_name"] = config_file_info["filename"]
response["config_file_size"] = config_file_info["filesize"]
response["config_str"] = result["config_str"]
self.finish(response)
class VersionSetDefaultHandler(APIHandler):
"""
设置模型版本为默认版本
- 描述 设置模型默认版本
- 请求方式post
- 请求参数
> - version_id, int, 模型版本id
> - model_id, int, 模型id
- 返回值
"""
@authenticated
def post(self):
version_id = self.get_int_argument("version_id")
model_id = self.get_int_argument("model_id")
if not version_id or not model_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(
text("update model_version set is_default=0 where model_id=:model_id"),
{"model_id": model_id})
conn.execute(
text("update model_version set is_default=1 where id=:id"),
{"id": version_id})
conn.commit()
self.finish()
class VersionDeleteHandler(APIHandler):
"""
删除模型版本
- 描述 删除模型版本
- 请求方式post
- 请求参数
> - version_id, int, 模型版本id
- 返回值
"""
@authenticated
def post(self):
version_id = self.get_int_argument("version_id")
if not version_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
row = {}
# 获取模型对应的model_file, config_file使用model_file, config_file删除对应的存储文件
with self.app_mysql.connect() as conn:
cur = conn.execute(text("select model_id, model_file, config_file from model_version where id=:id"),
{"id": version_id})
row = db_mysql.to_json(cur)
if not row:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型版本不存在")
model_file = row["model_file"]
config_file = row["config_file"]
model_id = row["model_id"]
# 清空模型默认版本
conn.execute(
text("update model set default_version='' where id=:id"), {"id": model_id}
)
# 删除文件
try:
conn.execute(text("delete from files where md5_str=:md5_str"), {"md5_str": model_file})
conn.execute(text("delete from files where md5_str=:md5_str"), {"md5_str": config_file})
os.remove(settings.file_upload_dir + "model/" + model_file)
os.remove(settings.file_upload_dir + "model/" + config_file)
except Exception as e:
logging.info(e)
conn.execute(text("delete from model_version where id=:id"), {"id": version_id})
conn.commit()
self.finish()