模型版本的部分功能

main
周平 1 year ago
parent eee871a2da
commit 933b1878af

@ -19,4 +19,9 @@ industry_map = {
1014: u"其他行业", 1014: u"其他行业",
} }
model_type_classic = 1001
model_type_machine = 1002
model_type_map = {
model_type_classic: u"经典算法",
model_type_machine: u"机器学习",
}

@ -2,6 +2,8 @@ import itertools
import contextlib import contextlib
import logging import logging
from typing import List, Any, Optional
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine from sqlalchemy import create_engine
@ -19,7 +21,7 @@ class Row(dict):
raise AttributeError(name) raise AttributeError(name)
def to_json_list(cursor): def to_json_list(cursor: Any) -> Optional[List[Row]]:
column_names = list(cursor.keys()) column_names = list(cursor.keys())
result = cursor.fetchall() result = cursor.fetchall()
if not result: if not result:

@ -42,17 +42,16 @@ class UploadHandler(APIHandler):
cur = conn.execute(sql, {"md5_str": md5_str}) cur = conn.execute(sql, {"md5_str": md5_str})
row = cur.fetchone() row = cur.fetchone()
if not row: if not row:
filepath = os.path.join(settings.file_upload_dir, md5_str + '_' + filename) filepath = os.path.join(settings.file_upload_dir, md5_str + '_' + filename)
if not os.path.exists(filepath): if not os.path.exists(filepath):
for meta in file_metas: for meta in file_metas:
# filename = meta['filename'] # filename = meta['filename']
with open(filepath, 'wb') as f: with open(filepath, 'wb') as f:
f.write(meta['body']) f.write(meta['body'])
with self.app_mysql.connect() as conn: sql_insert = text("insert into files(filename, filepath, md5_str, filesize, filetype, user) values(:filename, :filepath, :md5_str, :file_size, :filetype, :user)")
sql = text("insert into files(filename, filepath, md5_str, filetype, user) values(:filename, :filepath, :md5_str, :filetype, :user)") conn.execute(sql_insert, {"filename": filename, "filepath": filepath, "md5_str": md5_str, "file_size": int(file_size/1024/1024), "filetype": filetype, "user": self.current_user.id})
conn.execute(sql, {"filename": filename, "filepath": filepath, "md5_str": md5_str, "filetype": filetype, "user": self.current_user.id})
conn.commit() conn.commit()
self.finish({"result": md5_str}) self.finish({"result": md5_str})

@ -6,7 +6,7 @@ from website import errors
from website import settings from website import settings
from website import consts from website import consts
from website import db from website import db
from website.util import shortuuid, aes from website.util import md5
from website.handler import APIHandler, authenticated from website.handler import APIHandler, authenticated
@ -25,14 +25,32 @@ class ClassificationAddHandler(APIHandler):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称过长") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称过长")
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
cur = conn.execute(text("select id from model_classification where name=:name"), name=name) cur = conn.execute(text("select id from model_classification where name=:name"), {"name":name})
classification_id = cur.fetchone()[0] classification_id = cur.fetchone()[0]
if classification_id: if classification_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称重复") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称重复")
conn.execute(text(""" conn.execute(
insert into model_classification (name, created_at) values (:name, NOW())"""), text("""insert into model_classification (name, created_at) values (:name, NOW())"""), {"name": name}
name=name) )
self.finish()
class ClassificationEditHandler(APIHandler):
"""
编辑模型分类
"""
@authenticated
def post(self):
classification_id = self.get_int_argument("id")
name = self.get_escaped_argument("name", "")
if not classification_id or not name:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(text("""update model_classification set name=:name where id=:id"""), {"name": name, "id": classification_id})
conn.commit() conn.commit()
self.finish() self.finish()
@ -46,8 +64,7 @@ class ClassificationListHandler(APIHandler):
@authenticated @authenticated
def post(self): def post(self):
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
cur = conn.execute(text(""" cur = conn.execute(text("""select id, name from model_classification"""))
select id, name from model_classification"""))
result = db.to_json_list(cur) result = db.to_json_list(cur)
self.finish({"data": result}) self.finish({"data": result})
@ -63,87 +80,261 @@ class ClassificationDeleteHandler(APIHandler):
classification_id = self.get_int_argument("id") classification_id = self.get_int_argument("id")
if not classification_id: if not classification_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
conn.execute(text(""" conn.execute(text("""DELETE FROM model_classification WHERE id=:id"""), {"id": classification_id})
DELETE FROM model_classification WHERE id=:id"""), conn.commit()
id=classification_id)
self.finish() self.finish()
class ListHandler(APIHandler): class ListHandler(APIHandler):
""" """
模型列表
""" """
@authenticated @authenticated
def post(self): def post(self):
self.finish() pageNo = self.get_int_argument("pageNo", 1)
pageSize = self.get_int_argument("pageSize", 10)
name = self.get_escaped_argument("name", "")
result = []
with self.app_mysql.connect() as conn:
sql = "select m.id, m.name, m.model_type, m.default_version, mc.name classification_name, m.update_at " \
"from model m left join model_classification mc on m.classification=mc.id where m.del=0 "
param = {}
sql_count = "select count(id) from model where del=0 "
param_count = {}
if name:
sql += "and m.name like :name"
param["name"] = "%{}%".format(name)
sql_count += "and m.name like :name"
param_count["name"] = "%{}%".format(name)
sql += " order by m.id desc limit :pageSize offset :offset"
param["pageSize"] = pageSize
param["offset"] = (pageNo - 1) * pageSize
cur = conn.execute(text(sql), param)
result = db.to_json_list(cur)
count = conn.execute(text(sql_count), param_count).fetchone()[0]
data = []
for item in result:
data.append({
"id": item["id"],
"name": item["name"],
"model_type": consts.model_type_map[item["model_type"]],
"classification_name": item["classification_name"],
"default_version": item["default_version"],
"update_time": str(item["update_at"])
})
self.finish({"data": data, "count": count})
class AddHandler(APIHandler): class AddHandler(APIHandler):
""" """
添加模型
""" """
@authenticated @authenticated
def post(self): def post(self):
name = self.get_escaped_argument("name", "")
model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法1002/深度学习
classification = self.get_int_argument("classification")
comment = self.get_escaped_argument("comment", "")
if not name or not classification or model_type not in consts.MODEL_TYPE:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
sql = text("select id from model_classification where id=:id", {"id": classification})
cur = conn.execute(sql)
if not cur.fetchone():
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类不存在")
conn.execute(
text("""insert into model (name, model_type, classification, comment, created_at) values (:name, :model_type, :classification, :comment, NOW())"""),
{"name": name, "model_type": model_type, "classification": classification, "comment": comment}
)
conn.commit()
self.finish() self.finish()
class EditHandler(APIHandler): class EditHandler(APIHandler):
""" """
编辑模型
""" """
@authenticated @authenticated
def post(self): def post(self):
mid = self.get_int_argument("id")
name = self.get_escaped_argument("name", "")
model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法1002/深度学习
classification = self.get_int_argument("classification")
comment = self.get_escaped_argument("comment", "")
if not name or not classification or model_type not in consts.MODEL_TYPE:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(
text("""update model set name=:name, model_type=:model_type, classification=:classification, comment=:comment, update_at=NOW() where id=:id"""),
{"name": name, "model_type": model_type, "classification": classification, "comment": comment, "id": mid}
)
conn.commit()
self.finish() self.finish()
class InfoHandler(APIHandler): class InfoHandler(APIHandler):
""" """
模型信息
""" """
@authenticated @authenticated
def post(self): def post(self):
self.finish() mid = self.get_int_argument("id")
if not mid:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
result = {}
with self.app_mysql.connect() as conn:
cur = conn.execute(
text("""select m.name, mc.name as classification_name, m.comment, m.update_at from model m, model_classification mc where m.id=:id and m.classification=c.id"""),
{"id": mid}
)
result = db.to_json(cur)
if not result:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型不存在")
data = {
"name": result["name"],
"classification_name": result["classification_name"],
"comment": result["comment"],
"update_time": str(result["update_at"])
}
self.finish(data)
class DeleteHandler(APIHandler): class DeleteHandler(APIHandler):
""" """
删除模型
""" """
@authenticated @authenticated
def post(self): def post(self):
mid = self.get_int_argument("id")
if not mid:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(text("update model set del=1 where id=:id"), {"id": mid})
conn.commit()
self.finish() self.finish()
class VersionAddHandler(APIHandler): class VersionAddHandler(APIHandler):
""" """
添加模型版本
- 描述 添加模型版本
- 请求方式post
- 请求参数
> - model_id, int, 模型id
> - version, string, 版本
> - comment, string备注
> - model_file, string, 模型文件的md5, 只允许上传一个文件
> - config_file, string, 模型配置文件的md5
> - config_str, string, 模型配置参数json格式字符串eg: '{"a":"xx","b":"xx"}'
- 返回值
""" """
@authenticated @authenticated
def post(self): def post(self):
mid = self.get_int_argument("model_id")
version = self.get_escaped_argument("version", "")
comment = self.get_escaped_argument("comment", "")
model_file = self.get_escaped_argument("model_file", "")
config_file = self.get_escaped_argument("config_file", "")
config_str = self.get_escaped_argument("config_str", "")
if not mid or not version or not model_file:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if not md5.md5_validate(model_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误")
if config_file and not md5.md5_validate(config_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误")
with self.app_mysql.connect() as conn:
conn.execute(
text(
"""
insert into model_version (model_id, version, comment,model_file, config_file, config_str, created_at, update_at)
values (:model_id, :version, :comment, :model_file, :config_file, :config_str, NOW(), NOW())"""
),
{"model_id": mid, "version": version, "comment": comment, "model_file": model_file, "config_file": config_file, "config_str": config_str})
conn.commit()
self.finish() self.finish()
class VersionEditHandler(APIHandler): class VersionEditHandler(APIHandler):
""" """
编辑模型版本
- 描述 编辑模型版本
- 请求方式post
- 请求参数
> - version_id, int, 模型id
> - version, string, 版本
> - comment, string备注
> - model_file, string, 模型文件的md5, 只允许上传一个文件
> - config_file, string, 模型配置文件的md5
> - config_str, string, 模型配置参数json格式字符串eg: '{"a":"xx","b":"xx"}'
- 返回值
""" """
@authenticated @authenticated
def post(self): def post(self):
version_id = self.get_int_argument("version_id")
version = self.get_escaped_argument("version", "")
comment = self.get_escaped_argument("comment", "")
model_file = self.get_escaped_argument("model_file", "")
config_file = self.get_escaped_argument("config_file", "")
config_str = self.get_escaped_argument("config_str", "")
if not version_id or not version or not model_file:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if not md5.md5_validate(model_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误")
if config_file and not md5.md5_validate(config_file):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误")
with self.app_mysql.connect() as conn:
conn.execute(
text("update model_version set version=:version, comment=:comment, model_file=:model_file, config_file=:config_file, config_str=:config_str, update_at=NOW() where id=:id"),
{"version": version, "comment": comment, "model_file": model_file, "config_file": config_file, "config_str": config_str, "id": version_id}
)
self.finish() self.finish()
class VersionListHandler(APIHandler): class VersionListHandler(APIHandler):
""" """
模型版本列表
""" """
@authenticated @authenticated
@ -153,7 +344,7 @@ class VersionListHandler(APIHandler):
class VersionInfoHandler(APIHandler): class VersionInfoHandler(APIHandler):
""" """
模型版本信息
""" """
@authenticated @authenticated
@ -163,17 +354,39 @@ class VersionInfoHandler(APIHandler):
class VersionSetDefaultHandler(APIHandler): class VersionSetDefaultHandler(APIHandler):
""" """
设置模型版本为默认版本
- 描述 设置模型默认版本
- 请求方式post
- 请求参数
> - version_id, int, 模型版本id
> - model_id, int, 模型id
- 返回值
""" """
@authenticated @authenticated
def post(self): def post(self):
version_id = self.get_int_argument("version_id")
model_id = self.get_int_argument("model_id")
if not version_id or not model_id:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
with self.app_mysql.connect() as conn:
conn.execute(
text("update model_version set is_default=0 where model_id=:model_id"),
{"model_id": model_id})
conn.execute(
text("update model_version set is_default=1 where id=:id"),
{"id": version_id})
conn.commit()
self.finish() self.finish()
class VersionDeleteHandler(APIHandler): class VersionDeleteHandler(APIHandler):
""" """
删除模型版本
""" """
@authenticated @authenticated

@ -0,0 +1,10 @@
import re
def md5_validate(v: str) -> bool:
md5_pattern = re.compile(r'^[a-fA-F0-9]{32}$')
# 匹配字符串
if md5_pattern.match(v):
return True
return False
Loading…
Cancel
Save