You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

101 lines
3.6 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
import logging
from typing import Union
from sqlalchemy import Column, Integer, String, DateTime, func
from sqlalchemy.ext.declarative import declarative_base
from website.db_mysql import get_session
Base = declarative_base()
"""
CREATE TABLE `model` (
`id` int NOT NULL AUTO_INCREMENT,
`suid` varchar(10) DEFAULT NULL COMMENT 'short uuid',
`name` varchar(255) NOT NULL DEFAULT '',
`model_type` int DEFAULT '1002' COMMENT '模型类型1001/经典算法1002/深度学习',
`classification` int DEFAULT '0' COMMENT '模型分类的id',
`comment` varchar(255) DEFAULT '' COMMENT '备注',
`default_version` varchar(100) DEFAULT '',
`del` tinyint(1) DEFAULT '0' COMMENT '删除状态1/删除0/正常',
`create_time` datetime DEFAULT CURRENT_TIMESTAMP,
`update_time` datetime DEFAULT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='模型表';
"""
class Model(Base):
__tablename__ = "model"
id = Column(Integer, primary_key=True, autoincrement=True)
suid = Column(String(10), comment="short uuid")
name = Column(String(255), nullable=False, default="")
model_type = Column(Integer, default=1002, comment="模型类型1001/经典算法1002/深度学习")
classification = Column(Integer, default=0, comment="模型分类的id")
comment = Column(String(255), default="", comment="备注")
default_version = Column(String(100), default="")
delete = Column("del", Integer, default=0, comment="删除状态1/删除0/正常")
create_time = Column(DateTime, default=func.now())
update_time = Column(DateTime, onupdate=func.now())
def __repr__(self):
return f"Model(id={self.id}, name='{self.name}', model_type={self.model_type})"
class ModelRepositry(object):
def get_suid(self, model_id: int) -> str:
with get_session() as session:
model = session.query(Model).filter(Model.id == model_id).first()
if not model or not model.suid:
return ""
return model.suid
def get_model_by_id(self, model_id: int) -> Union[Model, None]:
with get_session() as session:
model = session.query(Model).filter(Model.id == model_id).first()
if not model:
return None
return model
def get_model_dict_by_id(self, model_id: int) -> dict:
with get_session() as session:
logging.info(f"model id is : {model_id}")
model = session.query(Model).filter(Model.id == model_id).first()
if not model:
return {}
model_dict = {
'id': model.id,
'suid': model.suid,
'name': model.name,
'model_type': model.model_type,
'classification': model.classification,
'comment': model.comment,
'default_version': model.default_version,
'delete': model.delete,
'create_time': model.create_time,
'update_time': model.update_time
}
logging.info(f"model dict is : {model_dict}")
return model_dict
def get_model_by_ids(self, model_ids: list) -> list:
with get_session() as session:
models = session.query(Model).filter(Model.id.in_(model_ids)).all()
if not models:
return []
return models
def get_model_count(self) -> int:
with get_session() as session:
return session.query(Model).filter(Model.delete == 0).count()