You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

698 lines
23 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
import logging
import os
from sqlalchemy import text
from website import consts
from website import db_mysql
from website import errors
from website import settings
from website.handler import APIHandler, authenticated
from website.util import md5, shortuuid
class ClassificationAddHandler(APIHandler):
"""
添加模型分类
"""
@authenticated
def post(self):
name = self.get_escaped_argument("name", "")
if not name:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if len(name) > 128:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称过长")
with self.app_mysql.connect() as conn:
cur = conn.execute(
text("select id from model_classification where name=:name"),
{"name": name},
)
row = db_mysql.to_json(cur)
if row:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称重复")
conn.execute(
text(
"""insert into model_classification (name, create_time) values (:name, NOW())"""
),
{"name": name},
)
conn.commit()
self.finish()
class ClassificationEditHandler(APIHandler):
"""
编辑模型分类
"""
@authenticated
def post(self):
classification_id = self.get_int_argument("id")
name = self.get_escaped_argument("name", "")
if not classification_id or not name:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(
text("""update model_classification set name=:name where id=:id"""),
{"name": name, "id": classification_id},
)
conn.commit()
self.finish()
class ClassificationListHandler(APIHandler):
"""
模型分类列表
"""
@authenticated
def post(self):
with self.app_mysql.connect() as conn:
cur = conn.execute(text("""select id, name from model_classification"""))
result = db_mysql.to_json_list(cur)
self.finish({"data": result})
class ClassificationDeleteHandler(APIHandler):
"""
删除模型分类
"""
@authenticated
def post(self):
classification_id = self.get_int_argument("id")
if not classification_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(
text("""DELETE FROM model_classification WHERE id=:id"""),
{"id": classification_id},
)
conn.commit()
self.finish()
class ListHandler(APIHandler):
"""
模型列表
"""
@authenticated
def post(self):
pageNo = self.get_int_argument("pageNo", 1)
pageSize = self.get_int_argument("pageSize", consts.PAGE_SIZE)
name = self.get_escaped_argument("name", "")
result = []
with self.app_mysql.connect() as conn:
sql = (
"select m.id, m.name, m.model_type, m.default_version, mc.name classification_name, m.update_time "
"from model m left join model_classification mc on m.classification=mc.id where m.del=0 "
)
param = {}
sql_count = "select count(id) from model where del=0 "
param_count = {}
if name:
sql += "and m.name like :name"
param["name"] = "%{}%".format(name)
sql_count += "and m.name like :name"
param_count["name"] = "%{}%".format(name)
sql += " order by m.id desc limit :pageSize offset :offset"
param["pageSize"] = pageSize
param["offset"] = (pageNo - 1) * pageSize
cur = conn.execute(text(sql), param)
result = db_mysql.to_json_list(cur)
count = conn.execute(text(sql_count), param_count).fetchone()[0]
data = []
for item in result:
data.append(
{
"id": item["id"],
"name": item["name"],
"model_type": consts.model_type_map[item["model_type"]],
"classification_name": item["classification_name"],
"default_version": item["default_version"],
"update_time": str(item["update_time"]),
}
)
self.finish({"data": data, "count": count})
class ListSimpleHandler(APIHandler):
@authenticated
def post(self):
with self.app_mysql.connect() as conn:
sql = "select id, name from model where del=0"
cur = conn.execute(text(sql))
res = db_mysql.to_json_list(cur)
self.finish({"result": res})
class AddHandler(APIHandler):
"""
添加模型
"""
@authenticated
def post(self):
name = self.get_escaped_argument("name", "")
model_type = self.get_int_argument(
"model_type", consts.model_type_machine
) # 1001/经典算法1002/深度学习
classification = self.get_int_argument("classification")
comment = self.get_escaped_argument("comment", "")
if not name or not classification or model_type not in consts.model_type_map:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
sql = text("select id from model_classification where id=:id")
cur = conn.execute(sql, {"id": classification})
if not cur.fetchone():
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类不存在")
conn.execute(
text(
"""insert into model (suid, name, model_type, classification, comment, create_time, update_time)
values (:suid, :name, :model_type, :classification, :comment, NOW(), NOW())"""
),
{
"suid": shortuuid.ShortUUID().random(10),
"name": name,
"model_type": model_type,
"classification": classification,
"comment": comment,
},
)
conn.commit()
self.finish()
class EditHandler(APIHandler):
"""
编辑模型
"""
@authenticated
def post(self):
mid = self.get_int_argument("id")
name = self.get_escaped_argument("name", "")
model_type = self.get_int_argument(
"model_type", consts.model_type_machine
) # 1001/经典算法1002/深度学习
classification = self.get_int_argument("classification")
comment = self.get_escaped_argument("comment", "")
if not name or not classification or model_type not in consts.model_type_map:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(
text(
"""update model
set name=:name, model_type=:model_type, classification=:classification, comment=:comment,
update_time=NOW()
where id=:id"""
),
{
"name": name,
"model_type": model_type,
"classification": classification,
"comment": comment,
"id": mid,
},
)
conn.commit()
self.finish()
class InfoHandler(APIHandler):
"""
模型信息
"""
@authenticated
def post(self):
mid = self.get_int_argument("id")
if not mid:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
result = {}
with self.app_mysql.connect() as conn:
cur = conn.execute(
text(
"""
select
m.name, m.model_type, m.default_version, m.comment, m.update_time,
mc.id as classification_id, mc.name as classification_name
from model m, model_classification mc
where m.id=:id and m.classification=mc.id
"""
),
{"id": mid},
)
result = db_mysql.to_json(cur)
if not result:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型不存在")
data = {
"name": result["name"],
"model_type": result["model_type"],
"default_version": result["default_version"],
"classification_id": result["classification_id"],
"classification_name": result["classification_name"],
"comment": result["comment"],
"update_time": str(result["update_time"]),
}
self.finish(data)
class DeleteHandler(APIHandler):
"""
删除模型
"""
@authenticated
def post(self):
mid = self.get_int_argument("id")
if not mid:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(text("update model set del=1 where id=:id"), {"id": mid})
conn.commit()
self.finish()
class VersionAddHandler(APIHandler):
"""
添加模型版本
- 描述: 添加模型版本
- 请求方式post
- 请求参数:
> - model_id, int, 模型id
> - version, string, 版本
> - comment, string备注
> - model_file, string, 模型文件的md5, 只允许上传一个文件
> - config_file, string, 模型配置文件的md5
> - config_str, string, 模型配置参数json格式字符串eg: '{"a":"xx","b":"xx"}'
- 返回值:无
"""
@authenticated
def post(self):
mid = self.get_int_argument("model_id")
version = self.get_escaped_argument("version", "")
comment = self.get_escaped_argument("comment", "")
model_file = self.get_escaped_argument("model_file", "")
config_file = self.get_escaped_argument("config_file", "")
config_str = self.get_escaped_argument("config_str", "")
if not mid or not version or not model_file:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if not md5.md5_validate(model_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误")
if config_file and not md5.md5_validate(config_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误")
with self.app_mysql.connect() as conn:
conn.execute(
text(
"""
insert into model_version
(model_id, version, comment,model_file, config_file, config_str, create_time, update_time)
values (:model_id, :version, :comment, :model_file, :config_file, :config_str, NOW(), NOW())"""
),
{
"model_id": mid,
"version": version,
"comment": comment,
"model_file": model_file,
"config_file": config_file,
"config_str": config_str,
},
)
conn.commit()
self.finish()
class VersionEditHandler(APIHandler):
"""
编辑模型版本
- 描述: 编辑模型版本
- 请求方式post
- 请求参数:
> - version_id, int, 模型id
> - version, string, 版本
> - comment, string备注
> - model_file, string, 模型文件的md5, 只允许上传一个文件
> - config_file, string, 模型配置文件的md5
> - config_str, string, 模型配置参数json格式字符串eg: '{"a":"xx","b":"xx"}'
- 返回值:无
"""
@authenticated
def post(self):
version_id = self.get_int_argument("version_id")
version = self.get_escaped_argument("version", "")
comment = self.get_escaped_argument("comment", "")
model_file = self.get_escaped_argument("model_file", "")
config_file = self.get_escaped_argument("config_file", "")
config_str = self.get_escaped_argument("config_str", "")
if not version_id or not version or not model_file:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if not md5.md5_validate(model_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误")
if config_file and not md5.md5_validate(config_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误")
with self.app_mysql.connect() as conn:
conn.execute(
text(
"update model_version "
"set version=:version, comment=:comment, model_file=:model_file, config_file=:config_file, "
"config_str=:config_str, update_time=NOW() where id=:id"
),
{
"version": version,
"comment": comment,
"model_file": model_file,
"config_file": config_file,
"config_str": config_str,
"id": version_id,
},
)
self.finish()
class VersionListHandler(APIHandler):
"""
模型版本列表
- 描述: 模型版本列表
- 请求方式post
- 请求参数:
> - model_id, int, 模型id
> - pageNo, int
> - pageSize, int
- 返回值:
```
{
"count": 123,
"data": [
{
"version_id": 213, # 版本id
"version": "xx", # 版本号
"path": "xxx", # 文件路径
"size": 123,
"update_time": "xxx",
"is_default": 1 # 是否默认1/默认0/非默认
},
...
]
}
```
"""
@authenticated
def post(self):
model_id = self.get_int_argument("model_id")
pageNo = self.get_int_argument("pageNo", 1)
pageSize = self.get_int_argument("pageSize", 20)
if not model_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
count = 0
data = []
with self.app_mysql.connect() as conn:
# 查询数据从model_version, files查询version以及file的信息
cur = conn.execute(
text(
"""
select mv.id as version_id, mv.version, mv.model_file, mv.update_time, mv.is_default, f.filepath, f.filesize
from model_version mv
left join files f on mv.model_file=f.md5_str
where mv.model_id=:mid and mv.del=0
order by mv.id desc limit :offset, :limit
"""
),
{"mid": model_id, "offset": (pageNo - 1) * pageSize, "limit": pageSize},
)
result = db_mysql.to_json_list(cur)
# 获取记录数量
count = conn.execute(
text(
"""
select count(*)
from model_version mv
where mv.del=0 and model_id=:mid
"""
),
{"mid": model_id},
).scalar()
for item in result:
data.append(
{
"version_id": item["version_id"], # 版本id
"version": item["version"], # 版本号
"path": item["filepath"], # 文件路径
"size": item["filesize"],
"update_time": str(item["update_time"]),
"is_default": item["is_default"], # 是否默认1/默认0/非默认
}
)
self.finish({"count": count, "data": data})
class VersionInfoHandler(APIHandler):
"""
模型版本信息
- 描述: 模型版本信息
- 请求方式post
- 请求参数:
> - version_id, int, 模型id
- 返回值:
```
{
"model_name": "xxx",
"version": "xxx", # 版本
"comment": "xxx", # 备注
"model_file_name": "xxx", # 模型文件名
"model_file_size": 123, # 模型文件大小
"model_file_md5": "xx", # 模型文件md5
"config_file_name": "xx", # 配置文件名
"config_file_size": "xx", # 配置文件大小
"config_file_md5": "xxx", # 配置文件md5
"config_str": "xxx" # 配置参数
}
```
"""
@authenticated
def post(self):
version_id = self.get_int_argument("version_id")
response = {
"model_name": "",
"version": "",
"comment": "",
"model_file_name": "",
"model_file_size": 0,
"model_file_md5": "",
"config_file_name": "",
"config_file_size": 0,
"config_file_md5": "",
"config_str": "",
}
with self.app_mysql.connect() as conn:
cur = conn.execute(
text(
"""select m.name as model_name, mv.version, mv.comment, mv.model_file, mv.config_file, mv.config_str
from model_version mv, model m where mv.id=:id and mv.model_id=m.id"""
),
{"id": version_id},
)
result = db_mysql.to_json(cur)
if not result:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型版本不存在")
model_file = result["model_file"]
config_file = result["config_file"]
response["model_name"] = result["model_name"]
response["version"] = result["version"]
response["comment"] = result["comment"]
response["model_file_md5"] = model_file
response["config_file_md5"] = config_file
# 获取文件信息
if model_file:
cur_model_file = conn.execute(
text("select filename, filesize from files where md5_str=:md5_str"),
{"md5_str": model_file},
)
model_file_info = db_mysql.to_json(cur_model_file)
response["model_file_name"] = (
model_file_info["filename"] if model_file_info else ""
)
response["model_file_size"] = (
model_file_info["filesize"] if model_file_info else 0
)
if config_file:
cur_config_file = conn.execute(
text("select filename, filesize from files where md5_str=:md5_str"),
{"md5_str": config_file},
)
config_file_info = db_mysql.to_json(cur_config_file)
response["config_file_name"] = (
config_file_info["filename"] if config_file_info else ""
)
response["config_file_size"] = (
config_file_info["filesize"] if config_file_info else 0
)
response["config_str"] = self.unescape_string(result["config_str"])
self.finish(response)
class VersionSetDefaultHandler(APIHandler):
"""
设置模型版本为默认版本
- 描述: 设置模型默认版本
- 请求方式post
- 请求参数:
> - version_id, int, 模型版本id
> - model_id, int, 模型id
- 返回值:无
"""
@authenticated
def post(self):
version_id = self.get_int_argument("version_id")
model_id = self.get_int_argument("model_id")
if not version_id or not model_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(
text("update model_version set is_default=0 where model_id=:model_id"),
{"model_id": model_id},
)
conn.execute(
text("update model_version set is_default=1 where id=:id"),
{"id": version_id},
)
conn.execute(
text("update model m set m.default_version=(select version from model_version v where v.id=:vid) "
"where m.id=:mid"),
{"vid": version_id, "mid": model_id}
)
conn.commit()
self.finish()
class VersionDeleteHandler(APIHandler):
"""
删除模型版本
- 描述: 删除模型版本
- 请求方式post
- 请求参数:
> - version_id, int, 模型版本id
- 返回值:无
"""
@authenticated
def post(self):
version_id = self.get_int_argument("version_id")
if not version_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
row = {}
# 获取模型对应的model_file, config_file使用model_file, config_file删除对应的存储文件
with self.app_mysql.connect() as conn:
cur = conn.execute(
text(
"select model_id, model_file, config_file from model_version where id=:id"
),
{"id": version_id},
)
row = db_mysql.to_json(cur)
if not row:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型版本不存在")
model_file = row["model_file"]
config_file = row["config_file"]
model_id = row["model_id"]
# 清空模型默认版本
conn.execute(
text("update model set default_version='' where id=:id"),
{"id": model_id},
)
# 删除文件
try:
conn.execute(
text("delete from files where md5_str=:md5_str"),
{"md5_str": model_file},
)
conn.execute(
text("delete from files where md5_str=:md5_str"),
{"md5_str": config_file},
)
os.remove(settings.file_upload_dir + "model/" + model_file)
os.remove(settings.file_upload_dir + "model/" + config_file)
except Exception as e:
logging.info(e)
conn.execute(
text("delete from model_version where id=:id"), {"id": version_id}
)
conn.commit()
self.finish()