You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

738 lines
25 KiB
Python

# -*- coding: utf-8 -*-
import logging
import os
from sqlalchemy import text
from website import consts
12 months ago
from website import db_mysql
from website import errors
from website import settings
10 months ago
from website.handler import APIHandler, authenticated, operation_log
11 months ago
from website.util import md5, shortuuid
class ClassificationAddHandler(APIHandler):
"""
添加模型分类
11 months ago
"""
@authenticated
10 months ago
@operation_log("模型管理", "模型列表", consts.op_type_add_str, "添加模型分类", "")
def post(self):
name = self.get_escaped_argument("name", "")
if not name:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if len(name) > 128:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称过长")
with self.app_mysql.connect() as conn:
11 months ago
cur = conn.execute(
text("select id from model_classification where name=:name"),
{"name": name},
)
11 months ago
row = db_mysql.to_json(cur)
if row:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称重复")
conn.execute(
11 months ago
text(
"""insert into model_classification (name, create_time) values (:name, NOW())"""
),
{"name": name},
)
conn.commit()
self.finish()
class ClassificationEditHandler(APIHandler):
"""
编辑模型分类
"""
@authenticated
10 months ago
@operation_log("模型管理", "模型列表", consts.op_type_edit_str, "编辑模型分类", "")
def post(self):
classification_id = self.get_int_argument("id")
name = self.get_escaped_argument("name", "")
if not classification_id or not name:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
11 months ago
conn.execute(
text("""update model_classification set name=:name where id=:id"""),
{"name": name, "id": classification_id},
)
conn.commit()
self.finish()
class ClassificationListHandler(APIHandler):
"""
模型分类列表
"""
@authenticated
10 months ago
@operation_log("模型管理", "模型列表", consts.op_type_list_str, "查询模型分类", "")
def post(self):
with self.app_mysql.connect() as conn:
cur = conn.execute(text("""select id, name from model_classification"""))
12 months ago
result = db_mysql.to_json_list(cur)
self.finish({"data": result})
class ClassificationDeleteHandler(APIHandler):
"""
删除模型分类
"""
@authenticated
10 months ago
@operation_log("模型管理", "模型列表", consts.op_type_delete_str, "删除模型分类", "")
def post(self):
classification_id = self.get_int_argument("id")
if not classification_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
11 months ago
conn.execute(
text("""DELETE FROM model_classification WHERE id=:id"""),
{"id": classification_id},
)
conn.commit()
self.finish()
9 months ago
class ClassificationEditHandler(APIHandler):
"""
编辑模型分类
"""
@authenticated
@operation_log("模型管理", "模型列表", consts.op_type_edit_str, "编辑模型分类", "")
def post(self):
classification_id = self.get_int_argument("id")
name = self.get_escaped_argument("name", "")
if not classification_id or not name:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(
text("""update model_classification set name=:name where id=:id"""),
{"name": name, "id": classification_id},
)
conn.commit()
self.finish()
class ListHandler(APIHandler):
"""
模型列表
"""
@authenticated
10 months ago
@operation_log("模型管理", "模型列表", consts.op_type_list_str, "查询模型列表", "")
def post(self):
pageNo = self.get_int_argument("pageNo", 1)
pageSize = self.get_int_argument("pageSize", consts.PAGE_SIZE)
name = self.get_escaped_argument("name", "")
result = []
with self.app_mysql.connect() as conn:
11 months ago
sql = (
"select m.id, m.name, m.model_type, m.default_version, mc.name classification_name, m.update_time "
"from model m left join model_classification mc on m.classification=mc.id where m.del=0 "
)
param = {}
10 months ago
sql_count = "select count(id) from model m where m.del=0 "
param_count = {}
if name:
sql += "and m.name like :name"
param["name"] = "%{}%".format(name)
sql_count += "and m.name like :name"
param_count["name"] = "%{}%".format(name)
sql += " order by m.id desc limit :pageSize offset :offset"
param["pageSize"] = pageSize
param["offset"] = (pageNo - 1) * pageSize
cur = conn.execute(text(sql), param)
12 months ago
result = db_mysql.to_json_list(cur)
count = conn.execute(text(sql_count), param_count).fetchone()[0]
data = []
for item in result:
11 months ago
data.append(
{
"id": item["id"],
"name": item["name"],
"model_type": consts.model_type_map[item["model_type"]],
"classification_name": item["classification_name"],
"default_version": item["default_version"],
"update_time": str(item["update_time"]),
}
)
self.finish({"data": data, "count": count})
11 months ago
class ListSimpleHandler(APIHandler):
@authenticated
10 months ago
@operation_log("模型管理", "模型列表", consts.op_type_list_str, "查询模型列表", "")
11 months ago
def post(self):
with self.app_mysql.connect() as conn:
sql = "select id, name from model where del=0"
cur = conn.execute(text(sql))
res = db_mysql.to_json_list(cur)
self.finish({"result": res})
class AddHandler(APIHandler):
"""
添加模型
"""
@authenticated
10 months ago
@operation_log("模型管理", "添加模型", consts.op_type_add_str, "添加模型", "")
def post(self):
name = self.get_escaped_argument("name", "")
11 months ago
model_type = self.get_int_argument(
"model_type", consts.model_type_machine
) # 1001/经典算法1002/深度学习
classification = self.get_int_argument("classification")
comment = self.get_escaped_argument("comment", "")
if not name or not classification or model_type not in consts.model_type_map:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
sql = text("select id from model_classification where id=:id")
cur = conn.execute(sql, {"id": classification})
if not cur.fetchone():
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类不存在")
conn.execute(
text(
11 months ago
"""insert into model (suid, name, model_type, classification, comment, create_time, update_time)
values (:suid, :name, :model_type, :classification, :comment, NOW(), NOW())"""
),
{
"suid": shortuuid.ShortUUID().random(10),
"name": name,
"model_type": model_type,
"classification": classification,
"comment": comment,
},
)
conn.commit()
self.finish()
class EditHandler(APIHandler):
"""
编辑模型
"""
@authenticated
10 months ago
@operation_log("模型管理", "编辑模型", consts.op_type_edit_str, "编辑模型", "")
def post(self):
mid = self.get_int_argument("id")
name = self.get_escaped_argument("name", "")
11 months ago
model_type = self.get_int_argument(
"model_type", consts.model_type_machine
) # 1001/经典算法1002/深度学习
classification = self.get_int_argument("classification")
comment = self.get_escaped_argument("comment", "")
if not name or not classification or model_type not in consts.model_type_map:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(
text(
"""update model
set name=:name, model_type=:model_type, classification=:classification, comment=:comment,
update_time=NOW()
11 months ago
where id=:id"""
),
{
"name": name,
"model_type": model_type,
"classification": classification,
"comment": comment,
"id": mid,
},
)
conn.commit()
self.finish()
class InfoHandler(APIHandler):
"""
模型信息
"""
@authenticated
10 months ago
@operation_log("模型管理", "模型信息", consts.op_type_list_str, "查询模型信息", "")
def post(self):
mid = self.get_int_argument("id")
if not mid:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
result = {}
with self.app_mysql.connect() as conn:
cur = conn.execute(
11 months ago
text(
"""
select
m.name, m.model_type, m.default_version, m.comment, m.update_time,
11 months ago
mc.id as classification_id, mc.name as classification_name
9 months ago
from model m
left join model_classification mc
on m.classification=mc.id
where m.id=:id
11 months ago
"""
),
{"id": mid},
)
12 months ago
result = db_mysql.to_json(cur)
if not result:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型不存在")
data = {
"name": result["name"],
"model_type": result["model_type"],
"default_version": result["default_version"],
"classification_id": result["classification_id"],
"classification_name": result["classification_name"],
"comment": result["comment"],
11 months ago
"update_time": str(result["update_time"]),
}
self.finish(data)
class DeleteHandler(APIHandler):
"""
删除模型
"""
@authenticated
10 months ago
@operation_log("模型管理", "模型列表", consts.op_type_delete_str, "删除模型", "")
def post(self):
mid = self.get_int_argument("id")
if not mid:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(text("update model set del=1 where id=:id"), {"id": mid})
conn.commit()
self.finish()
class VersionAddHandler(APIHandler):
"""
添加模型版本
- 描述 添加模型版本
- 请求方式post
- 请求参数
> - model_id, int, 模型id
> - version, string, 版本
> - comment, string备注
> - model_file, string, 模型文件的md5, 只允许上传一个文件
> - config_file, string, 模型配置文件的md5
> - config_str, string, 模型配置参数json格式字符串eg: '{"a":"xx","b":"xx"}'
- 返回值
"""
@authenticated
10 months ago
@operation_log("模型管理", "模型版本", consts.op_type_add_str, "添加模型版本", "")
def post(self):
mid = self.get_int_argument("model_id")
version = self.get_escaped_argument("version", "")
comment = self.get_escaped_argument("comment", "")
model_file = self.get_escaped_argument("model_file", "")
config_file = self.get_escaped_argument("config_file", "")
config_str = self.get_escaped_argument("config_str", "")
if not mid or not version or not model_file:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if not md5.md5_validate(model_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误")
if config_file and not md5.md5_validate(config_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误")
with self.app_mysql.connect() as conn:
conn.execute(
text(
"""
insert into model_version
11 months ago
(model_id, version, comment,model_file, config_file, config_str, create_time, update_time)
values (:model_id, :version, :comment, :model_file, :config_file, :config_str, NOW(), NOW())"""
),
11 months ago
{
"model_id": mid,
"version": version,
"comment": comment,
"model_file": model_file,
"config_file": config_file,
"config_str": config_str,
},
)
conn.commit()
self.finish()
class VersionEditHandler(APIHandler):
"""
编辑模型版本
- 描述 编辑模型版本
- 请求方式post
- 请求参数
> - version_id, int, 模型id
> - version, string, 版本
> - comment, string备注
> - model_file, string, 模型文件的md5, 只允许上传一个文件
> - config_file, string, 模型配置文件的md5
> - config_str, string, 模型配置参数json格式字符串eg: '{"a":"xx","b":"xx"}'
- 返回值
"""
@authenticated
10 months ago
@operation_log("模型管理", "模型版本", consts.op_type_edit_str, "编辑模型版本", "")
def post(self):
version_id = self.get_int_argument("version_id")
version = self.get_escaped_argument("version", "")
comment = self.get_escaped_argument("comment", "")
model_file = self.get_escaped_argument("model_file", "")
config_file = self.get_escaped_argument("config_file", "")
config_str = self.get_escaped_argument("config_str", "")
if not version_id or not version or not model_file:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if not md5.md5_validate(model_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误")
if config_file and not md5.md5_validate(config_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误")
with self.app_mysql.connect() as conn:
conn.execute(
text(
"update model_version "
"set version=:version, comment=:comment, model_file=:model_file, config_file=:config_file, "
11 months ago
"config_str=:config_str, update_time=NOW() where id=:id"
),
{
"version": version,
"comment": comment,
"model_file": model_file,
"config_file": config_file,
"config_str": config_str,
"id": version_id,
},
)
self.finish()
class VersionListHandler(APIHandler):
"""
模型版本列表
- 描述 模型版本列表
- 请求方式post
- 请求参数
> - model_id, int, 模型id
> - pageNo, int
> - pageSize, int
- 返回值
```
11 months ago
{
"count": 123,
"data": [
{
"version_id": 213, # 版本id
"version": "xx", # 版本号
"path": "xxx", # 文件路径
"size": 123,
"update_time": "xxx",
"is_default": 1 # 是否默认1/默认0/非默认
},
...
]
}
```
"""
@authenticated
10 months ago
@operation_log("模型管理", "模型版本", consts.op_type_list_str, "模型版本列表", "")
def post(self):
model_id = self.get_int_argument("model_id")
pageNo = self.get_int_argument("pageNo", 1)
pageSize = self.get_int_argument("pageSize", 20)
if not model_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
count = 0
data = []
with self.app_mysql.connect() as conn:
# 查询数据从model_version, files查询version以及file的信息
cur = conn.execute(
text(
"""
select mv.id as version_id, mv.version, mv.model_file, mv.update_time, mv.is_default, f.filepath, f.filesize
from model_version mv
left join files f on mv.model_file=f.md5_str
where mv.model_id=:mid and mv.del=0
order by mv.id desc limit :offset, :limit
"""
),
11 months ago
{"mid": model_id, "offset": (pageNo - 1) * pageSize, "limit": pageSize},
)
11 months ago
result = db_mysql.to_json_list(cur)
# 获取记录数量
count = conn.execute(
text(
"""
select count(*)
from model_version mv
11 months ago
where mv.del=0 and model_id=:mid
"""
),
11 months ago
{"mid": model_id},
).scalar()
for item in result:
11 months ago
data.append(
{
"version_id": item["version_id"], # 版本id
"version": item["version"], # 版本号
"path": item["filepath"], # 文件路径
"size": item["filesize"],
"update_time": str(item["update_time"]),
"is_default": item["is_default"], # 是否默认1/默认0/非默认
}
)
self.finish({"count": count, "data": data})
class VersionInfoHandler(APIHandler):
"""
模型版本信息
- 描述 模型版本信息
- 请求方式post
- 请求参数
> - version_id, int, 模型id
- 返回值
```
{
"model_name": "xxx",
"version": "xxx", # 版本
"comment": "xxx", # 备注
"model_file_name": "xxx", # 模型文件名
"model_file_size": 123, # 模型文件大小
"model_file_md5": "xx", # 模型文件md5
"config_file_name": "xx", # 配置文件名
"config_file_size": "xx", # 配置文件大小
"config_file_md5": "xxx", # 配置文件md5
"config_str": "xxx" # 配置参数
}
```
"""
@authenticated
10 months ago
@operation_log("模型管理", "模型版本", consts.op_type_list_str, "查询模型版本详情", "")
def post(self):
version_id = self.get_int_argument("version_id")
response = {
"model_name": "",
"version": "",
"comment": "",
"model_file_name": "",
"model_file_size": 0,
"model_file_md5": "",
"config_file_name": "",
"config_file_size": 0,
"config_file_md5": "",
11 months ago
"config_str": "",
}
with self.app_mysql.connect() as conn:
cur = conn.execute(
text(
"""select m.name as model_name, mv.version, mv.comment, mv.model_file, mv.config_file, mv.config_str
11 months ago
from model_version mv, model m where mv.id=:id and mv.model_id=m.id"""
),
{"id": version_id},
)
12 months ago
result = db_mysql.to_json(cur)
if not result:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型版本不存在")
model_file = result["model_file"]
config_file = result["config_file"]
response["model_name"] = result["model_name"]
response["version"] = result["version"]
response["comment"] = result["comment"]
response["model_file_md5"] = model_file
response["config_file_md5"] = config_file
# 获取文件信息
if model_file:
cur_model_file = conn.execute(
11 months ago
text("select filename, filesize from files where md5_str=:md5_str"),
{"md5_str": model_file},
)
12 months ago
model_file_info = db_mysql.to_json(cur_model_file)
11 months ago
response["model_file_name"] = (
model_file_info["filename"] if model_file_info else ""
)
response["model_file_size"] = (
model_file_info["filesize"] if model_file_info else 0
)
if config_file:
cur_config_file = conn.execute(
11 months ago
text("select filename, filesize from files where md5_str=:md5_str"),
{"md5_str": config_file},
)
12 months ago
config_file_info = db_mysql.to_json(cur_config_file)
11 months ago
response["config_file_name"] = (
config_file_info["filename"] if config_file_info else ""
)
response["config_file_size"] = (
config_file_info["filesize"] if config_file_info else 0
)
11 months ago
response["config_str"] = self.unescape_string(result["config_str"])
self.finish(response)
class VersionSetDefaultHandler(APIHandler):
"""
设置模型版本为默认版本
- 描述 设置模型默认版本
- 请求方式post
- 请求参数
> - version_id, int, 模型版本id
> - model_id, int, 模型id
- 返回值
"""
@authenticated
10 months ago
@operation_log("模型管理", "模型版本", consts.op_type_edit_str, "设置模型版本为默认版本", "")
def post(self):
version_id = self.get_int_argument("version_id")
model_id = self.get_int_argument("model_id")
if not version_id or not model_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(
text("update model_version set is_default=0 where model_id=:model_id"),
11 months ago
{"model_id": model_id},
)
conn.execute(
text("update model_version set is_default=1 where id=:id"),
11 months ago
{"id": version_id},
)
conn.execute(
text("update model m set m.default_version=(select version from model_version v where v.id=:vid) "
"where m.id=:mid"),
{"vid": version_id, "mid": model_id}
)
conn.commit()
self.finish()
class VersionDeleteHandler(APIHandler):
"""
删除模型版本
- 描述 删除模型版本
- 请求方式post
- 请求参数
> - version_id, int, 模型版本id
- 返回值
"""
@authenticated
10 months ago
@operation_log("模型管理", "模型版本", consts.op_type_delete_str, "删除模型版本", "")
def post(self):
version_id = self.get_int_argument("version_id")
if not version_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
row = {}
# 获取模型对应的model_file, config_file使用model_file, config_file删除对应的存储文件
with self.app_mysql.connect() as conn:
11 months ago
cur = conn.execute(
text(
"select model_id, model_file, config_file from model_version where id=:id"
),
{"id": version_id},
)
12 months ago
row = db_mysql.to_json(cur)
if not row:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型版本不存在")
model_file = row["model_file"]
config_file = row["config_file"]
model_id = row["model_id"]
# 清空模型默认版本
conn.execute(
11 months ago
text("update model set default_version='' where id=:id"),
{"id": model_id},
)
# 删除文件
try:
11 months ago
conn.execute(
text("delete from files where md5_str=:md5_str"),
{"md5_str": model_file},
)
conn.execute(
text("delete from files where md5_str=:md5_str"),
{"md5_str": config_file},
)
os.remove(settings.file_upload_dir + "model/" + model_file)
os.remove(settings.file_upload_dir + "model/" + config_file)
except Exception as e:
logging.info(e)
11 months ago
conn.execute(
text("delete from model_version where id=:id"), {"id": version_id}
)
conn.commit()
self.finish()