添加工程文件

main
周平 1 year ago
parent 9c036d7cb7
commit 2406ef5c95

@ -0,0 +1,226 @@
# -*- coding: utf-8 -*-
import logging
import os.path
import sys
import time
import redis
import tornado.escape
import tornado.ioloop
import tornado.options
import tornado.web
# import tornado.websocket
import torndb
import importlib
# from confluent_kafka import Producer
# from rediscluster import StrictRedisCluster
# from redis import sentinel
from rediscluster import RedisCluster
# from redis.sentinel import Sentinel
# from tornado.options import define, options
from tornado.options import options, define as _define, parse_command_line
# from elasticsearch import Elasticsearch
# from tornado_swagger import swagger
def define(name, default=None, type=None, help=None, metavar=None,
multiple=False, group=None, callback=None):
if name not in options._options:
return _define(name, default, type, help, metavar,
multiple, group, callback)
tornado.options.define = define
sys.dont_write_bytecode = True
define("port", default=8888, help="run on the given port", type=int)
define("debug", default=0)
_ROOT = os.path.dirname(os.path.abspath(__file__))
importlib.reload(sys)
# sys.setdefaultencoding('utf-8')
try:
import website
except ImportError:
print("app package import error")
logging.info("app import error")
# sys.path.append(os.path.join(_ROOT, "../.."))
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../..")))
from website import settings
from website.db import app_engine
from website.handler import APIErrorHandler
from website.urls import handlers, page_handlers
# from website.urls import handlers_v2
# print(os.path.dirname(os.path.abspath(__file__)))
class Connection(torndb.Connection):
def __init__(self,
host,
database,
user=None,
password=None,
max_idle_time=7 * 3600,
connect_timeout=500,
time_zone="+0:00"):
self.host = host
self.database = database
self.max_idle_time = float(max_idle_time)
args = dict(conv=torndb.CONVERSIONS,
use_unicode=True,
charset="utf8",
db=database,
init_command=('SET time_zone = "%s";' %
time_zone),
connect_timeout=connect_timeout,
sql_mode="TRADITIONAL")
if user is not None:
args["user"] = user
if password is not None:
args["passwd"] = password
# We accept a path to a MySQL socket file or a host(:port) string
if "/" in host:
args["unix_socket"] = host
else:
self.socket = None
pair = host.split(":")
if len(pair) == 2:
args["host"] = pair[0]
args["port"] = int(pair[1])
else:
args["host"] = host
args["port"] = 3306
self._db = None
self._db_args = args
self._last_use_time = time.time()
try:
self.reconnect()
except Exception:
logging.error("Cannot connect to MySQL on %s",
self.host,
exc_info=True)
# class NoCacheStaticFileHandler(tornado.web.StaticFileHandler):
# def set_extra_headers(self, path):
# self.set_header("Cache-control", "no-cache")
class Application(tornado.web.Application):
def __init__(self):
# from website.handlers import Model
handlers_ = []
for handler in handlers:
handlers_.append(("%s%s" % (settings.api_prefix, handler[0]),
handler[1]))
for handler in page_handlers:
handlers_.append((handler[0], handler[1]))
# for handler in handlers_v2:
# handlers_.append(("%s%s" % (settings.api_prefix_v2, handler[0]),
# handler[1]))
# handlers_.append((r"/wap/s", tornado.web.RedirectHandler, dict(url=r"//wap/s.html")))
handlers_.append((r".*", APIErrorHandler))
# handlers_.append((r"/static/(.*)", NoCacheStaticFileHandler, {"path": os.path.join(_ROOT, "static")}))
settings_ = dict(
debug=options.debug,
# login_url="/login",
login_url="",
cookie_secret=settings.cookie_secret,
template_path=os.path.join(_ROOT, "templates"),
static_path=os.path.join(_ROOT, "static"),
xsrf_cookies=False,
autoescape=None,
)
self.db_app = Connection(
settings.mysql_app["host"],
settings.mysql_app["database"],
user=settings.mysql_app["user"],
password=settings.mysql_app["password"],
time_zone=settings.mysql_app["time_zone"])
self.app_mysql = app_engine
# if settings.redis_sentinel == 1:
# rs = Sentinel(settings.redis_sentinel_nodes, socket_timeout=0.1)
# self.r_app = rs.master_for(settings.redis_sentinel_master,
# socket_timeout=0.1,
# password=settings.redis_sentinel_pwd)
if settings.redis_cluster == 1:
self.r_app = RedisCluster(startup_nodes=settings.redis_app_cluster_notes, decode_responses=True,
password=settings.redis_cluster_pwd)
else:
self.r_app = redis.Redis(*settings.redis_app, decode_responses=True)
# self.r_app = redis.Redis(*settings.redis_app)
# self.kafka_producer = Producer(**settings.kafka_conf)
# self.es = Elasticsearch(settings.es_nodes)
# Model.setup_dbs({"db_app": self.db_app,
# "r_app": self.r_app
# })
tornado.web.Application.__init__(self, handlers_, **settings_)
# swagger.Application.__init__(self, handlers_, **settings_)
def sig_handler(signum, frame):
tornado.ioloop.IOLoop.instance().stop()
class PwdFilter(logging.Filter):
def filter(self, record):
try:
print("##########")
print("{}, {}".format(record.name, record.msg))
except Exception as e:
print(e)
pass
return True
def main():
tornado.options.parse_command_line()
# options.parse_command_line()
formatter = logging.Formatter(
'[%(levelname)1.1s %(asctime)s.%(msecs)d '
'%(module)s:%(funcName)s:%(lineno)d] %(message)s',
"%Y-%m-%d %H:%M:%S"
) # creating own format
for handler in logging.getLogger().handlers: # setting format for all handlers
handler.setFormatter(formatter)
# handler.addFilter(PwdFilter())
app = Application()
app.listen(options.port)
# def ping():
# try:
# row = app.db_app.get("select id from user limit 1")
# if row:
# logging.info("db check ok")
# except Exception as e:
# logging.info(e)
# logging.info("db connection err, reconnect")
# app.db_app.reconnect()
logging.info("start app server...")
# tornado.ioloop.PeriodicCallback(ping, 600000).start()
tornado.ioloop.IOLoop.instance().start()
if __name__ == "__main__":
main()

@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
"""静态配置"""
# 行业分类
industry_map = {
1001: u"IT服务",
1002: u"制造业",
1003: u"批发/零售",
1004: u"生活服务",
1005: u"文化/体育/娱乐业",
1006: u"建筑/房地产",
1007: u"教育",
1008: u"运输/物流/仓储",
1009: u"医疗",
1010: u"政府",
1011: u"金融",
1012: u"能源/采矿",
1013: u"农林渔牧",
1014: u"其他行业",
}

@ -0,0 +1,69 @@
import itertools
import contextlib
import logging
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine
from website import settings
class Row(dict):
"""A dict that allows for object-like property access syntax."""
def __getattr__(self, name):
try:
return self[name]
except KeyError:
raise AttributeError(name)
def to_json_list(cursor):
column_names = list(cursor.keys())
result = cursor.fetchall()
if not result:
return None
return [Row(itertools.zip_longest(column_names, row)) for row in result]
def to_json(cursor):
column_names = list(cursor.keys())
result = cursor.fetchone()
if not result:
return None
return Row(itertools.zip_longest(column_names, result))
app_engine = create_engine(
'mysql+pymysql://{}:{}@{}/{}?charset=utf8mb4'.format(
settings.mysql_app['user'],
settings.mysql_app['password'],
settings.mysql_app['host'],
settings.mysql_app['database']
), # SQLAlchemy 数据库连接串,格式见下面
echo=bool(settings.SQLALCHEMY_ECHO), # 是不是要把所执行的SQL打印出来一般用于调试
# pool_pre_ping=True,
# pool_size=int(settings.SQLALCHEMY_POOL_SIZE), # 连接池大小
# max_overflow=int(settings.SQLALCHEMY_POOL_MAX_SIZE), # 连接池最大的大小
# pool_recycle=int(settings.SQLALCHEMY_POOL_RECYCLE), # 多久时间回收连接
)
# Session = sessionmaker(bind=engine)
# Base = declarative_base(engine)
#
#
# @contextlib.contextmanager
# def get_session():
# s = Session()
# try:
# yield s
# s.commit()
# except Exception as e:
# s.rollback()
# raise e
# finally:
# s.close()

@ -0,0 +1,71 @@
# -*- coding: utf-8 -*-
import logging
from tornado import escape
from tornado.web import HTTPError
# HTTP status code
HTTP_OK = 200
ERROR_BAD_REQUEST = 400
ERROR_UNAUTHORIZED = 401
ERROR_FORBIDDEN = 403
ERROR_NOT_FOUND = 404
ERROR_METHOD_NOT_ALLOWED = 405
ERROR_INTERNAL_SERVER_ERROR = 500
# Custom error code
ERROR_WARNING = 1001
ERROR_DEPRECATED = 1002
ERROR_MAINTAINING = 1003
ERROR_UNKNOWN_ERROR = 9999
ERROR_LICENSE_NOT_ACTIVE = 9000
ERROR_LICENSE_EXPIRE_ATALL = 9003
class HTTPAPIError(HTTPError):
"""API error handling exception
API server always returns formatted JSON to client even there is
an internal server error.
"""
def __init__(self, status_code=ERROR_UNKNOWN_ERROR, message=None,
error=None, data=None, log_message=None, *args):
assert isinstance(data, dict) or data is None
message = message if message else ""
assert isinstance(message, str)
super(HTTPAPIError, self).__init__(int(status_code),
log_message, *args)
self.error = error if error else \
_error_types.get(self.status_code, _unknow_error)
self.message = message
self.data = data if data is not None else {}
def __str__(self):
err = {"meta": {"code": self.status_code, "error": self.error}}
if self.data:
err["data"] = self.data
if self.message:
err["meta"]["message"] = self.message
return escape.json_encode(err)
# default errors
_unknow_error = "unknow_error"
_error_types = {400: "bad_request",
401: "unauthorized",
403: "forbidden",
404: "not_found",
405: "method_not_allowed",
500: "internal_server_error",
1001: "warning",
1002: "deprecated",
1003: "maintaining",
9000: "license_not_active",
9003: "license_expire_at_all",
9999: _unknow_error}

@ -0,0 +1,556 @@
# -*- coding: utf-8 -*-
import ast
import functools
import hashlib
import logging
import re
import time
import traceback
import urllib
import json
# import urlparse
from urllib.parse import parse_qs, unquote
# from urllib import unquote
import tornado
from tornado import escape
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
from tornado.options import options
# from tornado.web import RequestHandler as BaseRequestHandler, HTTPError, asynchronous
from tornado.web import RequestHandler as BaseRequestHandler, HTTPError
from torndb import Row
from website import errors
from website import settings
from website.service.license import get_license_status
if settings.enable_curl_async_http_client:
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
else:
AsyncHTTPClient.configure(None, max_clients=settings.max_clients)
REMOVE_SLASH_RE = re.compile(".+/$")
def _callback_wrapper(callback):
"""A wrapper to handling basic callback error"""
def _wrapper(response):
if response.error:
logging.error("call remote api err: %s" % response)
if isinstance(response.error, tornado.httpclient.HTTPError):
raise errors.HTTPAPIError(response.error.code, "网络连接失败")
else:
raise errors.HTTPAPIError(errors.ERROR_UNKNOWN_ERROR, "未知错误")
else:
callback(response)
return _wrapper
class BaseHandler(BaseRequestHandler):
# def __init__(self, application, request, **kwargs):
# super(BaseHandler, self).__init__(application, request, **kwargs)
# self.xsrf_form_html()
def set_default_headers(self):
self.set_header("Access-Control-Allow-Origin", "*")
self.set_header("Access-Control-Allow-Headers",
"DNT,token,X-CustomHeader,Keep-Alive,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,"
"Content-Type")
# self.set_header("Access-Control-Allow-Headers",
# "DNT,web-token,app-token,Authorization,Accept,Origin,Keep-Alive,User-Agent,X-Mx-ReqToken,"
# "X-Data-Type,X-Auth-Token,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range")
self.set_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS')
# def _request_summary(self):
#
# return "%s %s %s(@%s)" % (self.request.method, self.request.uri, self.request.body.decode(),
# self.request.remote_ip)
def get(self, *args, **kwargs):
# enable GET request when enable delegate get to post
if settings.app_get_to_post:
self.post(*args, **kwargs)
else:
raise HTTPError(405)
def render(self, template_name, **kwargs):
if self.current_user:
if 'username' not in kwargs:
kwargs["username"] = self.current_user.name
if 'current_userid' not in kwargs:
kwargs["current_userid"] = self.current_user.id
if 'role' not in kwargs:
kwargs["role"] = self.current_user.role
self.set_header("Cache-control", "no-cache")
return super(BaseHandler, self).render(template_name, **kwargs)
@property
def db_app(self):
return self.application.db_app
@property
def app_mysql(self):
return self.application.app_mysql
@property
def r_app(self):
return self.application.r_app
@property
def kafka_producer(self):
return self.application.kafka_producer
#
@property
def es(self):
return self.application.es
# @property
# def nsq(self):
# return self.application.nsq
def _call_api(self, url, headers, body, callback, method, callback_wrapper=_callback_wrapper):
start = 0
if callback_wrapper:
callback = callback_wrapper(callback)
try:
start = time.time()
AsyncHTTPClient().fetch(HTTPRequest(url=url,
method=method,
body=body,
headers=headers,
allow_nonstandard_methods=True,
connect_timeout=settings.remote_connect_timeout,
request_timeout=settings.remote_request_timeout,
follow_redirects=False),
callback)
except tornado.httpclient.HTTPError:
logging.error("requet from %s, take time: %s" % (url, (time.time() - start) * 1000))
# if hasattr(x, "response") and x.response:
# callback(x.response)
# else:
# logging.error("Tornado signalled HTTPError %s", x)
# raise x
# @asynchronous
async def call_api(self, url, headers=None, body=None, callback=None, method="POST"):
if callback is None:
callback = self.call_api_callback
if headers is None:
headers = self.request.headers
if body is None:
body = self.request.body
else:
# make sure it is a post request
headers["Content-Type"] = "application/x-www-form-urlencoded"
self._call_api(url, headers, body, callback, method)
def get_current_user(self, token_body=None):
# jid = self.get_secure_cookie(settings.cookie_key)
token = self.request.headers.get("token")
if token_body:
token = token_body
if not token:
return None
token = unquote(token)
jid = tornado.web.decode_signed_value(
settings.cookie_secret,
settings.secure_cookie_name,
token
)
jid = jid and str(jid, encoding="utf-8") or ""
key = settings.session_key_prefix % jid
user = self.r_app.get(key)
if user:
if "/user/info" in self.request.path:
pass
else:
self.r_app.expire(key, settings.session_ttl)
user = str(user, encoding='utf8') if isinstance(user, bytes) else user
# return Row(ast.literal_eval(str(user, encoding="utf-8")))
return Row(ast.literal_eval(user))
else:
return None
def md5compare(self, name):
string = unquote(name)
num1 = string.split("|")[0]
num2 = string.split("|")[1]
num3 = string.split("|")[2]
num = num1 + num2
md5string = hashlib.md5(num).hexdigest().upper()
if md5string == num3:
return True
else:
return False
def tostr(self, src):
return str(src, encoding='utf8') if isinstance(src, bytes) else src
def prepare(self):
self.remove_slash()
self.prepare_context()
self.set_default_jsonbody()
# self.traffic_threshold()
def set_default_jsonbody(self):
if self.request.headers.get('Content-Type') == 'application/json;charset=UTF-8' and self.request.body:
# logging.info(self.request.headers.get('Content-Type'))
# if self.request.headers.get('Content-Type') == 'application/json; charset=UTF-8':
json_body = tornado.escape.json_decode(self.request.body)
for key, value in json_body.items():
if value is not None:
if type(value) is list:
self.request.arguments.setdefault(key, []).extend(value)
elif type(value) is dict:
self.request.arguments[key] = value
else:
self.request.arguments.setdefault(key, []).extend([bytes(str(value), 'utf-8')])
def traffic_threshold(self):
# if self.request.uri in ["/api/download"]:
# return
if not self.current_user:
user_id = self.request.remote_ip
else:
user_id = self.current_user.id
freq_key = "API:FREQ:%s:%s" % (user_id, int(time.time()) / 10)
send_count = self.r_app.incr(freq_key)
if send_count > settings.api_count_in_ten_second:
freq_key = "API:FREQ:%s:%s" % (user_id, int(time.time()) / 10 + 1)
self.r_app.setex(freq_key, send_count, 10)
raise errors.HTTPAPIError(
errors.ERROR_METHOD_NOT_ALLOWED, "请勿频繁操作")
if send_count == 1:
self.r_app.expire(freq_key, 10)
def prepare_context(self):
# self.nsq.pub(settings.nsq_topic_stats, escape.json_encode(self._create_stats_msg()))
pass
def remove_slash(self):
if self.request.method == "GET":
if REMOVE_SLASH_RE.match(self.request.path):
# remove trail slash in path
uri = self.request.path.rstrip("/")
if self.request.query:
uri += "?" + self.request.query
self.redirect(uri)
# def get_json_argument(self, name, default=BaseRequestHandler._ARG_DEFAULT):
def get_json_argument(self, name, default=object()):
json_body = tornado.escape.json_decode(self.request.body)
value = json_body.get(name, default)
return escape.utf8(value) if isinstance(value, str) else value
def get_int_json_argument(self, name, default=0):
try:
return int(self.get_json_argument(name, default))
except ValueError:
return default
def get_escaped_json_argument(self, name, default=None):
if default is not None:
return self.escape_string(self.get_json_argument(name, default))
else:
return self.get_json_argument(name, default)
# def get_argument(self, name, default=BaseRequestHandler._ARG_DEFAULT, strip=True):
def get_argument(self, name, default=object(), strip=True):
value = super(BaseHandler, self).get_argument(name, default, strip)
return escape.utf8(value) if isinstance(value, str) else value
def get_int_argument(self, name, default=0):
try:
return int(self.get_argument(name, default))
except ValueError:
return default
def get_float_argument(self, name, default=0.0):
try:
return float(self.get_argument(name, default))
except ValueError:
return default
def get_uint_arg(self, name, default=0):
try:
return abs(int(self.get_argument(name, default)))
except ValueError:
return default
def unescape_string(self, s):
return escape.xhtml_unescape(s)
def escape_string(self, s):
return escape.xhtml_escape(s)
def get_escaped_argument(self, name, default=None):
if default is not None:
return self.escape_string(self.get_argument(name, default))
else:
return self.get_argument(name, default)
def get_page_url(self, page, form_id=None, tab=None):
if form_id:
return "javascript:goto_page('%s',%s);" % (form_id.strip(), page)
path = self.request.path
query = self.request.query
# qdict = urlparse.parse_qs(query)
qdict = parse_qs(query)
for k, v in qdict.items():
if isinstance(v, list):
qdict[k] = v and v[0] or ''
qdict['page'] = page
if tab:
qdict['tab'] = tab
return path + '?' + urllib.urlencode(qdict)
def find_all(self, target, substring):
current_pos = target.find(substring)
while current_pos != -1:
yield current_pos
current_pos += len(substring)
current_pos = target.find(substring, current_pos)
class WebHandler(BaseHandler):
def finish(self, chunk=None, message=None):
callback = escape.utf8(self.get_argument("callback", None))
if callback:
self.set_header("Content-Type", "application/x-javascript")
if isinstance(chunk, dict):
chunk = escape.json_encode(chunk)
self._write_buffer = [callback, "(", chunk, ")"] if chunk else []
super(WebHandler, self).finish()
else:
self.set_header("Cache-control", "no-cache")
super(WebHandler, self).finish(chunk)
def write_error(self, status_code, **kwargs):
try:
exc_info = kwargs.pop('exc_info')
e = exc_info[1]
if isinstance(e, errors.HTTPAPIError):
pass
elif isinstance(e, HTTPError):
if e.status_code == 401:
self.redirect("/", permanent=True)
return
e = errors.HTTPAPIError(e.status_code)
else:
e = errors.HTTPAPIError(errors.ERROR_INTERNAL_SERVER_ERROR)
exception = "".join([ln for ln
in traceback.format_exception(*exc_info)])
if status_code == errors.ERROR_INTERNAL_SERVER_ERROR \
and not options.debug:
self.send_error_mail(exception)
if options.debug:
e.data["exception"] = exception
self.clear()
# always return 200 OK for Web errors
self.set_status(errors.HTTP_OK)
self.set_header("Content-Type", "application/json; charset=UTF-8")
self.finish(str(e))
except Exception:
logging.error(traceback.format_exc())
return super(WebHandler, self).write_error(status_code, **kwargs)
def send_error_mail(self, exception):
pass
class APIHandler(BaseHandler):
def finish(self, chunk=None, message=None):
if chunk is None:
chunk = {}
if isinstance(chunk, dict):
chunk = {"meta": {"code": errors.HTTP_OK}, "data": chunk}
if message:
chunk["message"] = message
callback = escape.utf8(self.get_argument("callback", None))
if callback:
self.set_header("Content-Type", "application/x-javascript")
if isinstance(chunk, dict):
chunk = escape.json_encode(chunk)
self._write_buffer = [callback, "(", chunk, ")"] if chunk else []
super(APIHandler, self).finish()
else:
self.set_header("Content-Type", "application/json; charset=UTF-8")
super(APIHandler, self).finish(chunk)
def write_error(self, status_code, **kwargs):
try:
exc_info = kwargs.pop('exc_info')
e = exc_info[1]
if isinstance(e, errors.HTTPAPIError):
pass
elif isinstance(e, HTTPError):
e = errors.HTTPAPIError(e.status_code)
else:
e = errors.HTTPAPIError(errors.ERROR_INTERNAL_SERVER_ERROR)
exception = "".join([ln for ln
in traceback.format_exception(*exc_info)])
if status_code == errors.ERROR_INTERNAL_SERVER_ERROR \
and not options.debug:
self.send_error_mail(exception)
if options.debug:
e.data["exception"] = exception
self.clear()
# always return 200 OK for API errors
self.set_status(errors.HTTP_OK)
self.set_header("Content-Type", "application/json; charset=UTF-8")
self.finish(str(e))
except Exception:
logging.error(traceback.format_exc())
return super(APIHandler, self).write_error(status_code, **kwargs)
def send_error_mail(self, exception):
"""Override to implement custom error mail"""
pass
class ErrorHandler(BaseHandler):
"""Default 404: Not Found handler."""
def prepare(self):
super(ErrorHandler, self).prepare()
raise HTTPError(errors.ERROR_NOT_FOUND)
class APIErrorHandler(APIHandler):
"""Default API 404: Not Found handler."""
def prepare(self):
super(APIErrorHandler, self).prepare()
raise errors.HTTPAPIError(errors.ERROR_NOT_FOUND)
def authenticated(method):
"""Decorate methods with this to require that the user be logged in.
Just raise 401
or be avaliable
"""
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
if not self.current_user:
# raise HTTPError(401)
raise errors.HTTPAPIError(errors.ERROR_UNAUTHORIZED, "登录失效")
return method(self, *args, **kwargs)
return wrapper
def authenticated_admin(method):
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
if int(self.current_user.role) != 1001:
# raise HTTPError(403)
raise errors.HTTPAPIError(errors.ERROR_FORBIDDEN, "permission denied")
return method(self, *args, **kwargs)
return wrapper
def operation_log(primary_module, secondary_module, operation_type, content, desc):
"""
Add logging to a function. level is the logging
level, name is the logger name, and message is the
log message. If name and message aren't specified,
they default to the function's module and name.
"""
def decorate(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
self.db_app.insert(
"insert into system_log(user, ip, first_module, second_module, op_type, op_content, description) "
"values(%s, %s, %s, %s, %s, %s, %s)",
self.current_user.name, self.request.headers["X-Forwarded-For"], primary_module, secondary_module, operation_type,
content, desc
)
return func(self, *args, **kwargs)
return wrapper
return decorate
def permission(codes):
def decorate(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
rows = self.db_app.query(
"select rp.permission from role_permission rp, user_role ur where rp.role=ur.role and ur.userid=%s",
self.current_user.id
)
permissions = [item["permission"] for item in rows]
for code in codes:
if code not in permissions:
raise errors.HTTPAPIError(errors.ERROR_FORBIDDEN, "permission denied")
return func(self, *args, **kwargs)
return wrapper
return decorate
def license_validate(codes):
def decorate(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
license_cache = self.r_app.get("system:license")
license_info = {}
if license_cache:
license_info = json.loads(self.tostr(license_cache))
else:
row = self.db_app.get("select syscode, expireat from license limit 1")
if row:
self.r_app.set("system:license", json.dumps({"syscode":row["syscode"], "expireat":row["expireat"]}))
license_info = row
license_status = get_license_status(license_info)
# logging.info("license status is : {}, need : {}".format(license_status, codes))
if license_status not in codes:
raise errors.HTTPAPIError(errors.ERROR_LICENSE_NOT_ACTIVE, "系统License未授权")
# if not license_info:
# raise errors.HTTPAPIError(errors.ERROR_LICENSE_NOT_ACTIVE, "License未授权")
# expireat = int(license_info["expireat"])
# local_time = int(time.time())
# if local_time >= expireat:
# raise errors.HTTPAPIError(errors.ERROR_LICENSE_NOT_ACTIVE, "License授权过期")
return func(self, *args, **kwargs)
return wrapper
return decorate
def userlog(method):
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
logging.info("[ip]:%s [user]:%s" % (self.request.remote_ip, self.current_user.id))
return method(self, *args, **kwargs)
return wrapper

@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
class Model(object):
_dbs = {}
@classmethod
def setup_dbs(cls, dbs):
cls._dbs = dbs
@property
def db_app(self):
return self._dbs.get("db_app", None)
@property
def r_app(self):
return self._dbs.get("r_app", None)

@ -0,0 +1,253 @@
# -*- coding: utf-8 -*-
import logging
from sqlalchemy import text
from website import errors
from website import settings
from website import consts
from website import db
from website.util import shortuuid, aes
from website.service import enterprise
from website.handler import APIHandler, authenticated
class EntityIndexHandler(APIHandler):
"""首页"""
@authenticated
def post(self):
pageNo = self.get_int_argument("pageNo", 1)
pageSize = self.get_int_argument("pageSize", 10)
name = self.tostr(self.get_escaped_argument("name", ""))
with self.app_mysql.connect() as conn:
sql_text = "select id, name, industry, logo, create_at from enterprise "
param = {}
count_sql_text = "select count(id) from enterprise "
count_param = {}
if name:
sql_text += "where name like :name"
param["name"] = "%{}%".format(name)
count_sql_text += "where name like :name"
count_param["name"] = "%{}%".format(name)
sql_text += " order by id desc limit :pageSize offset :offset"
param["pageSize"] = pageSize
param["offset"] = (pageNo - 1) * pageSize
cur = conn.execute(text(sql_text), param)
result = db.to_json_list(cur)
count = conn.execute(text(count_sql_text), count_param).fetchone()[0]
logging.info(count)
logging.info(result)
data = []
for item in result:
modelCount = enterprise.get_enterprise_model_count(item["id"])
deviceCount = enterprise.get_enterprise_device_count(item["id"])
data.append(
{
"id": item["id"],
"name": item["name"],
"industry": item["industry"],
"modelCount": modelCount,
"deviceCount": deviceCount,
"logo": item["logo"],
"createTime": str(item["create_at"])
}
)
self.finish({"count": count, "data": data})
class EntityIndexBasecountHandler(APIHandler):
"""首页基础统计书记"""
@authenticated
def post(self):
entity = enterprise.get_enterprise_entity_count(self.app_mysql)
model = enterprise.get_enterprise_model_count()
device = enterprise.get_enterprise_device_count()
self.finish({"entity": entity, "model": model, "device": device})
class EntityAddHandler(APIHandler):
"""添加企业"""
@authenticated
def post(self):
name = self.tostr(self.get_escaped_argument("name", ""))
province = self.get_escaped_argument("province", "")
city = self.get_escaped_argument("city", "")
addr = self.get_escaped_argument("addr", "")
industry = self.get_int_argument("industry")
contact = self.get_escaped_argument("contact", "")
phone = self.get_escaped_argument("phone", "")
summary = self.get_escaped_argument("summary", "")
logo = self.get_escaped_argument("logo", "")
if not name or not province or not city or not addr or not industry or not contact or not phone or not summary:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if industry not in consts.industry_map:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "清选择行业类型")
if logo and len(logo) * 0.75 / 1024 / 1024 > 1.2:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制")
short_uid = shortuuid.ShortUUID().random(length=8)
pwd = aes.encrypt(settings.enterprise_aes_key, short_uid)
with self.app_mysql.connect() as conn:
conn.execute(
text(
"insert into enterprise(name, province, city, addr, industry, contact, phone, summary, logo, account, pwd) "
"values(:name, :province, :city, :addr, :industry, :contact, :phone, :summary, :logo, :account, :pwd)"
),
{
"name": name,
"province": province,
"city": city,
"addr": addr,
"industry": industry,
"contact": contact,
"phone": phone,
"summary": summary,
"logo": logo,
"account": "admin",
"pwd": pwd,
}
)
conn.commit()
# self.db_app.insert(
# "insert into enterprise(name, province, city, addr, industry, contact, phone, summary, logo, account, pwd) "
# "values(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)",
# name, province, city, addr, industry, contact, phone, summary, logo, "admin", pwd,
# )
self.finish()
class EntityEditHandler(APIHandler):
"""编辑企业"""
@authenticated
def post(self):
eid = self.get_int_argument("id")
name = self.tostr(self.get_escaped_argument("name", ""))
province = self.get_escaped_argument("province", "")
city = self.get_escaped_argument("city", "")
addr = self.get_escaped_argument("addr", "")
industry = self.get_int_argument("industry")
contact = self.get_escaped_argument("contact", "")
phone = self.get_escaped_argument("phone", "")
summary = self.get_escaped_argument("summary", "")
logo = self.get_escaped_argument("logo", "")
account = self.get_escaped_argument("account", "")
if not name or not province or not city or not addr or not industry or not contact or not phone or not summary:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "参数缺失")
if industry not in consts.industry_map:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "清选择行业类型")
if logo and len(logo) * 0.75 / 1024 / 1024 > 1.2:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "Logo图标大小超出1M限制")
with self.app_mysql.connect() as conn:
conn.execute(
text(
# "insert into enterprise(name, province, city, addr, industry, contact, phone, summary, logo, account, pwd) "
# "values(:name, :province, :city, :addr, :industry, :contact, :phone, :summary, :logo, :account, :pwd)"
"update enterprise set name=:name, province=:province, city=:city, addr=:addr, industry=:industry, contact"
"=:contact, phone=:phone, summary=:summary, logo=:logo, account=:account where id=:id",
),
{
"name": name,
"province": province,
"city": city,
"addr": addr,
"industry": industry,
"contact": contact,
"phone": phone,
"summary": summary,
"logo": logo,
"account": account,
}
)
conn.commit()
self.finish()
class EntityInfoHandler(APIHandler):
"""企业信息"""
@authenticated
def post(self):
eid = self.get_int_argument("id")
row = {}
with self.app_mysql.connect() as conn:
cur = conn.execute(
text("select * from enterprise where id=:id"), {"id": eid}
)
# keys = list(cur.keys())
#
# one = cur.fetchone()
# row = dict(zip(keys, one))
# logging.info(db.Row(itertools.zip_longest(keys, one)))
row = db.to_json(cur)
cur.close()
if not row:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "请求失败")
data = {
"name": row["name"],
"province": row["province"],
"city": row["city"],
"industry": row["industry"],
"contact": row["contact"],
"phone": row["phone"],
"summary": row["summary"],
"logo": row["logo"],
"createTime": str(row["create_at"]),
"account": row["account"], # 企业账号
}
self.finish(data)
class EntityPwdcheckHandler(APIHandler):
"""查看企业密码"""
@authenticated
def post(self):
eid = self.get_int_argument("id")
with self.app_mysql.connect() as conn:
cur = conn.execute(text("select pwd from enterprise where id=:id"), {"id": eid})
# row = cur.fetchone()
logging.info(cur)
row = db.to_json(cur)
if not row:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "请求失败")
pwd = row["pwd"]
cur.close()
pwd_dcrypt = aes.decrypt(settings.enterprise_aes_key, pwd)
self.finish({"pwd": pwd_dcrypt})

@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
from website.handlers.enterprise import handler
handlers = [
("/enterprise/entity/index", handler.EntityIndexHandler),
("/enterprise/entity/index/basecount", handler.EntityIndexBasecountHandler),
("/enterprise/entity/add", handler.EntityAddHandler),
("/enterprise/entity/edit", handler.EntityEditHandler),
("/enterprise/entity/info", handler.EntityInfoHandler),
("/enterprise/entity/pwdcheck", handler.EntityPwdcheckHandler),
]
page_handlers = [
]

@ -0,0 +1,127 @@
# -*- coding: utf-8 -*-
"""系统信息"""
import logging
import uuid
import time
import re
import os
import json
import hashlib
import datetime
from website import errors
from website import settings
from website.handler import APIHandler, WebHandler, authenticated, operation_log, permission
from website.util import sysinfo, rsa
class VersionHandler(APIHandler):
@authenticated
# @permission([100014, 100016])
def post(self):
self.finish()
class IdentifycodeHandler(APIHandler):
@authenticated
# @permission([100014, 100015])
# @operation_log("资产管理中心", "系统激活", "查询", "查询本地识别码", "查询本地识别码")
def post(self):
code = sysinfo.get_identify_code()
self.finish({"result": code})
class LicenseUploadHandler(APIHandler):
@authenticated
# @permission([100014, 100015])
# @operation_log("资产管理中心", "系统激活", "导入", "上传license文件", "上传license文件")
def post(self):
file_metas = self.request.files.get('file', None)
if not file_metas:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "请上传文件")
file = file_metas[0]
filename = file.filename
punctuation = """!"#$%&'()*+,/:;<=>?@[\]^`{|}~ """
regex = re.compile('[%s]' % re.escape(punctuation))
filename = regex.sub("", filename.replace('..', ''))
file_size = len(file.body)
if file_size > 10 * 1024 * 1024:
raise errors.HTTPAPIError(errors.ERROR_METHOD_NOT_ALLOWED, 'Exceed 10M size limit')
md5_str = hashlib.md5(file.body).hexdigest()
filepath = settings.rsa_license_file
try:
body = file['body']
public_key = rsa.load_pub_key_string(open(settings.rsa_public_file).read().strip('\n').encode('utf-8'))
plaintext = rsa.decrypt(public_key, body)
plaintext_json = json.loads(self.tostr(plaintext))
syscode = plaintext_json["syscode"]
expireat = plaintext_json["expireat"]
current_syscode = sysinfo.get_identify_code()
if syscode != current_syscode:
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "license激活失败请重新激活")
row = self.db_app.get("select id from license where syscode=%s", syscode)
if row:
self.db_app.update(
"update license set expireat=%s where syscode=%s", str(expireat), syscode
)
else:
self.db_app.insert(
"insert into license(syscode, expireat) values(%s, %s)",
syscode, expireat
)
self.r_app.set("system:license", json.dumps({"syscode":syscode, "expireat":expireat}))
with open(filepath, 'wb') as f:
f.write(file['body'])
logging.info(plaintext_json)
except Exception as e:
logging.info(e)
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "license激活失败请重新激活")
self.finish()
class ActivateInfoHandler(APIHandler):
@authenticated
# @permission([100014, 100015])
# @operation_log("资产管理中心", "系统激活", "查询", "查询系统激活信息", "查询系统激活信息")
def post(self):
license_str = ""
activate_at = ""
expire_at = ""
date_remain = 0
row = self.db_app.get(
"select create_at, expireat from license limit 1"
)
if row:
license_str = open(settings.rsa_license_file, 'r').read()
activate_at = str(row["create_at"])
expire_at = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(row["expireat"])))
now = datetime.datetime.now()
delta = (datetime.datetime.fromtimestamp(int(row["expireat"])).date() - now.date()).days
date_remain = delta if delta > 0 else 0
data = {
"system": settings.system_info[settings.system_type]["name"],
"license": license_str,
"activate_at": activate_at,
"expire_at": expire_at,
"date_remain": date_remain
}
self.finish(data)
class InfoHandler(APIHandler):
def post(self):
self.finish()

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
from website.handlers.system import handler
handlers = [
# ("/", handler.Handler),
("/system/version", handler.VersionHandler),
("/system/identifycode", handler.IdentifycodeHandler),
("/system/license/upload", handler.LicenseUploadHandler),
("/system/activate/info", handler.ActivateInfoHandler),
("/system/info", handler.InfoHandler),
]
page_handlers = [
]

@ -0,0 +1,181 @@
# -*- coding: utf-8 -*-
import logging
import json
import base64
import tornado.web
import uuid
import time
import datetime
import itertools
from io import StringIO, BytesIO
from website import errors
from website import settings
from website import db
from website import consts
from website.handler import APIHandler
from website.util import aes
from website.util.captcha import create_validate_code
from website.service.license import get_license_status
from website.util import shortuuid
from sqlalchemy import text
import tornado.escape
class CaptchaHandler(APIHandler):
def get(self):
self.set_header("Content-Type", "image/png")
image, image_str = create_validate_code()
c = uuid.uuid4().hex
token = self.create_signed_value("logc", c)
self.r_app.set("logincaptcha:%s" % c, image_str, ex=120)
buffered = BytesIO()
# 保存验证码图片
image.save(buffered, 'png')
img_b64 = base64.b64encode(buffered.getvalue())
# for line in buffered.getvalue():
# self.write(line)
# output.close()
self.finish({"token": self.tostr(token), "captcha": self.tostr(img_b64)})
class LogoutHandler(APIHandler):
def get(self):
if self.current_user:
# self.db_app.insert(
# "insert into system_log(user, ip, first_module, second_module, op_type, op_content, description) "
# "values(%s, %s, %s, %s, %s, %s, %s)",
# self.current_user.name, self.request.remote_ip, "平台管理中心", "账号管理", "登出", "系统登出", "系统登出"
# )
self.r_app.delete(settings.session_key_prefix % self.current_user.uuid)
self.finish()
class LoginHandler(APIHandler):
def post(self):
suid = shortuuid.ShortUUID().random(10)
logging.info(suid)
username = self.get_escaped_argument("username")
password = self.get_escaped_argument("pwd")
# captcha = self.get_escaped_argument("captcha", "")
# captcha_token = self.get_escaped_argument("captcha_token", "")
# wrong_time_lock = self.r_app.get("pwd:wrong:time:%s:lock" % self.tostr(username))
# if wrong_time_lock:
# raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "账号处于冷却期,请稍后再试")
# return
logging.info(self.request.arguments)
logging.info(password)
logging.info("#########################")
# if not captcha:
# raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "请输入验证码")
# if not captcha_token:
# raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "缺少参数")
# c = tornado.web.decode_signed_value(
# settings.cookie_secret,
# "logc",
# self.tostr(captcha_token)
# )
# code = self.r_app.get("logincaptcha:%s" % self.tostr(c))
# 清除校验码缓存
# self.r_app.delete("logincaptcha:%s" % c)
# if not code:
# raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "验证码已过期")
# 判断验证码与缓存是否一致
# if self.tostr(captcha).lower() != self.tostr(code).lower():
# raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "验证码错误")
username = self.tostr(username)
password = self.tostr(password)
pwd_enc = aes.encrypt(settings.pwd_aes_key, password)
row = {}
with self.app_mysql.connect() as conn:
cur = conn.execute(
text("select id, uid, available from sys_user where name=:name and pwd=:pwd"),
{"name": username, "pwd": pwd_enc}
)
# keys = list(cur.keys())
#
# one = cur.fetchone()
# row = dict(zip(keys, one))
# logging.info(db.Row(itertools.zip_longest(keys, one)))
row = db.to_json(cur)
cur.close()
# data = [dict(zip(keys, res)) for res in cur.fetchall()]
if not row:
# wrong_time = self.r_app.get("pwd:wrong:time:%s" % username)
# logging.info(wrong_time)
# logging.info(settings.pwd_error_limit - 1)
# if wrong_time and int(wrong_time) > settings.pwd_error_limit - 1:
# self.r_app.set("pwd:wrong:time:%s:lock" % username, 1, ex=3600)
# self.r_app.delete("pwd:wrong:time:%s" % username)
# else:
# self.r_app.incr("pwd:wrong:time:%s" % username)
raise errors.HTTPAPIError(errors.ERROR_BAD_REQUEST, "用户名或者密码错误")
return
if row["available"] == 0:
raise errors.HTTPAPIError(errors.ERROR_FORBIDDEN, "当前用户被禁用")
return
# row_role = self.db_app.get("select role from user_role where userid=%s", row["id"])
# user_role = row_role["role"]
userId = row["id"]
jsessionid = row["uid"]
# create sign value admin_login_sign
secure_cookie = self.create_signed_value(settings.secure_cookie_name, str(jsessionid))
self.r_app.set(
settings.session_key_prefix % jsessionid,
json.dumps({
"id": userId,
"name": username,
"uuid": row["uid"],
# "role": user_role
}),
ex=settings.session_ttl
)
# self.db_app.insert(
# "insert into system_log(user, ip, first_module, second_module, op_type, op_content, description) "
# "values(%s, %s, %s, %s, %s, %s, %s)",
# username, self.request.remote_ip, "平台管理中心", "账号管理", "登录", "系统登录", "系统登录"
# )
# license_row = self.db_app.get(
# "select expireat from license limit 1"
# )
# system_status = get_license_status(license_row)
render_data = {
"token": str(secure_cookie, encoding="utf-8"),
# "role": user_role,
"username": username,
# "system_status": system_status, # 9000/未激活, 9001/已激活, 9002/过期可查看, 9003/完全过期
}
self.finish(render_data)
class UserInfoHandler(APIHandler):
def post(self):
token = self.get_argument("token")
user = self.get_current_user(token_body=self.tostr(token))
if not user:
raise errors.HTTPAPIError(errors.ERROR_UNAUTHORIZED)
self.finish({"name": user.name, "role": user.role})

@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
from website.handlers.user import handler
handlers = [
# ("/user/list", handler.UserListHandler),
# ("/captcha", handler.CaptchaHandler),
# ("/bodyargument", handler.BodyHandler),
# ("/user/info", handler.UserInfoHandler),
("/login", handler.LoginHandler),
("/logout", handler.LogoutHandler),
]
page_handlers = [
]

@ -0,0 +1,31 @@
from website.handler import BaseHandler
from sqlalchemy import text
# 获取企业模型数量
def get_enterprise_model_count(id):
return 0
# 获取企业设备数量
def get_enterprise_device_count(id):
return 0
# 获取所有企业实体数量
def get_enterprise_entity_count(engine):
with engine.connect() as conn:
count_sql_text = "select count(id) from enterprise "
count = conn.execute(text(count_sql_text)).fetchone()
if count:
return count[0]
return 0
# 获取所有企业模型数量
def get_enterprise_model_count():
return 0
# 获取所有企业设备数量
def get_enterprise_device_count():
return 0

@ -0,0 +1,32 @@
import time
import datetime
import logging
from website import consts, settings
def get_license_status(license):
status = consts.system_status_not_active
# if not license:
# pass
if license:
now = datetime.datetime.now()
timestamp_now = int(now.timestamp())
expireat = int(license["expireat"])
expireat_datetime = datetime.datetime.fromtimestamp(expireat)
expireat_next30days = expireat_datetime + datetime.timedelta(days=30)
expireat_next30days_timestamp = int(expireat_next30days.timestamp())
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
# logging.info(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(expireat)))
# logging.info(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(expireat_next30days_timestamp)))
# logging.info(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp_now)))
if timestamp_now >= expireat_next30days_timestamp:
status = consts.system_status_expire_atall
elif timestamp_now >= expireat and timestamp_now < expireat_next30days_timestamp:
status = consts.system_status_expire_but_ok
elif timestamp_now < expireat:
status = consts.system_status_activated
return status

@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
# sqlalchemy
SQLALCHEMY_ECHO = 0
# db mysql
mysql_app = {
"host": "127.0.0.1:3306",
"database": "",
"user": "root",
"password": "",
"time_zone": "+8:00"
}
redis_app = ("127.0.0.1", 6382, 0, "")
redis_app_cluster_notes = [
{"host": "127.0.0.1", "port": "6379"},
{"host": "127.0.0.1", "port": "6380"}
]
redis_cluster_pwd = ""
redis_cluster = 0
# redis_sentinel = 1
# redis_sentinel_nodes = [
# ('192.168.0.1', 40500),
# ('192.168.0.2', 40500)
# ]
# redis_sentinel_master = ""
# redis_sentinel_pwd = ""
# session expire duration
session_ttl = 3600
session_key_prefix = "system:user:jid:%s"
# convert get to post
app_get_to_post = False
# system host
host = "https://gap.sst.com"
# api root
api_root = "https://gap.sst.com"
api_prefix = "/api/v1"
api_prefix_v2 = "/api/v2"
# api freq limit
api_count_in_ten_second = 20
login_url = "/login"
cookie_domain = ".sst.com"
# cookie_domain = "localhost"
# cookie_secret, generate method : base64.b64encode(uuid.uuid4().bytes + uuid.uuid4().bytes)
cookie_secret = "rn43LMFOQJu1w8lJXlN93Oc7GOqo3kiTvqOq4IrTDjk="
cookie_path = "/"
cookie_key = "system:jsessionid"
secure_cookie_name = "sst"
pwd_error_limit = 10
enable_curl_async_http_client = True
max_clients = 300
remote_request_timeout = 10.0 # 异步调用远程api超时时间 单位:Second
remote_connect_timeout = 10.0 # 异步调用远程api连接超时时间单位Second
# 系统密码加密秘钥
pwd_aes_key = "FquMBlcVoIkTAmL7"
enterprise_aes_key = "FquMBlcVoIkTAmL7"
rsa_public_file = "/data/gap/public"
rsa_license_file = "/data/gap/license"
# hashlib.sha256(base64.b64encode(uuid.uuid4().bytes + uuid.uuid4().bytes)).hexdigest()
# 线上配置信息使用settings_local.py
try:
from settings_local import *
except:
pass

@ -0,0 +1,17 @@
import os
import sys
import importlib
handlers = []
# handlers_v2 = []
page_handlers = []
handlers_path = os.path.join(os.getcwd(), "handlers")
sys.path.append(handlers_path)
handlers_dir = os.listdir(handlers_path)
for item in handlers_dir:
if os.path.isdir(os.path.join(handlers_path, item)):
hu = importlib.import_module("{}.url".format(item))
handlers.extend(hu.handlers)
page_handlers.extend(hu.page_handlers)

@ -0,0 +1,145 @@
#!/usr/bin/env python
# -*- coding=utf-8 -*-
"""
AES加密解密工具类
@author jzx
@date 2018/10/24
此工具类加密解密结果与 http://tool.chacuo.net/cryptaes 结果一致
数据块128位
key 为16位
iv 为16位且与key相等
字符集utf-8
输出为base64
AES加密模式 为cbc
填充 pkcs7padding
"""
import base64
from Crypto.Cipher import AES
import random
def pkcs7padding(text):
"""
明文使用PKCS7填充
最终调用AES加密方法时传入的是一个byte数组要求是16的整数倍因此需要对明文进行处理
:param text: 待加密内容(明文)
:return:
"""
bs = AES.block_size # 16
length = len(text)
bytes_length = len(bytes(text, encoding='utf-8'))
# tipsutf-8编码时英文占1个byte而中文占3个byte
padding_size = length if (bytes_length == length) else bytes_length
padding = bs - padding_size % bs
# tipschr(padding)看与其它语言的约定,有的会使用'\0'
padding_text = chr(padding) * padding
return text + padding_text
def pkcs7unpadding(text):
"""
处理使用PKCS7填充过的数据
:param text: 解密后的字符串
:return:
"""
length = len(text)
unpadding = ord(text[length - 1])
return text[0:length - unpadding]
def encrypt(key, content):
"""
AES加密
key,iv使用同一个
模式cbc
填充pkcs7
:param key: 密钥
:param content: 加密内容
:return:
"""
key_bytes = bytes(key, encoding='utf-8')
iv = key_bytes
cipher = AES.new(key_bytes, AES.MODE_CBC, iv)
# 处理明文
content_padding = pkcs7padding(content)
# 加密
encrypt_bytes = cipher.encrypt(bytes(content_padding, encoding='utf-8'))
# 重新编码
result = str(base64.b64encode(encrypt_bytes), encoding='utf-8')
return result
def decrypt(key, content):
"""
AES解密
key,iv使用同一个
模式cbc
去填充pkcs7
:param key:
:param content:
:return:
"""
key_bytes = bytes(key, encoding='utf-8')
iv = key_bytes
cipher = AES.new(key_bytes, AES.MODE_CBC, iv)
# base64解码
encrypt_bytes = base64.b64decode(content)
# 解密
decrypt_bytes = cipher.decrypt(encrypt_bytes)
# 重新编码
result = str(decrypt_bytes, encoding='utf-8')
# 去除填充内容
result = pkcs7unpadding(result)
return result
def get_key(n):
"""
获取密钥 n 密钥长度
:return:
"""
c_length = int(n)
source = 'ABCDEFGHJKMNPQRSTWXYZabcdefhijkmnprstwxyz2345678'
length = len(source) - 1
result = ''
for i in range(c_length):
result += source[random.randint(0, length)]
return result
if __name__ == "__main__":
# Test
# 非16字节的情况
aes_key = get_key(16)
print('aes_key:' + aes_key)
# 对英文加密
source_en = 'Hello!+world'
encrypt_en = encrypt(aes_key, source_en)
print(encrypt_en)
# 解密
decrypt_en = decrypt(aes_key, encrypt_en)
print(decrypt_en)
print(source_en == decrypt_en)
# 中英文混合加密
source_mixed = 'Hello, 韩- 梅 -梅'
encrypt_mixed = encrypt(aes_key, source_mixed)
print(encrypt_mixed)
decrypt_mixed = decrypt(aes_key, encrypt_mixed)
print(decrypt_mixed)
print(decrypt_mixed == source_mixed)
# 刚好16字节的情况
en_16 = 'abcdefgj10124567'
encrypt_en = encrypt(aes_key, en_16)
print(encrypt_en)
# 解密
decrypt_en = decrypt(aes_key, encrypt_en)
print(decrypt_en)
print(en_16 == decrypt_en)
mix_16 = 'abx张三丰12sa'
encrypt_mixed = encrypt(aes_key, mix_16)
print(encrypt_mixed)
decrypt_mixed = decrypt(aes_key, encrypt_mixed)
print(decrypt_mixed)
print(decrypt_mixed == mix_16)

@ -0,0 +1,171 @@
# -*- coding: utf-8 -*-
import logging
import os
import random
import settings
from PIL import Image, ImageDraw, ImageFont, ImageFilter
_letter_cases = "abcdefghjkmnpqrstuvwxy" # 小写字母,取出干扰
_upper_cases = _letter_cases.upper() # 大写字母
_numbers = ''.join(map(str, range(3, 10))) # 数字
init_chars = ''.join((_letter_cases, _upper_cases, _numbers))
def create_validate_code(
size=(140, 30),
chars=init_chars,
img_type="GIF",
mode="RGB",
bg_color=(255, 255, 255),
fg_color=(0, 0, 255),
font_size=18,
font_type="ttf/%s.ttf" % random.randint(1, 6),
length=4,
draw_lines=True,
n_line=(3, 8),
draw_points=True,
point_chance=20):
'''
@todo: 生成验证码图片
@param size: 图片的大小格式默认为(120, 30)
@param chars: 允许的字符集合格式字符串
@param img_type: 图片保存的格式默认为GIF可选的为GIFJPEGTIFFPNG
@param mode: 图片模式默认为RGB
@param bg_color: 背景颜色默认为白色
@param fg_color: 前景色验证码字符颜色默认为蓝色#0000FF
@param font_size: 验证码字体大小
@param font_type: 验证码字体默认为 ae_AlArabiya.ttf
@param length: 验证码字符个数
@param draw_lines: 是否划干扰线
@param n_lines: 干扰线的条数范围格式元组默认为(1, 2)只有draw_lines为True时有效
@param draw_points: 是否画干扰点
@param point_chance: 干扰点出现的概率大小范围[0, 100]
@return: [0]: PIL Image实例
@return: [1]: 验证码图片中的字符串
'''
width, height = size # 宽, 高
img = Image.new(mode, size, bg_color) # 创建图形
draw = ImageDraw.Draw(img) # 创建画笔
def get_chars():
'''生成给定长度的字符串,返回列表格式'''
return random.sample(chars, length)
def rndColor():
return (random.randint(64, 255), random.randint(64, 255), random.randint(64, 255))
# 随机颜色2:
def rndColor2():
return (random.randint(32, 127), random.randint(32, 127), random.randint(32, 127))
def create_lines():
'''绘制干扰线'''
line_num = random.randint(*n_line) # 干扰线条数
for i in range(line_num):
# 起始点
begin = (random.randint(0, size[0]), random.randint(0, size[1]))
# 结束点
end = (random.randint(0, size[0]), random.randint(0, size[1]))
draw.line([begin, end], fill=(0, 0, 0))
def create_points():
'''绘制干扰点'''
chance = min(100, max(0, int(point_chance))) # 大小限制在[0, 100]
# 填充每个像素:
# for x in range(width):
# for y in range(height):
# draw.point((x, y), fill=rndColor())
for w in range(width):
for h in range(height):
tmp = random.randint(0, 100)
if tmp > 100 - chance:
# draw.point((w, h), fill=(0, 0, 0))
draw.point((w, h), fill=rndColor())
def create_strs():
'''绘制验证码字符'''
c_chars = get_chars()
strs = ' %s ' % ' '.join(c_chars) # 每个字符前后以空格隔开
logging.info(os.getcwd())
# font_type = settings.captcha_font_path + "%s.ttf" % random.randint(1, 4)
font_type = os.getcwd() + "/util/ttf/%s.ttf" % random.randint(1, 4)
logging.info("font type is : %s " % font_type)
font = ImageFont.truetype(font_type, font_size)
font_width, font_height = font.getsize(strs)
draw.text(((width - font_width) / 3, (height - font_height) / 3),
strs, font=font, fill=fg_color)
return ''.join(c_chars)
if draw_lines:
create_lines()
if draw_points:
create_points()
strs = create_strs()
# 图形扭曲参数
params = [1 - float(random.randint(1, 2)) / 100,
0,
0,
0,
1 - float(random.randint(1, 10)) / 100,
float(random.randint(1, 2)) / 500,
0.001,
float(random.randint(1, 2)) / 500
]
img = img.transform(size, Image.PERSPECTIVE, params) # 创建扭曲
#img = img.filter(ImageFilter.BLUR) # 滤镜,边界加强(阈值更大)
#img = img.filter(ImageFilter.CONTOUR)
return img, strs
"""
#字体的位置,不同版本的系统会有不同
font_path = '/Library/Fonts/Arial.ttf'
#生成几位数的验证码
number = 4
#生成验证码图片的高度和宽度
size = (100,30)
#背景颜色,默认为白色
bgcolor = (255,255,255)
#字体颜色,默认为蓝色
fontcolor = (0,0,255)
#干扰线颜色。默认为红色
linecolor = (255,0,0)
#是否要加入干扰线
draw_line = True
#加入干扰线条数的上下限
line_number = (1,5)
#用来随机生成一个字符串
def gene_text():
source = list(string.letters)
for index in range(0,10):
source.append(str(index))
return ''.join(random.sample(source,number))#number是生成验证码的位数
#用来绘制干扰线
def gene_line(draw,width,height):
begin = (random.randint(0, width), random.randint(0, height))
end = (random.randint(0, width), random.randint(0, height))
draw.line([begin, end], fill = linecolor)
#生成验证码
def gene_code():
width,height = size #宽和高
image = Image.new('RGBA',(width,height),bgcolor) #创建图片
font = ImageFont.truetype(font_path,25) #验证码的字体
draw = ImageDraw.Draw(image) #创建画笔
text = gene_text() #生成字符串
font_width, font_height = font.getsize(text)
draw.text(((width - font_width) / number, (height - font_height) / number),text,
font= font,fill=fontcolor) #填充字符串
if draw_line:
gene_line(draw,width,height)
# image = image.transform((width+30,height+10), Image.AFFINE, (1,-0.3,0,-0.1,1,0),Image.BILINEAR) #创建扭曲
image = image.transform((width+20,height+10), Image.AFFINE, (1,-0.3,0,-0.1,1,0),Image.BILINEAR) #创建扭曲
image = image.filter(ImageFilter.EDGE_ENHANCE_MORE) #滤镜,边界加强
image.save('idencode.png') #保存验证码图片
"""

@ -0,0 +1,71 @@
from crontab import CronTab
from datetime import datetime
import json
import getpass
user = getpass.getuser()
class CronOpt:
#创建定时任务
def __init__(self,user=user):
self.initialize(user)
#初始化
def initialize(self,user=user):
self.cron = CronTab(user)
#查询列表
def select(self,reInit = False):
if reInit != False:
# 强制重新读取列表
self.initialize()
cronArray = []
for job in self.cron:
# print job.command
schedule = job.schedule(date_from=datetime.now())
cronDict = {
"task" : (job.command).replace(r'>/dev/null 2>&1', ''),
"next" : str(schedule.get_next()),
"prev" : str(schedule.get_prev()),
"comment": job.comment
}
cronArray.append(cronDict)
return cronArray
#新增定时任务
def add(self, command, timeStr, commentName):
# 创建任务
job = self.cron.new(command=command)
# 设置任务执行周期
job.setall(timeStr)
# 备注也是命令id
job.set_comment(commentName)
# 写入到定时任务
self.cron.write()
# 返回更新后的列表
return self.select(True)
#删除定时任务
def delCron(self,commentName):
for job in self.cron :
if job.comment == commentName:
self.cron.remove(job)
self.cron.write()
return self.select(True)
#更新定时任务
def update(self, command, timeStr, commentName):
# 先删除任务
self.delCron(commentName)
# 再创建任务,以达到更新的结果
return self.add(command, timeStr, commentName)
if __name__ == '__main__':
c = CronOpt()
#print(c.select())
# 新增定时任务
#print(c.add("ls /home","*/10 * * * *","测试"))
# 删除定时任务
print(c.delCron("测试"))

@ -0,0 +1,43 @@
#!/usr/bin/env python
import datetime
# 获取过去几天的日期
def get_last_days(delta):
days = []
now = datetime.datetime.now()
for i in range(1, delta):
last_date = now + datetime.timedelta(days=-i)
last_date_str = last_date.strftime('%Y-%m-%d') # 格式化输出
days.append(last_date_str)
days.sort()
return days
class DateNext(object):
def get_next_week_date(self, weekday):
# 获取下周几的日期
# weekday取值如下
# (MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, SATURDAY, SUNDAY) = range(7)
today = datetime.date.today()
onday = datetime.timedelta(days=1)
# m1 = calendar.MONDAY
while today.weekday() != weekday:
today += onday
# logging.info(datetime.date(today.year, today.month, today.day))
nextWeekDay = today.strftime('%Y-%m-%d')
return nextWeekDay
def get_next_month_date(self, day_of_the_month):
# 获取下个月几号的日期
today_date = datetime.datetime.today()
today = today_date.day
if today == day_of_the_month:
return today_date.strftime('%Y-%m-%d')
if today < day_of_the_month:
return datetime.date(today_date.year, today_date.month, day_of_the_month).strftime('%Y-%m-%d')
if today > day_of_the_month:
if today_date.month == 12:
return datetime.date(today_date.year+1, 1, day_of_the_month).strftime('%Y-%m-%d')
return datetime.date(today_date.year, today_date.month+1, day_of_the_month).strftime('%Y-%m-%d')

@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
from M2Crypto import RSA
from M2Crypto import BIO
from binascii import a2b_hex, b2a_hex
def load_pub_key_string(string):
bio = BIO.MemoryBuffer(string)
return RSA.load_pub_key_bio(bio)
def block_data(texts, block_size):
for i in range(0, len(texts), block_size):
yield texts[i:i + block_size]
def decrypt(publick_key, texts):
plaintext = b""
block_size = 256
for text in block_data(a2b_hex(texts), block_size):
current_text = publick_key.public_decrypt(text, RSA.pkcs1_padding)
plaintext += current_text
return plaintext

@ -0,0 +1,137 @@
"""Concise UUID generation."""
import math
import secrets
import uuid as _uu
from typing import List
from typing import Optional
def int_to_string(
number: int, alphabet: List[str], padding: Optional[int] = None
) -> str:
"""
Convert a number to a string, using the given alphabet.
The output has the most significant digit first.
"""
output = ""
alpha_len = len(alphabet)
while number:
number, digit = divmod(number, alpha_len)
output += alphabet[digit]
if padding:
remainder = max(padding - len(output), 0)
output = output + alphabet[0] * remainder
return output[::-1]
def string_to_int(string: str, alphabet: List[str]) -> int:
"""
Convert a string to a number, using the given alphabet.
The input is assumed to have the most significant digit first.
"""
number = 0
alpha_len = len(alphabet)
for char in string:
number = number * alpha_len + alphabet.index(char)
return number
class ShortUUID(object):
def __init__(self, alphabet: Optional[str] = None) -> None:
if alphabet is None:
alphabet = "23456789ABCDEFGHJKLMNPQRSTUVWXYZ" "abcdefghijkmnopqrstuvwxyz"
self.set_alphabet(alphabet)
@property
def _length(self) -> int:
"""Return the necessary length to fit the entire UUID given the current alphabet."""
return int(math.ceil(math.log(2**128, self._alpha_len)))
def encode(self, uuid: _uu.UUID, pad_length: Optional[int] = None) -> str:
"""
Encode a UUID into a string (LSB first) according to the alphabet.
If leftmost (MSB) bits are 0, the string might be shorter.
"""
if not isinstance(uuid, _uu.UUID):
raise ValueError("Input `uuid` must be a UUID object.")
if pad_length is None:
pad_length = self._length
return int_to_string(uuid.int, self._alphabet, padding=pad_length)
def decode(self, string: str, legacy: bool = False) -> _uu.UUID:
"""
Decode a string according to the current alphabet into a UUID.
Raises ValueError when encountering illegal characters or a too-long string.
If string too short, fills leftmost (MSB) bits with 0.
Pass `legacy=True` if your UUID was encoded with a ShortUUID version prior to
1.0.0.
"""
if not isinstance(string, str):
raise ValueError("Input `string` must be a str.")
if legacy:
string = string[::-1]
return _uu.UUID(int=string_to_int(string, self._alphabet))
def uuid(self, name: Optional[str] = None, pad_length: Optional[int] = None) -> str:
"""
Generate and return a UUID.
If the name parameter is provided, set the namespace to the provided
name and generate a UUID.
"""
if pad_length is None:
pad_length = self._length
# If no name is given, generate a random UUID.
if name is None:
u = _uu.uuid4()
elif name.lower().startswith(("http://", "https://")):
u = _uu.uuid5(_uu.NAMESPACE_URL, name)
else:
u = _uu.uuid5(_uu.NAMESPACE_DNS, name)
return self.encode(u, pad_length)
def random(self, length: Optional[int] = None) -> str:
"""Generate and return a cryptographically secure short random string of `length`."""
if length is None:
length = self._length
return "".join(secrets.choice(self._alphabet) for _ in range(length))
def get_alphabet(self) -> str:
"""Return the current alphabet used for new UUIDs."""
return "".join(self._alphabet)
def set_alphabet(self, alphabet: str) -> None:
"""Set the alphabet to be used for new UUIDs."""
# Turn the alphabet into a set and sort it to prevent duplicates
# and ensure reproducibility.
new_alphabet = list(sorted(set(alphabet)))
if len(new_alphabet) > 1:
self._alphabet = new_alphabet
self._alpha_len = len(self._alphabet)
else:
raise ValueError("Alphabet with more than " "one unique symbols required.")
def encoded_length(self, num_bytes: int = 16) -> int:
"""Return the string length of the shortened UUID."""
factor = math.log(256) / math.log(self._alpha_len)
return int(math.ceil(factor * num_bytes))
# For backwards compatibility
_global_instance = ShortUUID()
encode = _global_instance.encode
decode = _global_instance.decode
uuid = _global_instance.uuid
random = _global_instance.random
get_alphabet = _global_instance.get_alphabet
set_alphabet = _global_instance.set_alphabet

@ -0,0 +1,61 @@
import subprocess
import socket
import hashlib
import uuid
import logging
def get_cpu_id():
p = subprocess.Popen(["dmidecode -t 4 | grep ID"],
shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
data = p.stdout
lines = []
while True:
line = str(data.readline(), encoding="utf-8")
if line == '\n':
break
if line:
d = dict([line.strip().split(': ')])
lines.append(d)
else:
break
return lines
def get_board_serialnumber():
p = subprocess.Popen(["dmidecode -t 2 | grep Serial"],
shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
data = p.stdout
lines = []
while True:
line = str(data.readline(), encoding="utf-8")
if line == '\n':
break
if line:
d = dict([line.strip().split(': ')])
lines.append(d)
else:
break
return lines
def get_identify_code():
mac = uuid.UUID(int=uuid.getnode()).hex[-12:]
mac_addr = ":".join([mac[e:e + 2] for e in range(0, 11, 2)])
host_name = socket.getfqdn(socket.gethostname())
cpu_ids = get_cpu_id()
serialnumbers = get_board_serialnumber()
s = ""
if mac_addr:
s += mac_addr
if host_name:
s += host_name
if cpu_ids:
for cpu in cpu_ids:
s += cpu["ID"]
if serialnumbers:
for number in serialnumbers:
s += number["Serial Number"]
logging.info(s)
code = hashlib.new('md5', s.encode("utf8")).hexdigest()
return code

@ -0,0 +1,103 @@
# -*- coding:utf8 -*-
import io
import pandas as pd
import tornado.gen as gen
import xlwt
class Excel(object):
@gen.coroutine
def generate_excel(self, head, rows):
"""
head is a dict, eg: [(0, u"编号"), (1, u"地址")]
rows is detail list, eg: [[0, "XXX"], ...]
"""
workbook = xlwt.Workbook(encoding='utf-8')
worksheet = workbook.add_sheet("sheet1")
row_num = 0
# col_num = 0
for item in head:
worksheet.write(row_num, head.index(item), item[1])
for row in rows:
row_num += 1
col_num = 0
for col in row:
worksheet.write(row_num, col_num, col)
col_num += 1
sio = io.BytesIO()
workbook.save(sio)
raise gen.Return(sio)
@gen.coroutine
# def generate_excel_pd(self, index, data, columns):
def generate_excel_pd(self, pd_data):
"""
pandas 构建图表
"""
sio = io.StringIO()
writer = pd.ExcelWriter(sio, engine='xlsxwriter')
for data in pd_data:
df = pd.DataFrame(data=data["data"], index=data["index"], columns=data["columns"])
sheet_name = data["sheet_name"]
df.to_excel(writer, sheet_name=sheet_name)
workbook = writer.book
worksheet = writer.sheets[sheet_name]
chart = workbook.add_chart({'type': 'line'})
max_row = len(df) + 1
for i in range(len(data['columns'])):
col = i + 1
chart.add_series({
# 'name': ['Sheet1', 0, col],
'name': [sheet_name, 0, col],
'categories': [sheet_name, 1, 0, max_row, 0],
'values': [sheet_name, 1, col, max_row, col],
'line': {'width': 1.00},
})
chart.set_x_axis({'name': 'Date', 'date_axis': True})
chart.set_y_axis({'name': 'Statistics', 'major_gridlines': {'visible': False}})
chart.set_legend({'position': 'top'})
worksheet.insert_chart('H2', chart)
# df = pd.DataFrame(data=data, index=index, columns=columns)
"""
# ================ anothor method =================
# workbook.save(sio)
io = StringIO.StringIO()
# Use a temp filename to keep pandas happy.
writer = pd.ExcelWriter('temp.xls', engine='xlsxwriter')
# Set the filename/file handle in the xlsxwriter.workbook object.
writer.book.filename = io
#
# Write the data frame to the StringIO object.
df.to_excel(writer, sheet_name='Sheet1')
writer.save()
xlsx_data = io.getvalue()
# ================ anothor method =================
"""
# sheet_name = 'Sheet1'
writer.save()
raise gen.Return(sio)
Loading…
Cancel
Save