diff --git a/website/consts.py b/website/consts.py index 940a84e..971885d 100644 --- a/website/consts.py +++ b/website/consts.py @@ -19,4 +19,9 @@ industry_map = { 1014: u"其他行业", } - +model_type_classic = 1001 +model_type_machine = 1002 +model_type_map = { + model_type_classic: u"经典算法", + model_type_machine: u"机器学习", +} \ No newline at end of file diff --git a/website/db.py b/website/db.py index 85b5f9a..bc1b602 100644 --- a/website/db.py +++ b/website/db.py @@ -2,6 +2,8 @@ import itertools import contextlib import logging +from typing import List, Any, Optional + from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from sqlalchemy import create_engine @@ -19,7 +21,7 @@ class Row(dict): raise AttributeError(name) -def to_json_list(cursor): +def to_json_list(cursor: Any) -> Optional[List[Row]]: column_names = list(cursor.keys()) result = cursor.fetchall() if not result: diff --git a/website/handlers/file/handler.py b/website/handlers/file/handler.py index ba5e880..06af4cb 100644 --- a/website/handlers/file/handler.py +++ b/website/handlers/file/handler.py @@ -42,17 +42,16 @@ class UploadHandler(APIHandler): cur = conn.execute(sql, {"md5_str": md5_str}) row = cur.fetchone() - if not row: - filepath = os.path.join(settings.file_upload_dir, md5_str + '_' + filename) - if not os.path.exists(filepath): - for meta in file_metas: - # filename = meta['filename'] - with open(filepath, 'wb') as f: - f.write(meta['body']) - - with self.app_mysql.connect() as conn: - sql = text("insert into files(filename, filepath, md5_str, filetype, user) values(:filename, :filepath, :md5_str, :filetype, :user)") - conn.execute(sql, {"filename": filename, "filepath": filepath, "md5_str": md5_str, "filetype": filetype, "user": self.current_user.id}) + if not row: + filepath = os.path.join(settings.file_upload_dir, md5_str + '_' + filename) + if not os.path.exists(filepath): + for meta in file_metas: + # filename = meta['filename'] + with open(filepath, 'wb') as f: + f.write(meta['body']) + + sql_insert = text("insert into files(filename, filepath, md5_str, filesize, filetype, user) values(:filename, :filepath, :md5_str, :file_size, :filetype, :user)") + conn.execute(sql_insert, {"filename": filename, "filepath": filepath, "md5_str": md5_str, "file_size": int(file_size/1024/1024), "filetype": filetype, "user": self.current_user.id}) conn.commit() self.finish({"result": md5_str}) diff --git a/website/handlers/model/handler.py b/website/handlers/model/handler.py index bff2373..cd58bdd 100644 --- a/website/handlers/model/handler.py +++ b/website/handlers/model/handler.py @@ -6,7 +6,7 @@ from website import errors from website import settings from website import consts from website import db -from website.util import shortuuid, aes +from website.util import md5 from website.handler import APIHandler, authenticated @@ -25,14 +25,32 @@ class ClassificationAddHandler(APIHandler): raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称过长") with self.app_mysql.connect() as conn: - cur = conn.execute(text("select id from model_classification where name=:name"), name=name) + cur = conn.execute(text("select id from model_classification where name=:name"), {"name":name}) classification_id = cur.fetchone()[0] if classification_id: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类名称重复") - conn.execute(text(""" - insert into model_classification (name, created_at) values (:name, NOW())"""), - name=name) + conn.execute( + text("""insert into model_classification (name, created_at) values (:name, NOW())"""), {"name": name} + ) + + self.finish() + + +class ClassificationEditHandler(APIHandler): + """ + 编辑模型分类 + """ + + @authenticated + def post(self): + classification_id = self.get_int_argument("id") + name = self.get_escaped_argument("name", "") + if not classification_id or not name: + raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") + with self.app_mysql.connect() as conn: + conn.execute(text("""update model_classification set name=:name where id=:id"""), {"name": name, "id": classification_id}) + conn.commit() self.finish() @@ -46,8 +64,7 @@ class ClassificationListHandler(APIHandler): @authenticated def post(self): with self.app_mysql.connect() as conn: - cur = conn.execute(text(""" - select id, name from model_classification""")) + cur = conn.execute(text("""select id, name from model_classification""")) result = db.to_json_list(cur) self.finish({"data": result}) @@ -63,87 +80,261 @@ class ClassificationDeleteHandler(APIHandler): classification_id = self.get_int_argument("id") if not classification_id: raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") + with self.app_mysql.connect() as conn: - conn.execute(text(""" - DELETE FROM model_classification WHERE id=:id"""), - id=classification_id) - + conn.execute(text("""DELETE FROM model_classification WHERE id=:id"""), {"id": classification_id}) + conn.commit() + self.finish() class ListHandler(APIHandler): """ - + 模型列表 """ @authenticated def post(self): - self.finish() + pageNo = self.get_int_argument("pageNo", 1) + pageSize = self.get_int_argument("pageSize", 10) + name = self.get_escaped_argument("name", "") + + result = [] + with self.app_mysql.connect() as conn: + sql = "select m.id, m.name, m.model_type, m.default_version, mc.name classification_name, m.update_at " \ + "from model m left join model_classification mc on m.classification=mc.id where m.del=0 " + + param = {} + + sql_count = "select count(id) from model where del=0 " + param_count = {} + + if name: + sql += "and m.name like :name" + param["name"] = "%{}%".format(name) + + sql_count += "and m.name like :name" + param_count["name"] = "%{}%".format(name) + + sql += " order by m.id desc limit :pageSize offset :offset" + param["pageSize"] = pageSize + param["offset"] = (pageNo - 1) * pageSize + + cur = conn.execute(text(sql), param) + result = db.to_json_list(cur) + + count = conn.execute(text(sql_count), param_count).fetchone()[0] + + data = [] + for item in result: + data.append({ + "id": item["id"], + "name": item["name"], + "model_type": consts.model_type_map[item["model_type"]], + "classification_name": item["classification_name"], + "default_version": item["default_version"], + "update_time": str(item["update_at"]) + }) + + self.finish({"data": data, "count": count}) class AddHandler(APIHandler): """ - + 添加模型 """ @authenticated def post(self): + name = self.get_escaped_argument("name", "") + model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法,1002/深度学习 + classification = self.get_int_argument("classification") + comment = self.get_escaped_argument("comment", "") + + if not name or not classification or model_type not in consts.MODEL_TYPE: + raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") + + with self.app_mysql.connect() as conn: + sql = text("select id from model_classification where id=:id", {"id": classification}) + cur = conn.execute(sql) + if not cur.fetchone(): + raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "分类不存在") + + conn.execute( + text("""insert into model (name, model_type, classification, comment, created_at) values (:name, :model_type, :classification, :comment, NOW())"""), + {"name": name, "model_type": model_type, "classification": classification, "comment": comment} + ) + + conn.commit() + self.finish() class EditHandler(APIHandler): """ - + 编辑模型 """ @authenticated def post(self): + mid = self.get_int_argument("id") + name = self.get_escaped_argument("name", "") + model_type = self.get_int_argument("model_type", consts.model_type_machine) # 1001/经典算法,1002/深度学习 + classification = self.get_int_argument("classification") + comment = self.get_escaped_argument("comment", "") + + if not name or not classification or model_type not in consts.MODEL_TYPE: + raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") + + with self.app_mysql.connect() as conn: + conn.execute( + text("""update model set name=:name, model_type=:model_type, classification=:classification, comment=:comment, update_at=NOW() where id=:id"""), + {"name": name, "model_type": model_type, "classification": classification, "comment": comment, "id": mid} + ) + + conn.commit() + self.finish() class InfoHandler(APIHandler): """ - + 模型信息 """ @authenticated def post(self): - self.finish() + mid = self.get_int_argument("id") + if not mid: + raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") + result = {} + with self.app_mysql.connect() as conn: + cur = conn.execute( + text("""select m.name, mc.name as classification_name, m.comment, m.update_at from model m, model_classification mc where m.id=:id and m.classification=c.id"""), + {"id": mid} + ) + + result = db.to_json(cur) + if not result: + raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型不存在") + + data = { + "name": result["name"], + "classification_name": result["classification_name"], + "comment": result["comment"], + "update_time": str(result["update_at"]) + } + + + self.finish(data) class DeleteHandler(APIHandler): """ - + 删除模型 """ @authenticated def post(self): + mid = self.get_int_argument("id") + if not mid: + raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") + with self.app_mysql.connect() as conn: + conn.execute(text("update model set del=1 where id=:id"), {"id": mid}) + conn.commit() + self.finish() class VersionAddHandler(APIHandler): """ + 添加模型版本 + + - 描述: 添加模型版本 + - 请求方式:post + - 请求参数: + > - model_id, int, 模型id + > - version, string, 版本 + > - comment, string,备注 + > - model_file, string, 模型文件的md5, 只允许上传一个文件 + > - config_file, string, 模型配置文件的md5 + > - config_str, string, 模型配置参数,json格式字符串,eg: '{"a":"xx","b":"xx"}' + - 返回值:无 """ @authenticated def post(self): + mid = self.get_int_argument("model_id") + version = self.get_escaped_argument("version", "") + comment = self.get_escaped_argument("comment", "") + model_file = self.get_escaped_argument("model_file", "") + config_file = self.get_escaped_argument("config_file", "") + config_str = self.get_escaped_argument("config_str", "") + if not mid or not version or not model_file: + raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") + + if not md5.md5_validate(model_file): + raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误") + if config_file and not md5.md5_validate(config_file): + raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误") + + with self.app_mysql.connect() as conn: + conn.execute( + text( + """ + insert into model_version (model_id, version, comment,model_file, config_file, config_str, created_at, update_at) + values (:model_id, :version, :comment, :model_file, :config_file, :config_str, NOW(), NOW())""" + ), + + {"model_id": mid, "version": version, "comment": comment, "model_file": model_file, "config_file": config_file, "config_str": config_str}) + conn.commit() self.finish() class VersionEditHandler(APIHandler): """ - + 编辑模型版本 + + - 描述: 编辑模型版本 + - 请求方式:post + - 请求参数: + > - version_id, int, 模型id + > - version, string, 版本 + > - comment, string,备注 + > - model_file, string, 模型文件的md5, 只允许上传一个文件 + > - config_file, string, 模型配置文件的md5 + > - config_str, string, 模型配置参数,json格式字符串,eg: '{"a":"xx","b":"xx"}' + - 返回值:无 """ @authenticated def post(self): + version_id = self.get_int_argument("version_id") + version = self.get_escaped_argument("version", "") + comment = self.get_escaped_argument("comment", "") + model_file = self.get_escaped_argument("model_file", "") + config_file = self.get_escaped_argument("config_file", "") + config_str = self.get_escaped_argument("config_str", "") + if not version_id or not version or not model_file: + raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") + + if not md5.md5_validate(model_file): + raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型文件格式错误") + if config_file and not md5.md5_validate(config_file): + raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型配置文件格式错误") + with self.app_mysql.connect() as conn: + conn.execute( + text("update model_version set version=:version, comment=:comment, model_file=:model_file, config_file=:config_file, config_str=:config_str, update_at=NOW() where id=:id"), + {"version": version, "comment": comment, "model_file": model_file, "config_file": config_file, "config_str": config_str, "id": version_id} + ) self.finish() + class VersionListHandler(APIHandler): """ - + 模型版本列表 """ @authenticated @@ -153,7 +344,7 @@ class VersionListHandler(APIHandler): class VersionInfoHandler(APIHandler): """ - + 模型版本信息 """ @authenticated @@ -163,17 +354,39 @@ class VersionInfoHandler(APIHandler): class VersionSetDefaultHandler(APIHandler): """ - + 设置模型版本为默认版本 + + - 描述: 设置模型默认版本 + - 请求方式:post + - 请求参数: + > - version_id, int, 模型版本id + > - model_id, int, 模型id + - 返回值:无 """ @authenticated def post(self): + version_id = self.get_int_argument("version_id") + model_id = self.get_int_argument("model_id") + if not version_id or not model_id: + raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") + + with self.app_mysql.connect() as conn: + conn.execute( + text("update model_version set is_default=0 where model_id=:model_id"), + {"model_id": model_id}) + + conn.execute( + text("update model_version set is_default=1 where id=:id"), + {"id": version_id}) + conn.commit() + self.finish() class VersionDeleteHandler(APIHandler): """ - + 删除模型版本 """ @authenticated diff --git a/website/util/md5.py b/website/util/md5.py new file mode 100644 index 0000000..a2bf2b7 --- /dev/null +++ b/website/util/md5.py @@ -0,0 +1,10 @@ +import re + +def md5_validate(v: str) -> bool: + md5_pattern = re.compile(r'^[a-fA-F0-9]{32}$') + # 匹配字符串 + if md5_pattern.match(v): + return True + + return False +