@ -6,7 +6,7 @@ from website import errors
from website import settings
from website import consts
from website import db
from website . util import shortuuid, aes
from website . util import md5
from website . handler import APIHandler , authenticated
@ -25,14 +25,32 @@ class ClassificationAddHandler(APIHandler):
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 分类名称过长 " )
with self . app_mysql . connect ( ) as conn :
cur = conn . execute ( text ( " select id from model_classification where name=:name " ) , name = name )
cur = conn . execute ( text ( " select id from model_classification where name=:name " ) , { " name " : name } )
classification_id = cur . fetchone ( ) [ 0 ]
if classification_id :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 分类名称重复 " )
conn . execute ( text ( """
insert into model_classification ( name , created_at ) values ( : name , NOW ( ) ) """ ),
name = name )
conn . execute (
text ( """ insert into model_classification (name, created_at) values (:name, NOW()) """ ) , { " name " : name }
)
self . finish ( )
class ClassificationEditHandler ( APIHandler ) :
"""
编辑模型分类
"""
@authenticated
def post ( self ) :
classification_id = self . get_int_argument ( " id " )
name = self . get_escaped_argument ( " name " , " " )
if not classification_id or not name :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 参数缺失 " )
with self . app_mysql . connect ( ) as conn :
conn . execute ( text ( """ update model_classification set name=:name where id=:id """ ) , { " name " : name , " id " : classification_id } )
conn . commit ( )
self . finish ( )
@ -46,8 +64,7 @@ class ClassificationListHandler(APIHandler):
@authenticated
def post ( self ) :
with self . app_mysql . connect ( ) as conn :
cur = conn . execute ( text ( """
select id , name from model_classification """ ))
cur = conn . execute ( text ( """ select id, name from model_classification """ ) )
result = db . to_json_list ( cur )
self . finish ( { " data " : result } )
@ -63,87 +80,261 @@ class ClassificationDeleteHandler(APIHandler):
classification_id = self . get_int_argument ( " id " )
if not classification_id :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 参数缺失 " )
with self . app_mysql . connect ( ) as conn :
conn . execute ( text ( """
DELETE FROM model_classification WHERE id = : id """ ),
id = classification_id )
conn . execute ( text ( """ DELETE FROM model_classification WHERE id=:id """ ) , { " id " : classification_id } )
conn . commit ( )
self . finish ( )
class ListHandler ( APIHandler ) :
"""
模型列表
"""
@authenticated
def post ( self ) :
self . finish ( )
pageNo = self . get_int_argument ( " pageNo " , 1 )
pageSize = self . get_int_argument ( " pageSize " , 10 )
name = self . get_escaped_argument ( " name " , " " )
result = [ ]
with self . app_mysql . connect ( ) as conn :
sql = " select m.id, m.name, m.model_type, m.default_version, mc.name classification_name, m.update_at " \
" from model m left join model_classification mc on m.classification=mc.id where m.del=0 "
param = { }
sql_count = " select count(id) from model where del=0 "
param_count = { }
if name :
sql + = " and m.name like :name "
param [ " name " ] = " % {} % " . format ( name )
sql_count + = " and m.name like :name "
param_count [ " name " ] = " % {} % " . format ( name )
sql + = " order by m.id desc limit :pageSize offset :offset "
param [ " pageSize " ] = pageSize
param [ " offset " ] = ( pageNo - 1 ) * pageSize
cur = conn . execute ( text ( sql ) , param )
result = db . to_json_list ( cur )
count = conn . execute ( text ( sql_count ) , param_count ) . fetchone ( ) [ 0 ]
data = [ ]
for item in result :
data . append ( {
" id " : item [ " id " ] ,
" name " : item [ " name " ] ,
" model_type " : consts . model_type_map [ item [ " model_type " ] ] ,
" classification_name " : item [ " classification_name " ] ,
" default_version " : item [ " default_version " ] ,
" update_time " : str ( item [ " update_at " ] )
} )
self . finish ( { " data " : data , " count " : count } )
class AddHandler ( APIHandler ) :
"""
添加模型
"""
@authenticated
def post ( self ) :
name = self . get_escaped_argument ( " name " , " " )
model_type = self . get_int_argument ( " model_type " , consts . model_type_machine ) # 1001/经典算法, 1002/深度学习
classification = self . get_int_argument ( " classification " )
comment = self . get_escaped_argument ( " comment " , " " )
if not name or not classification or model_type not in consts . MODEL_TYPE :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 参数缺失 " )
with self . app_mysql . connect ( ) as conn :
sql = text ( " select id from model_classification where id=:id " , { " id " : classification } )
cur = conn . execute ( sql )
if not cur . fetchone ( ) :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 分类不存在 " )
conn . execute (
text ( """ insert into model (name, model_type, classification, comment, created_at) values (:name, :model_type, :classification, :comment, NOW()) """ ) ,
{ " name " : name , " model_type " : model_type , " classification " : classification , " comment " : comment }
)
conn . commit ( )
self . finish ( )
class EditHandler ( APIHandler ) :
"""
编辑模型
"""
@authenticated
def post ( self ) :
mid = self . get_int_argument ( " id " )
name = self . get_escaped_argument ( " name " , " " )
model_type = self . get_int_argument ( " model_type " , consts . model_type_machine ) # 1001/经典算法, 1002/深度学习
classification = self . get_int_argument ( " classification " )
comment = self . get_escaped_argument ( " comment " , " " )
if not name or not classification or model_type not in consts . MODEL_TYPE :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 参数缺失 " )
with self . app_mysql . connect ( ) as conn :
conn . execute (
text ( """ update model set name=:name, model_type=:model_type, classification=:classification, comment=:comment, update_at=NOW() where id=:id """ ) ,
{ " name " : name , " model_type " : model_type , " classification " : classification , " comment " : comment , " id " : mid }
)
conn . commit ( )
self . finish ( )
class InfoHandler ( APIHandler ) :
"""
模型信息
"""
@authenticated
def post ( self ) :
self . finish ( )
mid = self . get_int_argument ( " id " )
if not mid :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 参数缺失 " )
result = { }
with self . app_mysql . connect ( ) as conn :
cur = conn . execute (
text ( """ select m.name, mc.name as classification_name, m.comment, m.update_at from model m, model_classification mc where m.id=:id and m.classification=c.id """ ) ,
{ " id " : mid }
)
result = db . to_json ( cur )
if not result :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 模型不存在 " )
data = {
" name " : result [ " name " ] ,
" classification_name " : result [ " classification_name " ] ,
" comment " : result [ " comment " ] ,
" update_time " : str ( result [ " update_at " ] )
}
self . finish ( data )
class DeleteHandler ( APIHandler ) :
"""
删除模型
"""
@authenticated
def post ( self ) :
mid = self . get_int_argument ( " id " )
if not mid :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 参数缺失 " )
with self . app_mysql . connect ( ) as conn :
conn . execute ( text ( " update model set del=1 where id=:id " ) , { " id " : mid } )
conn . commit ( )
self . finish ( )
class VersionAddHandler ( APIHandler ) :
"""
添加模型版本
- 描述 : 添加模型版本
- 请求方式 : post
- 请求参数 :
> - model_id , int , 模型id
> - version , string , 版本
> - comment , string , 备注
> - model_file , string , 模型文件的md5 , 只允许上传一个文件
> - config_file , string , 模型配置文件的md5
> - config_str , string , 模型配置参数 , json格式字符串 , eg : ' { " a " : " xx " , " b " : " xx " } '
- 返回值 : 无
"""
@authenticated
def post ( self ) :
mid = self . get_int_argument ( " model_id " )
version = self . get_escaped_argument ( " version " , " " )
comment = self . get_escaped_argument ( " comment " , " " )
model_file = self . get_escaped_argument ( " model_file " , " " )
config_file = self . get_escaped_argument ( " config_file " , " " )
config_str = self . get_escaped_argument ( " config_str " , " " )
if not mid or not version or not model_file :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 参数缺失 " )
if not md5 . md5_validate ( model_file ) :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 模型文件格式错误 " )
if config_file and not md5 . md5_validate ( config_file ) :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 模型配置文件格式错误 " )
with self . app_mysql . connect ( ) as conn :
conn . execute (
text (
"""
insert into model_version ( model_id , version , comment , model_file , config_file , config_str , created_at , update_at )
values ( : model_id , : version , : comment , : model_file , : config_file , : config_str , NOW ( ) , NOW ( ) ) """
) ,
{ " model_id " : mid , " version " : version , " comment " : comment , " model_file " : model_file , " config_file " : config_file , " config_str " : config_str } )
conn . commit ( )
self . finish ( )
class VersionEditHandler ( APIHandler ) :
"""
编辑模型版本
- 描述 : 编辑模型版本
- 请求方式 : post
- 请求参数 :
> - version_id , int , 模型id
> - version , string , 版本
> - comment , string , 备注
> - model_file , string , 模型文件的md5 , 只允许上传一个文件
> - config_file , string , 模型配置文件的md5
> - config_str , string , 模型配置参数 , json格式字符串 , eg : ' { " a " : " xx " , " b " : " xx " } '
- 返回值 : 无
"""
@authenticated
def post ( self ) :
version_id = self . get_int_argument ( " version_id " )
version = self . get_escaped_argument ( " version " , " " )
comment = self . get_escaped_argument ( " comment " , " " )
model_file = self . get_escaped_argument ( " model_file " , " " )
config_file = self . get_escaped_argument ( " config_file " , " " )
config_str = self . get_escaped_argument ( " config_str " , " " )
if not version_id or not version or not model_file :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 参数缺失 " )
if not md5 . md5_validate ( model_file ) :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 模型文件格式错误 " )
if config_file and not md5 . md5_validate ( config_file ) :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 模型配置文件格式错误 " )
with self . app_mysql . connect ( ) as conn :
conn . execute (
text ( " update model_version set version=:version, comment=:comment, model_file=:model_file, config_file=:config_file, config_str=:config_str, update_at=NOW() where id=:id " ) ,
{ " version " : version , " comment " : comment , " model_file " : model_file , " config_file " : config_file , " config_str " : config_str , " id " : version_id }
)
self . finish ( )
class VersionListHandler ( APIHandler ) :
"""
模型版本列表
"""
@authenticated
@ -153,7 +344,7 @@ class VersionListHandler(APIHandler):
class VersionInfoHandler ( APIHandler ) :
"""
模型版本信息
"""
@authenticated
@ -163,17 +354,39 @@ class VersionInfoHandler(APIHandler):
class VersionSetDefaultHandler ( APIHandler ) :
"""
设置模型版本为默认版本
- 描述 : 设置模型默认版本
- 请求方式 : post
- 请求参数 :
> - version_id , int , 模型版本id
> - model_id , int , 模型id
- 返回值 : 无
"""
@authenticated
def post ( self ) :
version_id = self . get_int_argument ( " version_id " )
model_id = self . get_int_argument ( " model_id " )
if not version_id or not model_id :
raise errors . HTTPAPIError ( errors . ERROR_BAD_REQUEST , " 参数缺失 " )
with self . app_mysql . connect ( ) as conn :
conn . execute (
text ( " update model_version set is_default=0 where model_id=:model_id " ) ,
{ " model_id " : model_id } )
conn . execute (
text ( " update model_version set is_default=1 where id=:id " ) ,
{ " id " : version_id } )
conn . commit ( )
self . finish ( )
class VersionDeleteHandler ( APIHandler ) :
"""
删除模型版本
"""
@authenticated