@ -1,18 +1,21 @@
# -*- coding: utf-8 -*-
import logging
import os
from sqlalchemy import text
from website import errors
from website import settings
from website import consts
from website import db
from website . util import md5
from website import errors
from website import settings
from website . handler import APIHandler , authenticated
from website . util import md5
class ClassificationAddHandler ( APIHandler ) :
"""
添加模型分类
"""
@authenticated
@ -25,11 +28,11 @@ class ClassificationAddHandler(APIHandler):
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 分类名称过长 " )
with self . app_mysql . connect ( ) as conn :
cur = conn . execute ( text ( " select id from model_classification where name=:name " ) , { " name " : name } )
cur = conn . execute ( text ( " select id from model_classification where name=:name " ) , { " name " : name } )
classification_id = cur . fetchone ( ) [ 0 ]
if classification_id :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 分类名称重复 " )
conn . execute (
text ( """ insert into model_classification (name, created_at) values (:name, NOW()) """ ) , { " name " : name }
)
@ -49,7 +52,8 @@ class ClassificationEditHandler(APIHandler):
if not classification_id or not name :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 参数缺失 " )
with self . app_mysql . connect ( ) as conn :
conn . execute ( text ( """ update model_classification set name=:name where id=:id """ ) , { " name " : name , " id " : classification_id } )
conn . execute ( text ( """ update model_classification set name=:name where id=:id """ ) ,
{ " name " : name , " id " : classification_id } )
conn . commit ( )
@ -80,7 +84,7 @@ class ClassificationDeleteHandler(APIHandler):
classification_id = self . get_int_argument ( " id " )
if not classification_id :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 参数缺失 " )
with self . app_mysql . connect ( ) as conn :
conn . execute ( text ( """ DELETE FROM model_classification WHERE id=:id """ ) , { " id " : classification_id } )
conn . commit ( )
@ -102,7 +106,7 @@ class ListHandler(APIHandler):
result = [ ]
with self . app_mysql . connect ( ) as conn :
sql = " select m.id, m.name, m.model_type, m.default_version, mc.name classification_name, m.update_at " \
" from model m left join model_classification mc on m.classification=mc.id where m.del=0 "
" from model m left join model_classification mc on m.classification=mc.id where m.del=0 "
param = { }
@ -115,7 +119,7 @@ class ListHandler(APIHandler):
sql_count + = " and m.name like :name "
param_count [ " name " ] = " % {} % " . format ( name )
sql + = " order by m.id desc limit :pageSize offset :offset "
param [ " pageSize " ] = pageSize
param [ " offset " ] = ( pageNo - 1 ) * pageSize
@ -135,7 +139,7 @@ class ListHandler(APIHandler):
" default_version " : item [ " default_version " ] ,
" update_time " : str ( item [ " update_at " ] )
} )
self . finish ( { " data " : data , " count " : count } )
@ -147,24 +151,26 @@ class AddHandler(APIHandler):
@authenticated
def post ( self ) :
name = self . get_escaped_argument ( " name " , " " )
model_type = self . get_int_argument ( " model_type " , consts . model_type_machine ) # 1001/经典算法, 1002/深度学习
model_type = self . get_int_argument ( " model_type " , consts . model_type_machine ) # 1001/经典算法, 1002/深度学习
classification = self . get_int_argument ( " classification " )
comment = self . get_escaped_argument ( " comment " , " " )
if not name or not classification or model_type not in consts . MODEL_TYPE :
if not name or not classification or model_type not in consts . model_type_map :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 参数缺失 " )
with self . app_mysql . connect ( ) as conn :
sql = text ( " select id from model_classification where id=:id " , { " id " : classification } )
cur = conn . execute ( sql )
if not cur . fetchone ( ) :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 分类不存在 " )
conn . execute (
text ( """ insert into model (name, model_type, classification, comment, created_at) values (:name, :model_type, :classification, :comment, NOW()) """ ) ,
text (
""" insert into model (name, model_type, classification, comment, created_at)
values ( : name , : model_type , : classification , : comment , NOW ( ) ) """ ),
{ " name " : name , " model_type " : model_type , " classification " : classification , " comment " : comment }
)
)
conn . commit ( )
self . finish ( )
@ -179,18 +185,23 @@ class EditHandler(APIHandler):
def post ( self ) :
mid = self . get_int_argument ( " id " )
name = self . get_escaped_argument ( " name " , " " )
model_type = self . get_int_argument ( " model_type " , consts . model_type_machine ) # 1001/经典算法, 1002/深度学习
model_type = self . get_int_argument ( " model_type " , consts . model_type_machine ) # 1001/经典算法, 1002/深度学习
classification = self . get_int_argument ( " classification " )
comment = self . get_escaped_argument ( " comment " , " " )
if not name or not classification or model_type not in consts . MODEL_TYPE :
if not name or not classification or model_type not in consts . model_type_map :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 参数缺失 " )
with self . app_mysql . connect ( ) as conn :
conn . execute (
text ( """ update model set name=:name, model_type=:model_type, classification=:classification, comment=:comment, update_at=NOW() where id=:id """ ) ,
{ " name " : name , " model_type " : model_type , " classification " : classification , " comment " : comment , " id " : mid }
)
text (
""" update model
set name = : name , model_type = : model_type , classification = : classification , comment = : comment ,
update_at = NOW ( )
where id = : id """ ),
{ " name " : name , " model_type " : model_type , " classification " : classification , " comment " : comment ,
" id " : mid }
)
conn . commit ( )
@ -210,22 +221,23 @@ class InfoHandler(APIHandler):
result = { }
with self . app_mysql . connect ( ) as conn :
cur = conn . execute (
text ( """ select m.name, mc.name as classification_name, m.comment, m.update_at from model m, model_classification mc where m.id=:id and m.classification=c.id """ ) ,
text (
""" select m.name, mc.name as classification_name, m.comment, m.update_at
from model m , model_classification mc where m . id = : id and m . classification = c . id """ ),
{ " id " : mid }
)
)
result = db . to_json ( cur )
if not result :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 模型不存在 " )
data = {
" name " : result [ " name " ] ,
" classification_name " : result [ " classification_name " ] ,
" comment " : result [ " comment " ] ,
" comment " : result [ " comment " ] ,
" update_time " : str ( result [ " update_at " ] )
}
self . finish ( data )
@ -273,7 +285,7 @@ class VersionAddHandler(APIHandler):
config_str = self . get_escaped_argument ( " config_str " , " " )
if not mid or not version or not model_file :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 参数缺失 " )
if not md5 . md5_validate ( model_file ) :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 模型文件格式错误 " )
if config_file and not md5 . md5_validate ( config_file ) :
@ -283,11 +295,13 @@ class VersionAddHandler(APIHandler):
conn . execute (
text (
"""
insert into model_version ( model_id , version , comment , model_file , config_file , config_str , created_at , update_at )
insert into model_version
( model_id , version , comment , model_file , config_file , config_str , created_at , update_at )
values ( : model_id , : version , : comment , : model_file , : config_file , : config_str , NOW ( ) , NOW ( ) ) """
) ,
{ " model_id " : mid , " version " : version , " comment " : comment , " model_file " : model_file , " config_file " : config_file , " config_str " : config_str } )
) ,
{ " model_id " : mid , " version " : version , " comment " : comment , " model_file " : model_file ,
" config_file " : config_file , " config_str " : config_str } )
conn . commit ( )
self . finish ( )
@ -318,38 +332,186 @@ class VersionEditHandler(APIHandler):
config_str = self . get_escaped_argument ( " config_str " , " " )
if not version_id or not version or not model_file :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 参数缺失 " )
if not md5 . md5_validate ( model_file ) :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 模型文件格式错误 " )
if config_file and not md5 . md5_validate ( config_file ) :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 模型配置文件格式错误 " )
with self . app_mysql . connect ( ) as conn :
conn . execute (
text ( " update model_version set version=:version, comment=:comment, model_file=:model_file, config_file=:config_file, config_str=:config_str, update_at=NOW() where id=:id " ) ,
{ " version " : version , " comment " : comment , " model_file " : model_file , " config_file " : config_file , " config_str " : config_str , " id " : version_id }
)
text (
" update model_version "
" set version=:version, comment=:comment, model_file=:model_file, config_file=:config_file, "
" config_str=:config_str, update_at=NOW() where id=:id " ) ,
{ " version " : version , " comment " : comment , " model_file " : model_file , " config_file " : config_file ,
" config_str " : config_str , " id " : version_id }
)
self . finish ( )
class VersionListHandler ( APIHandler ) :
"""
模型版本列表
- 描述 : 模型版本列表
- 请求方式 : post
- 请求参数 :
> - model_id , int , 模型id
> - pageNo , int
> - pageSize , int
- 返回值 :
` ` `
{
" count " : 123 ,
" data " : [
{
" version_id " : 213 , # 版本id
" version " : " xx " , # 版本号
" path " : " xxx " , # 文件路径
" size " : 123 ,
" update_time " : " xxx " ,
" is_default " : 1 # 是否默认, 1/默认, 0/非默认
} ,
. . .
]
}
` ` `
"""
@authenticated
def post ( self ) :
self . finish ( )
model_id = self . get_int_argument ( " model_id " )
pageNo = self . get_int_argument ( " pageNo " , 1 )
pageSize = self . get_int_argument ( " pageSize " , 20 )
if not model_id :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 参数缺失 " )
count = 0
data = [ ]
with self . app_mysql . connect ( ) as conn :
# 查询数据, 从model_version, files查询version以及file的信息
cur = conn . execute (
text (
"""
select mv . id as version_id , mv . version , mv . model_file , mv . update_at , mv . is_default , f . filepath , f . filesize
from model_version mv
left join files f on mv . model_file = f . md5
where mv . model_id = : mid and mv . del = 0
order by mv . version_id desc limit : offset , : limit
"""
) ,
{ " mid " : model_id , " offset " : ( pageNo - 1 ) * pageSize , " limit " : pageSize }
)
result = db . to_json ( cur )
# 获取记录数量
count = conn . execute (
text (
"""
select count ( * )
from model_version mv
where mv . del = 0
"""
) ,
{ " mid " : model_id }
) . scalar ( )
for item in result :
data . append ( {
" version_id " : item [ " version_id " ] , # 版本id
" version " : item [ " version " ] , # 版本号
" path " : item [ " filepath " ] , # 文件路径
" size " : item [ " filesize " ] ,
" update_time " : str ( item [ " update_at " ] ) ,
" is_default " : item [ " is_default " ] # 是否默认, 1/默认, 0/非默认
} )
self . finish ( { " count " : count , " data " : data } )
class VersionInfoHandler ( APIHandler ) :
"""
模型版本信息
- 描述 : 模型版本信息
- 请求方式 : post
- 请求参数 :
> - version_id , int , 模型id
- 返回值 :
` ` `
{
" model_name " : " xxx " ,
" version " : " xxx " , # 版本
" comment " : " xxx " , # 备注
" model_file_name " : " xxx " , # 模型文件名
" model_file_size " : 123 , # 模型文件大小
" model_file_md5 " : " xx " , # 模型文件md5
" config_file_name " : " xx " , # 配置文件名
" config_file_size " : " xx " , # 配置文件大小
" config_file_md5 " : " xxx " , # 配置文件md5
" config_str " : " xxx " # 配置参数
}
` ` `
"""
@authenticated
def post ( self ) :
self . finish ( )
version_id = self . get_int_argument ( " version_id " )
response = {
" model_name " : " " ,
" version " : " " ,
" comment " : " " ,
" model_file_name " : " " ,
" model_file_size " : 0 ,
" model_file_md5 " : " " ,
" config_file_name " : " " ,
" config_file_size " : 0 ,
" config_file_md5 " : " " ,
" config_str " : " "
}
with self . app_mysql . connect ( ) as conn :
cur = conn . execute (
text (
""" select m.name as model_name, mv.version, mv.comment, mv.model_file, mv.config_file, mv.config_str
from model_version mv , model m where mv . id = : id and mv . model_id = m . id """ ),
{ " id " : version_id }
)
result = db . to_json ( cur )
if not result :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 模型版本不存在 " )
model_file = result [ " model_file " ]
config_file = result [ " config_file " ]
response [ " model_name " ] = result [ " model_name " ]
response [ " version " ] = result [ " version " ]
response [ " comment " ] = result [ " comment " ]
response [ " model_file_md5 " ] = model_file
response [ " config_file_md5 " ] = config_file
# 获取文件信息
if model_file :
cur_model_file = conn . execute (
text ( " select filename, filesize from files where md5_str=:md5_str " ) , { " md5_str " : model_file }
)
model_file_info = db . to_json ( cur_model_file )
response [ " model_file_name " ] = model_file_info [ " filename " ]
response [ " model_file_size " ] = model_file_info [ " filesize " ]
if config_file :
cur_config_file = conn . execute (
text ( " select filename, filesize from files where md5_str=:md5_str " ) , { " md5_str " : config_file }
)
config_file_info = db . to_json ( cur_config_file )
response [ " config_file_name " ] = config_file_info [ " filename " ]
response [ " config_file_size " ] = config_file_info [ " filesize " ]
response [ " config_str " ] = result [ " config_str " ]
self . finish ( response )
class VersionSetDefaultHandler ( APIHandler ) :
@ -370,14 +532,14 @@ class VersionSetDefaultHandler(APIHandler):
model_id = self . get_int_argument ( " model_id " )
if not version_id or not model_id :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 参数缺失 " )
with self . app_mysql . connect ( ) as conn :
conn . execute (
text ( " update model_version set is_default=0 where model_id=:model_id " ) ,
text ( " update model_version set is_default=0 where model_id=:model_id " ) ,
{ " model_id " : model_id } )
conn . execute (
text ( " update model_version set is_default=1 where id=:id " ) ,
text ( " update model_version set is_default=1 where id=:id " ) ,
{ " id " : version_id } )
conn . commit ( )
@ -387,8 +549,50 @@ class VersionSetDefaultHandler(APIHandler):
class VersionDeleteHandler ( APIHandler ) :
"""
删除模型版本
- 描述 : 删除模型版本
- 请求方式 : post
- 请求参数 :
> - version_id , int , 模型版本id
- 返回值 : 无
"""
@authenticated
def post ( self ) :
version_id = self . get_int_argument ( " version_id " )
if not version_id :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 参数缺失 " )
row = { }
# 获取模型对应的model_file, config_file, 使用model_file, config_file删除对应的存储文件
with self . app_mysql . connect ( ) as conn :
cur = conn . execute ( text ( " select model_id, model_file, config_file from model_version where id=:id " ) ,
{ " id " : version_id } )
row = db . to_json ( cur )
if not row :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 模型版本不存在 " )
model_file = row [ " model_file " ]
config_file = row [ " config_file " ]
model_id = row [ " model_id " ]
# 清空模型默认版本
conn . execute (
text ( " update model set default_version= ' ' where id=:id " ) , { " id " : model_id }
)
# 删除文件
try :
conn . execute ( text ( " delete from files where md5_str=:md5_str " ) , { " md5_str " : model_file } )
conn . execute ( text ( " delete from files where md5_str=:md5_str " ) , { " md5_str " : config_file } )
os . remove ( settings . file_upload_dir + " model/ " + model_file )
os . remove ( settings . file_upload_dir + " model/ " + config_file )
except Exception as e :
logging . info ( e )
conn . execute ( text ( " delete from model_version where id=:id " ) , { " id " : version_id } )
conn . commit ( )
self . finish ( )