|
|
import itertools
|
|
|
import contextlib
|
|
|
import logging
|
|
|
|
|
|
from typing import List, Any, Optional
|
|
|
|
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
from sqlalchemy import create_engine
|
|
|
|
|
|
from website import settings
|
|
|
|
|
|
|
|
|
class Row(dict):
|
|
|
"""A dict that allows for object-like property access syntax."""
|
|
|
|
|
|
def __getattr__(self, name):
|
|
|
try:
|
|
|
return self[name]
|
|
|
except KeyError:
|
|
|
raise AttributeError(name)
|
|
|
|
|
|
|
|
|
def to_json_list(cursor: Any) -> Optional[List[Row]]:
|
|
|
column_names = list(cursor.keys())
|
|
|
result = cursor.fetchall()
|
|
|
if not result:
|
|
|
return []
|
|
|
|
|
|
return [Row(itertools.zip_longest(column_names, row)) for row in result]
|
|
|
|
|
|
|
|
|
def to_json(cursor):
|
|
|
column_names = list(cursor.keys())
|
|
|
result = cursor.fetchone()
|
|
|
|
|
|
if not result:
|
|
|
return {}
|
|
|
|
|
|
return Row(itertools.zip_longest(column_names, result))
|
|
|
|
|
|
|
|
|
app_engine = create_engine(
|
|
|
'mysql+pymysql://{}:{}@{}/{}?charset=utf8mb4'.format(
|
|
|
settings.mysql_app['user'],
|
|
|
settings.mysql_app['password'],
|
|
|
settings.mysql_app['host'],
|
|
|
settings.mysql_app['database']
|
|
|
), # SQLAlchemy 数据库连接串,格式见下面
|
|
|
echo=bool(settings.SQLALCHEMY_ECHO), # 是不是要把所执行的SQL打印出来,一般用于调试
|
|
|
pool_pre_ping=True,
|
|
|
pool_size=int(settings.SQLALCHEMY_POOL_SIZE), # 连接池大小
|
|
|
max_overflow=int(settings.SQLALCHEMY_POOL_MAX_SIZE), # 连接池最大的大小
|
|
|
pool_recycle=int(settings.SQLALCHEMY_POOL_RECYCLE), # 多久时间回收连接
|
|
|
)
|
|
|
|
|
|
Session = sessionmaker(bind=app_engine)
|
|
|
# Base = declarative_base(app_engine)
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
def get_session():
|
|
|
s = Session()
|
|
|
try:
|
|
|
yield s
|
|
|
s.commit()
|
|
|
except Exception as e:
|
|
|
s.rollback()
|
|
|
raise e
|
|
|
finally:
|
|
|
s.close()
|