更新代码结构

main
周平 1 year ago
parent c5e4d44827
commit a2abb09955

@ -53,7 +53,7 @@ except ImportError:
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../.."))) sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../..")))
from website import settings from website import settings
from website.db import app_engine from website.db_mysql import app_engine
from website.handler import APIErrorHandler from website.handler import APIErrorHandler
from website.urls import handlers, page_handlers from website.urls import handlers, page_handlers
# from website.urls import handlers_v2 # from website.urls import handlers_v2

@ -2,12 +2,13 @@
import logging import logging
from sqlalchemy import Column, Integer, String, DateTime from sqlalchemy import Column, Integer, String, DateTime, func
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from website.db import get_session from website.db_mysql import get_session
from website.handlers.enterprise_entity.db import EnterpriseEntityDB from website.util import shortuuid
from website.handlers.enterprise_node.db import EnterpriseNodeDB from website.db.enterprise_entity import EnterpriseEntityDB
from website.db.enterprise_node import EnterpriseNodeDB
Base = declarative_base() Base = declarative_base()
@ -16,7 +17,7 @@ class EnterpriseDevice(Base):
__tablename__ = 'enterprise_device' __tablename__ = 'enterprise_device'
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
suid = Column(String, len=10, default="") suid = Column(String(length=10), default="")
entity_id = Column(Integer) entity_id = Column(Integer)
entity_suid = Column(String) entity_suid = Column(String)
node_id = Column(Integer) node_id = Column(Integer)
@ -28,8 +29,8 @@ class EnterpriseDevice(Base):
param = Column(String) param = Column(String)
comment = Column(String) comment = Column(String)
delete = Column("del", Integer, default=0) delete = Column("del", Integer, default=0)
create_time = Column(DateTime) create_time = Column(DateTime, default=func.now())
update_time = Column(DateTime) update_time = Column(DateTime, default=func.now())
class EnterpriseDeviceDB(object): class EnterpriseDeviceDB(object):
@ -42,6 +43,7 @@ class EnterpriseDeviceDB(object):
device["entity_suid"] = entity_suid device["entity_suid"] = entity_suid
device["node_suid"] = node_suid device["node_suid"] = node_suid
device["suid"] = shortuuid.ShortUUID().random(10)
new_device = EnterpriseDevice(**device) new_device = EnterpriseDevice(**device)
with get_session() as session: with get_session() as session:
@ -91,3 +93,4 @@ class EnterpriseDeviceDB(object):
logging.error("Failed to list devices") logging.error("Failed to list devices")
raise e raise e
return devices return devices

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging import logging
from website.db import get_session, to_json from website.db_mysql import get_session, to_json
from sqlalchemy import text from sqlalchemy import text
class EnterpriseEntityDB(object): class EnterpriseEntityDB(object):

@ -4,7 +4,7 @@ from typing import List, Dict, Any
from sqlalchemy import text from sqlalchemy import text
from website.db import get_session, to_json_list, to_json from website.db_mysql import get_session, to_json_list, to_json
""" """
CREATE TABLE `enterprise_node` ( CREATE TABLE `enterprise_node` (

@ -5,7 +5,7 @@ import os
from sqlalchemy import text from sqlalchemy import text
from website import consts from website import consts
from website import db from website import db_mysql
from website import errors from website import errors
from website import settings from website import settings
from website.handler import APIHandler, authenticated from website.handler import APIHandler, authenticated
@ -69,7 +69,7 @@ class ClassificationListHandler(APIHandler):
def post(self): def post(self):
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
cur = conn.execute(text("""select id, name from model_classification""")) cur = conn.execute(text("""select id, name from model_classification"""))
result = db.to_json_list(cur) result = db_mysql.to_json_list(cur)
self.finish({"data": result}) self.finish({"data": result})
@ -125,7 +125,7 @@ class ListHandler(APIHandler):
param["offset"] = (pageNo - 1) * pageSize param["offset"] = (pageNo - 1) * pageSize
cur = conn.execute(text(sql), param) cur = conn.execute(text(sql), param)
result = db.to_json_list(cur) result = db_mysql.to_json_list(cur)
count = conn.execute(text(sql_count), param_count).fetchone()[0] count = conn.execute(text(sql_count), param_count).fetchone()[0]
@ -227,7 +227,7 @@ class InfoHandler(APIHandler):
{"id": mid} {"id": mid}
) )
result = db.to_json(cur) result = db_mysql.to_json(cur)
if not result: if not result:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型不存在") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型不存在")
@ -403,7 +403,7 @@ class VersionListHandler(APIHandler):
{"mid": model_id, "offset": (pageNo - 1) * pageSize, "limit": pageSize} {"mid": model_id, "offset": (pageNo - 1) * pageSize, "limit": pageSize}
) )
result = db.to_json(cur) result = db_mysql.to_json(cur)
# 获取记录数量 # 获取记录数量
count = conn.execute( count = conn.execute(
@ -479,7 +479,7 @@ class VersionInfoHandler(APIHandler):
{"id": version_id} {"id": version_id}
) )
result = db.to_json(cur) result = db_mysql.to_json(cur)
if not result: if not result:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型版本不存在") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型版本不存在")
@ -497,7 +497,7 @@ class VersionInfoHandler(APIHandler):
cur_model_file = conn.execute( cur_model_file = conn.execute(
text("select filename, filesize from files where md5_str=:md5_str"), {"md5_str": model_file} text("select filename, filesize from files where md5_str=:md5_str"), {"md5_str": model_file}
) )
model_file_info = db.to_json(cur_model_file) model_file_info = db_mysql.to_json(cur_model_file)
response["model_file_name"] = model_file_info["filename"] response["model_file_name"] = model_file_info["filename"]
response["model_file_size"] = model_file_info["filesize"] response["model_file_size"] = model_file_info["filesize"]
@ -505,7 +505,7 @@ class VersionInfoHandler(APIHandler):
cur_config_file = conn.execute( cur_config_file = conn.execute(
text("select filename, filesize from files where md5_str=:md5_str"), {"md5_str": config_file} text("select filename, filesize from files where md5_str=:md5_str"), {"md5_str": config_file}
) )
config_file_info = db.to_json(cur_config_file) config_file_info = db_mysql.to_json(cur_config_file)
response["config_file_name"] = config_file_info["filename"] response["config_file_name"] = config_file_info["filename"]
response["config_file_size"] = config_file_info["filesize"] response["config_file_size"] = config_file_info["filesize"]
@ -568,7 +568,7 @@ class VersionDeleteHandler(APIHandler):
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
cur = conn.execute(text("select model_id, model_file, config_file from model_version where id=:id"), cur = conn.execute(text("select model_id, model_file, config_file from model_version where id=:id"),
{"id": version_id}) {"id": version_id})
row = db.to_json(cur) row = db_mysql.to_json(cur)
if not row: if not row:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型版本不存在") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "模型版本不存在")

@ -6,7 +6,7 @@ import requests
from sqlalchemy import text from sqlalchemy import text
from website import consts from website import consts
from website import db from website import db_mysql
from website import errors from website import errors
from website import settings from website import settings
from website.handler import APIHandler, authenticated from website.handler import APIHandler, authenticated
@ -63,7 +63,7 @@ class ListHandler(APIHandler):
param["offset"] = (pageNo - 1) * pageSize param["offset"] = (pageNo - 1) * pageSize
cur = conn.execute(text(sql), param) cur = conn.execute(text(sql), param)
result = db.to_json_list(cur) result = db_mysql.to_json_list(cur)
count = conn.execute(text(sql_count), param_count).fetchone()[0] count = conn.execute(text(sql_count), param_count).fetchone()[0]
@ -206,7 +206,7 @@ class InfoHandler(APIHandler):
result = {} result = {}
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
cur = conn.execute(text("""select name, host, port, path, comment from model_hub where id=:id"""), {"id": hid}) cur = conn.execute(text("""select name, host, port, path, comment from model_hub where id=:id"""), {"id": hid})
result = db.to_json(cur) result = db_mysql.to_json(cur)
if not result: if not result:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "model hub not found") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "model hub not found")
self.finish({"data": result}) self.finish({"data": result})

@ -1,11 +1,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging
from sqlalchemy import text from sqlalchemy import text
from website import errors, db from website import db_mysql, errors
from website.handler import APIHandler, authenticated from website.handler import APIHandler, authenticated
from website.util import shortuuid from website.util import shortuuid
from website.handlers.enterprise_device import db as DB_Device from website.db import enterprise_device as DB_Device
class DeviceClassificationAddHandler(APIHandler): class DeviceClassificationAddHandler(APIHandler):
""" """
@ -30,6 +30,9 @@ class DeviceClassificationAddHandler(APIHandler):
text("SELECT id FROM device_classification WHERE name=:name"), {"name": name} text("SELECT id FROM device_classification WHERE name=:name"), {"name": name}
) )
ex = cur.fetchone() ex = cur.fetchone()
logging.info("##############################")
logging.info(ex)
logging.info("##############################")
if ex: if ex:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, '设备分类已存在') raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, '设备分类已存在')
@ -37,7 +40,7 @@ class DeviceClassificationAddHandler(APIHandler):
text("INSERT INTO device_classification (name, suid) VALUES (:name, :suid)"), text("INSERT INTO device_classification (name, suid) VALUES (:name, :suid)"),
{"name": name, "suid": shortuuid.ShortUUID().random(10)} {"name": name, "suid": shortuuid.ShortUUID().random(10)}
) )
conn.commit()
self.finish() self.finish()
@ -67,7 +70,7 @@ class DeviceClassificationHandler(APIHandler):
cur = conn.execute( cur = conn.execute(
text("SELECT id, name FROM device_classification where del=0 ORDER BY id DESC") text("SELECT id, name FROM device_classification where del=0 ORDER BY id DESC")
) )
res = db.to_json_list(cur) res = db_mysql.to_json_list(cur)
res = res and res or [] res = res and res or []
self.finish({"data": res}) self.finish({"data": res})
@ -89,6 +92,7 @@ class DeviceClassificationDeleteHandler(APIHandler):
conn.execute( conn.execute(
text("update device_classification set del=1 WHERE id=:id"), {"id": did} text("update device_classification set del=1 WHERE id=:id"), {"id": did}
) )
conn.commit()
self.finish() self.finish()

@ -4,7 +4,7 @@ import logging
from sqlalchemy import text from sqlalchemy import text
from website import consts from website import consts
from website import db from website import db_mysql
from website import errors from website import errors
from website import settings from website import settings
from website.handler import APIHandler, authenticated from website.handler import APIHandler, authenticated
@ -43,7 +43,7 @@ class EntityIndexHandler(APIHandler):
param["offset"] = (pageNo - 1) * pageSize param["offset"] = (pageNo - 1) * pageSize
cur = conn.execute(text(sql_text), param) cur = conn.execute(text(sql_text), param)
result = db.to_json_list(cur) result = db_mysql.to_json_list(cur)
count = conn.execute(text(count_sql_text), count_param).fetchone()[0] count = conn.execute(text(count_sql_text), count_param).fetchone()[0]
logging.info(count) logging.info(count)
@ -211,7 +211,7 @@ class EntityInfoHandler(APIHandler):
# row = dict(zip(keys, one)) # row = dict(zip(keys, one))
# logging.info(db.Row(itertools.zip_longest(keys, one))) # logging.info(db.Row(itertools.zip_longest(keys, one)))
row = db.to_json(cur) row = db_mysql.to_json(cur)
cur.close() cur.close()
@ -262,7 +262,7 @@ class EntityPwdcheckHandler(APIHandler):
cur = conn.execute(text("select pwd from enterprise where id=:id"), {"id": eid}) cur = conn.execute(text("select pwd from enterprise where id=:id"), {"id": eid})
# row = cur.fetchone() # row = cur.fetchone()
logging.info(cur) logging.info(cur)
row = db.to_json(cur) row = db_mysql.to_json(cur)
if not row: if not row:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "请求失败") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "请求失败")
pwd = row["pwd"] pwd = row["pwd"]

@ -2,8 +2,8 @@
from website import errors from website import errors
from website.handler import APIHandler, authenticated from website.handler import APIHandler, authenticated
from website.handlers.enterprise_entity import db as DB_Entity from website.db import enterprise_entity as DB_Entity
from website.handlers.enterprise_node import db as DB_Node from website.db import enterprise_node as DB_Node
from website.util import shortuuid from website.util import shortuuid

@ -8,7 +8,7 @@ import aiofiles
from sqlalchemy import text from sqlalchemy import text
from website import errors from website import errors
from website import db from website import db_mysql
from website import settings from website import settings
from website.handler import APIHandler, authenticated from website.handler import APIHandler, authenticated
@ -69,7 +69,7 @@ class DeleteHandler(APIHandler):
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
sql = text("select filepath from files where md5_str=:md5_str") sql = text("select filepath from files where md5_str=:md5_str")
cur = conn.execute(sql, {"md5_str": md5_str}) cur = conn.execute(sql, {"md5_str": md5_str})
row = db.to_json(cur) row = db_mysql.to_json(cur)
if not row: if not row:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "file not found") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "file not found")

@ -11,7 +11,7 @@ from io import StringIO, BytesIO
from website import errors from website import errors
from website import settings from website import settings
from website import db from website import db_mysql
from website import consts from website import consts
from website.handler import APIHandler from website.handler import APIHandler
from website.util import aes from website.util import aes
@ -112,7 +112,7 @@ class LoginHandler(APIHandler):
# row = dict(zip(keys, one)) # row = dict(zip(keys, one))
# logging.info(db.Row(itertools.zip_longest(keys, one))) # logging.info(db.Row(itertools.zip_longest(keys, one)))
row = db.to_json(cur) row = db_mysql.to_json(cur)
cur.close() cur.close()
# data = [dict(zip(keys, res)) for res in cur.fetchall()] # data = [dict(zip(keys, res)) for res in cur.fetchall()]

Loading…
Cancel
Save