添加异步数据库处理

main
周平 1 year ago
parent 85b1e044d0
commit c1f4dc7778

@ -7,6 +7,8 @@ from typing import List, Any, Optional
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from contextlib import asynccontextmanager
from website import settings from website import settings
@ -41,11 +43,11 @@ def to_json(cursor):
app_engine = create_engine( app_engine = create_engine(
'mysql+pymysql://{}:{}@{}/{}?charset=utf8mb4'.format( "mysql+pymysql://{}:{}@{}/{}?charset=utf8mb4".format(
settings.mysql_app['user'], settings.mysql_app["user"],
settings.mysql_app['password'], settings.mysql_app["password"],
settings.mysql_app['host'], settings.mysql_app["host"],
settings.mysql_app['database'] settings.mysql_app["database"],
), # SQLAlchemy 数据库连接串,格式见下面 ), # SQLAlchemy 数据库连接串,格式见下面
echo=bool(settings.SQLALCHEMY_ECHO), # 是不是要把所执行的SQL打印出来一般用于调试 echo=bool(settings.SQLALCHEMY_ECHO), # 是不是要把所执行的SQL打印出来一般用于调试
pool_pre_ping=True, pool_pre_ping=True,
@ -69,3 +71,33 @@ def get_session():
raise e raise e
finally: finally:
s.close() s.close()
async_app_engine = create_async_engine(
"mysql+aiomysql://{}:{}@{}/{}?charset=utf8mb4".format(
settings.mysql_app["user"],
settings.mysql_app["password"],
settings.mysql_app["host"],
settings.mysql_app["database"],
),
echo=bool(settings.SQLALCHEMY_ECHO),
pool_pre_ping=True,
pool_size=int(settings.SQLALCHEMY_POOL_SIZE),
max_overflow=int(settings.SQLALCHEMY_POOL_MAX_SIZE),
pool_recycle=int(settings.SQLALCHEMY_POOL_RECYCLE),
)
AsyncSession = sessionmaker(bind=async_app_engine, class_=AsyncSession)
@asynccontextmanager
async def get_async_session():
async with AsyncSession() as s:
try:
yield s
await s.commit()
except Exception as e:
await s.rollback()
raise e
finally:
await s.close()

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging import logging
import asyncio
from sqlalchemy import text from sqlalchemy import text
@ -10,22 +11,24 @@ from website import settings
from website.handler import APIHandler, authenticated from website.handler import APIHandler, authenticated
from website.service import enterprise from website.service import enterprise
from website.util import shortuuid, aes from website.util import shortuuid, aes
from concurrent.futures import ThreadPoolExecutor
from functools import partial
class EntityIndexHandler(APIHandler): class EntityIndexHandler(APIHandler):
"""首页""" """首页"""
@authenticated @authenticated
def post(self): async def post(self):
pageNo = self.get_int_argument("pageNo", 1) pageNo = self.get_int_argument("pageNo", 1)
pageSize = self.get_int_argument("pageSize", 10) pageSize = self.get_int_argument("pageSize", 10)
name = self.tostr(self.get_escaped_argument("name", "")) name = self.tostr(self.get_escaped_argument("name", ""))
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
sql_text = "select id, name, industry, logo, create_time from enterprise where 1=1 " sql_text = "select id, suid, name, industry, logo, create_time from enterprise where 1=1 "
param = {} param = {}
count_sql_text = "select count(id) from enterprise " count_sql_text = "select count(id) c from enterprise where 1=1 "
count_param = {} count_param = {}
if name: if name:
@ -45,26 +48,64 @@ class EntityIndexHandler(APIHandler):
cur = conn.execute(text(sql_text), param) cur = conn.execute(text(sql_text), param)
result = db_mysql.to_json_list(cur) result = db_mysql.to_json_list(cur)
count = conn.execute(text(count_sql_text), count_param).fetchone()[0] cur_count = conn.execute(text(count_sql_text), count_param)
logging.info(count) count = db_mysql.to_json(cur_count)
logging.info(result) count = count["c"] if count else 0
# data_index = [item["id"] for item in result]
data = [] data = []
for item in result: # for item in result:
modelCount = enterprise.get_enterprise_model_count(item["id"]) # modelCount = enterprise.get_enterprise_model_count(item["id"])
deviceCount = enterprise.get_enterprise_device_count(item["id"]) # deviceCount = enterprise.get_enterprise_device_count(item["id"])
data.append( # data.append(
{ # {
"id": item["id"], # "id": item["id"],
"name": item["name"], # "name": item["name"],
"industry": item["industry"], # "industry": consts.industry_map[item["industry"]],
"modelCount": modelCount, # "modelCount": modelCount,
"deviceCount": deviceCount, # "deviceCount": deviceCount,
"logo": item["logo"], # "logo": item["logo"],
"createTime": str(item["create_time"]) # "createTime": str(item["create_time"]),
} # }
# )
# with ThreadPoolExecutor() as executor:
# get_count = partial(enterprise.get_enterprise_model_and_device_count)
# futures = [
# executor.submit(get_count, entity_id=item["id"], entity_suid="")
# for item in result
# ]
# results = [future.result() for future in futures]
# model_counts = [result[0] for result in results]
# device_counts = [result[1] for result in results]
count_results = await asyncio.gather(
*[
enterprise.get_enterprise_model_and_device_count(entity_id=item["id"])
for item in result
]
)
model_counts = [result[0] for result in count_results]
device_counts = [result[1] for result in count_results]
data = [
{
"id": item["id"],
"name": item["name"],
"industry": consts.industry_map[item["industry"]],
"modelCount": model_count,
"deviceCount": device_count,
"logo": item["logo"],
"createTime": str(item["create_time"]),
}
for item, model_count, device_count in zip(
result, model_counts, device_counts
) )
]
self.finish({"count": count, "data": data}) self.finish({"count": count, "data": data})
@ -96,14 +137,25 @@ class EntityAddHandler(APIHandler):
summary = self.get_escaped_argument("summary", "") summary = self.get_escaped_argument("summary", "")
logo = self.get_escaped_argument("logo", "") logo = self.get_escaped_argument("logo", "")
if not name or not province or not city or not addr or not industry or not contact or not phone or not summary: if (
not name
or not province
or not city
or not addr
or not industry
or not contact
or not phone
or not summary
):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if industry not in consts.industry_map: if industry not in consts.industry_map:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "清选择行业类型") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "清选择行业类型")
if logo and len(logo) * 0.75 / 1024 / 1024 > 1.2: if logo and len(logo) * 0.75 / 1024 / 1024 > 1.2:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制") raise errors.HTTPAPIError(
errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制"
)
short_uid = shortuuid.ShortUUID().random(length=8) short_uid = shortuuid.ShortUUID().random(length=8)
pwd = aes.encrypt(settings.enterprise_aes_key, short_uid) pwd = aes.encrypt(settings.enterprise_aes_key, short_uid)
@ -127,7 +179,7 @@ class EntityAddHandler(APIHandler):
"logo": logo, "logo": logo,
"account": "admin", "account": "admin",
"pwd": pwd, "pwd": pwd,
} },
) )
conn.commit() conn.commit()
@ -157,21 +209,31 @@ class EntityEditHandler(APIHandler):
logo = self.get_escaped_argument("logo", "") logo = self.get_escaped_argument("logo", "")
account = self.get_escaped_argument("account", "") account = self.get_escaped_argument("account", "")
if not name or not province or not city or not addr or not industry or not contact or not phone or not summary: if (
not name
or not province
or not city
or not addr
or not industry
or not contact
or not phone
or not summary
):
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if industry not in consts.industry_map: if industry not in consts.industry_map:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "清选择行业类型") raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "清选择行业类型")
if logo and len(logo) * 0.75 / 1024 / 1024 > 1.2: if logo and len(logo) * 0.75 / 1024 / 1024 > 1.2:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制") raise errors.HTTPAPIError(
errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制"
)
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
conn.execute( conn.execute(
text( text(
# "insert into enterprise(name, province, city, addr, industry, contact, phone, summary, logo, account, pwd) " # "insert into enterprise(name, province, city, addr, industry, contact, phone, summary, logo, account, pwd) "
# "values(:name, :province, :city, :addr, :industry, :contact, :phone, :summary, :logo, :account, :pwd)" # "values(:name, :province, :city, :addr, :industry, :contact, :phone, :summary, :logo, :account, :pwd)"
"update enterprise set name=:name, province=:province, city=:city, addr=:addr, industry=:industry, contact" "update enterprise set name=:name, province=:province, city=:city, addr=:addr, industry=:industry, contact"
"=:contact, phone=:phone, summary=:summary, logo=:logo, account=:account where id=:id", "=:contact, phone=:phone, summary=:summary, logo=:logo, account=:account where id=:id",
), ),
@ -186,7 +248,8 @@ class EntityEditHandler(APIHandler):
"summary": summary, "summary": summary,
"logo": logo, "logo": logo,
"account": account, "account": account,
} "id": eid,
},
) )
conn.commit() conn.commit()
@ -222,6 +285,7 @@ class EntityInfoHandler(APIHandler):
"name": row["name"], "name": row["name"],
"province": row["province"], "province": row["province"],
"city": row["city"], "city": row["city"],
"addr": row["addr"],
"industry": row["industry"], "industry": row["industry"],
"contact": row["contact"], "contact": row["contact"],
"phone": row["phone"], "phone": row["phone"],
@ -242,9 +306,7 @@ class EntityDeleteHandler(APIHandler):
eid = self.get_int_argument("id") eid = self.get_int_argument("id")
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
conn.execute( conn.execute(text("update enterprise set del=1 where id=:id"), {"id": eid})
text("update enterprise set del=1 where id=:id"), {"id": eid}
)
conn.commit() conn.commit()
@ -259,7 +321,9 @@ class EntityPwdcheckHandler(APIHandler):
eid = self.get_int_argument("id") eid = self.get_int_argument("id")
with self.app_mysql.connect() as conn: with self.app_mysql.connect() as conn:
cur = conn.execute(text("select pwd from enterprise where id=:id"), {"id": eid}) cur = conn.execute(
text("select pwd from enterprise where id=:id"), {"id": eid}
)
# row = cur.fetchone() # row = cur.fetchone()
logging.info(cur) logging.info(cur)
row = db_mysql.to_json(cur) row = db_mysql.to_json(cur)
@ -271,3 +335,9 @@ class EntityPwdcheckHandler(APIHandler):
pwd_dcrypt = aes.decrypt(settings.enterprise_aes_key, pwd) pwd_dcrypt = aes.decrypt(settings.enterprise_aes_key, pwd)
self.finish({"pwd": pwd_dcrypt}) self.finish({"pwd": pwd_dcrypt})
class IndustryMapHandler(APIHandler):
@authenticated
def post(self):
self.finish(consts.industry_map)

@ -9,9 +9,7 @@ handlers = [
("/enterprise/entity/info", handler.EntityInfoHandler), ("/enterprise/entity/info", handler.EntityInfoHandler),
("/enterprise/entity/delete", handler.EntityDeleteHandler), ("/enterprise/entity/delete", handler.EntityDeleteHandler),
("/enterprise/entity/pwdcheck", handler.EntityPwdcheckHandler), ("/enterprise/entity/pwdcheck", handler.EntityPwdcheckHandler),
("/enterprise/industrymap", handler.IndustryMapHandler),
] ]
page_handlers = [ page_handlers = []
]

@ -1,6 +1,9 @@
from website.handler import BaseHandler from website.handler import BaseHandler
from sqlalchemy import text from sqlalchemy import text
from typing import Any from typing import Any
from website.db_mysql import get_session, get_async_session, to_json
import json
# 获取企业模型数量 # 获取企业模型数量
def get_enterprise_model_count(id: int) -> int: def get_enterprise_model_count(id: int) -> int:
@ -23,10 +26,104 @@ def get_enterprise_entity_count(engine: Any) -> int:
return 0 return 0
# 获取所有企业模型数量 # 获取所有企业模型数量
def get_enterprise_model_count() -> int: def get_enterprise_model_count(entity_id: int = 0, entity_suid: str = "") -> int:
with get_session() as session:
sql = "select base_models from enterprise_busi_model where "
param = {}
if entity_id:
sql += "entity_id = :entity_id"
param["entity_id"] = entity_id
elif entity_suid:
sql += "entity_suid = :entity_suid"
param["entity_suid"] = entity_suid
cur = session.execute(text(sql), param)
res = to_json(cur)
if res:
base_model_list = json.loads(res)
return len(base_model_list)
return 0 return 0
# 获取所有企业设备数量 # 获取所有企业设备数量
def get_enterprise_device_count() -> int: def get_enterprise_device_count(entity_id: int = 0, entity_suid: str = "") -> int:
with get_session() as session:
sql = "select count(id) as device_count from enterprise_device "
param = {}
if entity_id:
sql += "where entity_id = :entity_id"
param = {"entity_id": entity_id}
elif entity_suid:
sql += "where entity_suid = :entity_suid"
param = {"entity_suid": entity_suid}
cur = session.execute(text(sql), param)
res = to_json(cur)
if res:
return res["device_count"]
return 0 return 0
async def get_enterprise_model_and_device_count(
entity_id: int = 0, entity_suid: str = ""
) -> (int, int):
model_count = 0
device_count = 0
async with get_async_session() as session:
# sql_model = "select base_models from enterprise_busi_model where "
# param_model = {}
# if entity_id:
# sql_model += "entity_id = :entity_id"
# param_model["entity_id"] = entity_id
# elif entity_suid:
# sql_model += "entity_suid = :entity_suid"
# param_model["entity_suid"] = entity_suid
# cur_model = await session.execute(text(sql_model), param_model)
# res_model = to_json(cur_model)
# if res_model:
# base_model_list = json.loads(res_model)
# model_count = len(base_model_list)
# sql_device = "select count(id) as device_count from enterprise_device where "
# param_device = {}
# if entity_id:
# sql_device += "entity_id = :entity_id"
# param_device["entity_id"] = entity_id
# elif entity_suid:
# sql_device += "where entity_suid = :entity_suid"
# param_device = {"entity_suid": entity_suid}
# cur_device = await session.execute(text(sql_device), param_device)
# res_device = to_json(cur_device)
# if res_device:
# device_count = res_device["device_count"]
sql = """
SELECT
(SELECT base_models FROM enterprise_busi_model WHERE {where_clause}) as base_models,
(SELECT COUNT(*) FROM enterprise_device WHERE {where_clause}) AS device_count
"""
where_clause = ""
params = {}
if entity_id:
where_clause = "entity_id = :entity_id"
params["entity_id"] = entity_id
elif entity_suid:
where_clause = "entity_suid = :entity_suid"
params["entity_suid"] = entity_suid
sql = sql.format(where_clause=where_clause)
result = await session.execute(text(sql), params)
# row = result.fetchone()
row = to_json(result)
base_models, device_count = row["base_models"], row["device_count"]
if base_models:
model_count = len(json.loads(base_models))
return model_count, device_count

Loading…
Cancel
Save