添加异步数据库处理

main
周平 1 year ago
parent 85b1e044d0
commit c1f4dc7778

@ -7,6 +7,8 @@ from typing import List, Any, Optional
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from contextlib import asynccontextmanager
from website import settings
@ -41,11 +43,11 @@ def to_json(cursor):
app_engine = create_engine(
'mysql+pymysql://{}:{}@{}/{}?charset=utf8mb4'.format(
settings.mysql_app['user'],
settings.mysql_app['password'],
settings.mysql_app['host'],
settings.mysql_app['database']
"mysql+pymysql://{}:{}@{}/{}?charset=utf8mb4".format(
settings.mysql_app["user"],
settings.mysql_app["password"],
settings.mysql_app["host"],
settings.mysql_app["database"],
), # SQLAlchemy 数据库连接串,格式见下面
echo=bool(settings.SQLALCHEMY_ECHO), # 是不是要把所执行的SQL打印出来一般用于调试
pool_pre_ping=True,
@ -69,3 +71,33 @@ def get_session():
raise e
finally:
s.close()
async_app_engine = create_async_engine(
"mysql+aiomysql://{}:{}@{}/{}?charset=utf8mb4".format(
settings.mysql_app["user"],
settings.mysql_app["password"],
settings.mysql_app["host"],
settings.mysql_app["database"],
),
echo=bool(settings.SQLALCHEMY_ECHO),
pool_pre_ping=True,
pool_size=int(settings.SQLALCHEMY_POOL_SIZE),
max_overflow=int(settings.SQLALCHEMY_POOL_MAX_SIZE),
pool_recycle=int(settings.SQLALCHEMY_POOL_RECYCLE),
)
AsyncSession = sessionmaker(bind=async_app_engine, class_=AsyncSession)
@asynccontextmanager
async def get_async_session():
async with AsyncSession() as s:
try:
yield s
await s.commit()
except Exception as e:
await s.rollback()
raise e
finally:
await s.close()

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
import logging
import asyncio
from sqlalchemy import text
@ -10,22 +11,24 @@ from website import settings
from website.handler import APIHandler, authenticated
from website.service import enterprise
from website.util import shortuuid, aes
from concurrent.futures import ThreadPoolExecutor
from functools import partial
class EntityIndexHandler(APIHandler):
"""首页"""
@authenticated
def post(self):
async def post(self):
pageNo = self.get_int_argument("pageNo", 1)
pageSize = self.get_int_argument("pageSize", 10)
name = self.tostr(self.get_escaped_argument("name", ""))
with self.app_mysql.connect() as conn:
sql_text = "select id, name, industry, logo, create_time from enterprise where 1=1 "
sql_text = "select id, suid, name, industry, logo, create_time from enterprise where 1=1 "
param = {}
count_sql_text = "select count(id) from enterprise "
count_sql_text = "select count(id) c from enterprise where 1=1 "
count_param = {}
if name:
@ -45,26 +48,64 @@ class EntityIndexHandler(APIHandler):
cur = conn.execute(text(sql_text), param)
result = db_mysql.to_json_list(cur)
count = conn.execute(text(count_sql_text), count_param).fetchone()[0]
logging.info(count)
logging.info(result)
cur_count = conn.execute(text(count_sql_text), count_param)
count = db_mysql.to_json(cur_count)
count = count["c"] if count else 0
# data_index = [item["id"] for item in result]
data = []
for item in result:
modelCount = enterprise.get_enterprise_model_count(item["id"])
deviceCount = enterprise.get_enterprise_device_count(item["id"])
data.append(
{
"id": item["id"],
"name": item["name"],
"industry": item["industry"],
"modelCount": modelCount,
"deviceCount": deviceCount,
"logo": item["logo"],
"createTime": str(item["create_time"])
}
# for item in result:
# modelCount = enterprise.get_enterprise_model_count(item["id"])
# deviceCount = enterprise.get_enterprise_device_count(item["id"])
# data.append(
# {
# "id": item["id"],
# "name": item["name"],
# "industry": consts.industry_map[item["industry"]],
# "modelCount": modelCount,
# "deviceCount": deviceCount,
# "logo": item["logo"],
# "createTime": str(item["create_time"]),
# }
# )
# with ThreadPoolExecutor() as executor:
# get_count = partial(enterprise.get_enterprise_model_and_device_count)
# futures = [
# executor.submit(get_count, entity_id=item["id"], entity_suid="")
# for item in result
# ]
# results = [future.result() for future in futures]
# model_counts = [result[0] for result in results]
# device_counts = [result[1] for result in results]
count_results = await asyncio.gather(
*[
enterprise.get_enterprise_model_and_device_count(entity_id=item["id"])
for item in result
]
)
model_counts = [result[0] for result in count_results]
device_counts = [result[1] for result in count_results]
data = [
{
"id": item["id"],
"name": item["name"],
"industry": consts.industry_map[item["industry"]],
"modelCount": model_count,
"deviceCount": device_count,
"logo": item["logo"],
"createTime": str(item["create_time"]),
}
for item, model_count, device_count in zip(
result, model_counts, device_counts
)
]
self.finish({"count": count, "data": data})
@ -96,14 +137,25 @@ class EntityAddHandler(APIHandler):
summary = self.get_escaped_argument("summary", "")
logo = self.get_escaped_argument("logo", "")
if not name or not province or not city or not addr or not industry or not contact or not phone or not summary:
if (
not name
or not province
or not city
or not addr
or not industry
or not contact
or not phone
or not summary
):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if industry not in consts.industry_map:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "清选择行业类型")
if logo and len(logo) * 0.75 / 1024 / 1024 > 1.2:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制")
raise errors.HTTPAPIError(
errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制"
)
short_uid = shortuuid.ShortUUID().random(length=8)
pwd = aes.encrypt(settings.enterprise_aes_key, short_uid)
@ -127,7 +179,7 @@ class EntityAddHandler(APIHandler):
"logo": logo,
"account": "admin",
"pwd": pwd,
}
},
)
conn.commit()
@ -157,21 +209,31 @@ class EntityEditHandler(APIHandler):
logo = self.get_escaped_argument("logo", "")
account = self.get_escaped_argument("account", "")
if not name or not province or not city or not addr or not industry or not contact or not phone or not summary:
if (
not name
or not province
or not city
or not addr
or not industry
or not contact
or not phone
or not summary
):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if industry not in consts.industry_map:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "清选择行业类型")
if logo and len(logo) * 0.75 / 1024 / 1024 > 1.2:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制")
raise errors.HTTPAPIError(
errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制"
)
with self.app_mysql.connect() as conn:
conn.execute(
text(
# "insert into enterprise(name, province, city, addr, industry, contact, phone, summary, logo, account, pwd) "
# "values(:name, :province, :city, :addr, :industry, :contact, :phone, :summary, :logo, :account, :pwd)"
"update enterprise set name=:name, province=:province, city=:city, addr=:addr, industry=:industry, contact"
"=:contact, phone=:phone, summary=:summary, logo=:logo, account=:account where id=:id",
),
@ -186,7 +248,8 @@ class EntityEditHandler(APIHandler):
"summary": summary,
"logo": logo,
"account": account,
}
"id": eid,
},
)
conn.commit()
@ -222,6 +285,7 @@ class EntityInfoHandler(APIHandler):
"name": row["name"],
"province": row["province"],
"city": row["city"],
"addr": row["addr"],
"industry": row["industry"],
"contact": row["contact"],
"phone": row["phone"],
@ -242,9 +306,7 @@ class EntityDeleteHandler(APIHandler):
eid = self.get_int_argument("id")
with self.app_mysql.connect() as conn:
conn.execute(
text("update enterprise set del=1 where id=:id"), {"id": eid}
)
conn.execute(text("update enterprise set del=1 where id=:id"), {"id": eid})
conn.commit()
@ -259,7 +321,9 @@ class EntityPwdcheckHandler(APIHandler):
eid = self.get_int_argument("id")
with self.app_mysql.connect() as conn:
cur = conn.execute(text("select pwd from enterprise where id=:id"), {"id": eid})
cur = conn.execute(
text("select pwd from enterprise where id=:id"), {"id": eid}
)
# row = cur.fetchone()
logging.info(cur)
row = db_mysql.to_json(cur)
@ -271,3 +335,9 @@ class EntityPwdcheckHandler(APIHandler):
pwd_dcrypt = aes.decrypt(settings.enterprise_aes_key, pwd)
self.finish({"pwd": pwd_dcrypt})
class IndustryMapHandler(APIHandler):
@authenticated
def post(self):
self.finish(consts.industry_map)

@ -9,9 +9,7 @@ handlers = [
("/enterprise/entity/info", handler.EntityInfoHandler),
("/enterprise/entity/delete", handler.EntityDeleteHandler),
("/enterprise/entity/pwdcheck", handler.EntityPwdcheckHandler),
("/enterprise/industrymap", handler.IndustryMapHandler),
]
page_handlers = [
]
page_handlers = []

@ -1,6 +1,9 @@
from website.handler import BaseHandler
from sqlalchemy import text
from typing import Any
from website.db_mysql import get_session, get_async_session, to_json
import json
# 获取企业模型数量
def get_enterprise_model_count(id: int) -> int:
@ -20,13 +23,107 @@ def get_enterprise_entity_count(engine: Any) -> int:
count = conn.execute(text(count_sql_text)).fetchone()
if count:
return count[0]
return 0
# 获取所有企业模型数量
def get_enterprise_model_count() -> int:
return 0
def get_enterprise_model_count(entity_id: int = 0, entity_suid: str = "") -> int:
with get_session() as session:
sql = "select base_models from enterprise_busi_model where "
param = {}
if entity_id:
sql += "entity_id = :entity_id"
param["entity_id"] = entity_id
elif entity_suid:
sql += "entity_suid = :entity_suid"
param["entity_suid"] = entity_suid
cur = session.execute(text(sql), param)
res = to_json(cur)
if res:
base_model_list = json.loads(res)
return len(base_model_list)
return 0
# 获取所有企业设备数量
def get_enterprise_device_count() -> int:
return 0
def get_enterprise_device_count(entity_id: int = 0, entity_suid: str = "") -> int:
with get_session() as session:
sql = "select count(id) as device_count from enterprise_device "
param = {}
if entity_id:
sql += "where entity_id = :entity_id"
param = {"entity_id": entity_id}
elif entity_suid:
sql += "where entity_suid = :entity_suid"
param = {"entity_suid": entity_suid}
cur = session.execute(text(sql), param)
res = to_json(cur)
if res:
return res["device_count"]
return 0
async def get_enterprise_model_and_device_count(
entity_id: int = 0, entity_suid: str = ""
) -> (int, int):
model_count = 0
device_count = 0
async with get_async_session() as session:
# sql_model = "select base_models from enterprise_busi_model where "
# param_model = {}
# if entity_id:
# sql_model += "entity_id = :entity_id"
# param_model["entity_id"] = entity_id
# elif entity_suid:
# sql_model += "entity_suid = :entity_suid"
# param_model["entity_suid"] = entity_suid
# cur_model = await session.execute(text(sql_model), param_model)
# res_model = to_json(cur_model)
# if res_model:
# base_model_list = json.loads(res_model)
# model_count = len(base_model_list)
# sql_device = "select count(id) as device_count from enterprise_device where "
# param_device = {}
# if entity_id:
# sql_device += "entity_id = :entity_id"
# param_device["entity_id"] = entity_id
# elif entity_suid:
# sql_device += "where entity_suid = :entity_suid"
# param_device = {"entity_suid": entity_suid}
# cur_device = await session.execute(text(sql_device), param_device)
# res_device = to_json(cur_device)
# if res_device:
# device_count = res_device["device_count"]
sql = """
SELECT
(SELECT base_models FROM enterprise_busi_model WHERE {where_clause}) as base_models,
(SELECT COUNT(*) FROM enterprise_device WHERE {where_clause}) AS device_count
"""
where_clause = ""
params = {}
if entity_id:
where_clause = "entity_id = :entity_id"
params["entity_id"] = entity_id
elif entity_suid:
where_clause = "entity_suid = :entity_suid"
params["entity_suid"] = entity_suid
sql = sql.format(where_clause=where_clause)
result = await session.execute(text(sql), params)
# row = result.fetchone()
row = to_json(result)
base_models, device_count = row["base_models"], row["device_count"]
if base_models:
model_count = len(json.loads(base_models))
return model_count, device_count

Loading…
Cancel
Save