You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

101 lines
3.6 KiB
Python

# -*- coding: utf-8 -*-
11 months ago
import logging
from typing import Union
from sqlalchemy import Column, Integer, String, DateTime, func
11 months ago
from sqlalchemy.ext.declarative import declarative_base
from website.db_mysql import get_session
Base = declarative_base()
"""
CREATE TABLE `model` (
`id` int NOT NULL AUTO_INCREMENT,
`suid` varchar(10) DEFAULT NULL COMMENT 'short uuid',
`name` varchar(255) NOT NULL DEFAULT '',
`model_type` int DEFAULT '1002' COMMENT '模型类型1001/经典算法1002/深度学习',
`classification` int DEFAULT '0' COMMENT '模型分类的id',
`comment` varchar(255) DEFAULT '' COMMENT '备注',
`default_version` varchar(100) DEFAULT '',
`del` tinyint(1) DEFAULT '0' COMMENT '删除状态1/删除0/正常',
`create_time` datetime DEFAULT CURRENT_TIMESTAMP,
`update_time` datetime DEFAULT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='模型表';
"""
class Model(Base):
__tablename__ = "model"
id = Column(Integer, primary_key=True, autoincrement=True)
suid = Column(String(10), comment="short uuid")
name = Column(String(255), nullable=False, default="")
model_type = Column(Integer, default=1002, comment="模型类型1001/经典算法1002/深度学习")
classification = Column(Integer, default=0, comment="模型分类的id")
comment = Column(String(255), default="", comment="备注")
default_version = Column(String(100), default="")
delete = Column("del", Integer, default=0, comment="删除状态1/删除0/正常")
create_time = Column(DateTime, default=func.now())
update_time = Column(DateTime, onupdate=func.now())
def __repr__(self):
return f"Model(id={self.id}, name='{self.name}', model_type={self.model_type})"
class ModelRepositry(object):
def get_suid(self, model_id: int) -> str:
with get_session() as session:
model = session.query(Model).filter(Model.id == model_id).first()
if not model or not model.suid:
return ""
12 months ago
return model.suid
def get_model_by_id(self, model_id: int) -> Union[Model, None]:
12 months ago
with get_session() as session:
model = session.query(Model).filter(Model.id == model_id).first()
if not model:
return None
return model
11 months ago
def get_model_dict_by_id(self, model_id: int) -> dict:
with get_session() as session:
logging.info(f"model id is : {model_id}")
model = session.query(Model).filter(Model.id == model_id).first()
11 months ago
if not model:
return {}
11 months ago
model_dict = {
'id': model.id,
'suid': model.suid,
'name': model.name,
'model_type': model.model_type,
'classification': model.classification,
'comment': model.comment,
'default_version': model.default_version,
'delete': model.delete,
'create_time': model.create_time,
'update_time': model.update_time
}
logging.info(f"model dict is : {model_dict}")
return model_dict
11 months ago
def get_model_by_ids(self, model_ids: list) -> list:
with get_session() as session:
models = session.query(Model).filter(Model.id.in_(model_ids)).all()
if not models:
return []
return models
def get_model_count(self) -> int:
with get_session() as session:
10 months ago
return session.query(Model).filter(Model.delete == 0).count()