You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

583 lines
20 KiB
Python

# -*- coding: utf-8 -*-
import ast
import functools
import hashlib
import json
import logging
import re
import time
import traceback
import urllib
from typing import Any
# import urlparse
from urllib.parse import parse_qs, unquote
# from urllib import unquote
import tornado
from sqlalchemy import text
from tornado import escape
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
from tornado.options import options
# from tornado.web import RequestHandler as BaseRequestHandler, HTTPError, asynchronous
from tornado.web import RequestHandler as BaseRequestHandler, HTTPError
from torndb import Row
from website import errors
from website import settings
from website.service.license import get_license_status
if settings.enable_curl_async_http_client:
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
else:
AsyncHTTPClient.configure(None, max_clients=settings.max_clients)
REMOVE_SLASH_RE = re.compile(".+/$")
def _callback_wrapper(callback):
"""A wrapper to handling basic callback error"""
def _wrapper(response):
if response.error:
logging.error("call remote api err: %s" % response)
if isinstance(response.error, tornado.httpclient.HTTPError):
raise errors.HTTPAPIError(response.error.code, "网络连接失败")
else:
raise errors.HTTPAPIError(errors.ERROR_UNKNOWN_ERROR, "未知错误")
else:
callback(response)
return _wrapper
class BaseHandler(BaseRequestHandler):
# def __init__(self, application, request, **kwargs):
# super(BaseHandler, self).__init__(application, request, **kwargs)
# self.xsrf_form_html()
def set_default_headers(self):
self.set_header("Access-Control-Allow-Origin", "*")
self.set_header("Access-Control-Allow-Headers",
"DNT,token,X-CustomHeader,Keep-Alive,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,"
"Content-Type")
# self.set_header("Access-Control-Allow-Headers",
# "DNT,web-token,app-token,Authorization,Accept,Origin,Keep-Alive,User-Agent,X-Mx-ReqToken,"
# "X-Data-Type,X-Auth-Token,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range")
self.set_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS')
# def _request_summary(self):
#
# return "%s %s %s(@%s)" % (self.request.method, self.request.uri, self.request.body.decode(),
# self.request.remote_ip)
def get(self, *args, **kwargs):
# enable GET request when enable delegate get to post
if settings.app_get_to_post:
self.post(*args, **kwargs)
else:
raise HTTPError(405)
def render(self, template_name, **kwargs):
if self.current_user:
if 'username' not in kwargs:
kwargs["username"] = self.current_user.name
if 'current_userid' not in kwargs:
kwargs["current_userid"] = self.current_user.id
if 'role' not in kwargs:
kwargs["role"] = self.current_user.role
self.set_header("Cache-control", "no-cache")
return super(BaseHandler, self).render(template_name, **kwargs)
# @property
# def db_app(self):
# return self.application.db_app
@property
def app_mysql(self):
return self.application.app_mysql
@property
def r_app(self):
return self.application.r_app
@property
def kafka_producer(self):
return self.application.kafka_producer
#
@property
def es(self):
return self.application.es
# @property
# def nsq(self):
# return self.application.nsq
def _call_api(self, url, headers, body, callback, method, callback_wrapper=_callback_wrapper):
start = 0
if callback_wrapper:
callback = callback_wrapper(callback)
try:
start = time.time()
AsyncHTTPClient().fetch(HTTPRequest(url=url,
method=method,
body=body,
headers=headers,
allow_nonstandard_methods=True,
connect_timeout=settings.remote_connect_timeout,
request_timeout=settings.remote_request_timeout,
follow_redirects=False),
callback)
except tornado.httpclient.HTTPError:
logging.error("requet from %s, take time: %s" % (url, (time.time() - start) * 1000))
# if hasattr(x, "response") and x.response:
# callback(x.response)
# else:
# logging.error("Tornado signalled HTTPError %s", x)
# raise x
# @asynchronous
async def call_api(self, url, headers=None, body=None, callback=None, method="POST"):
if callback is None:
callback = self.call_api_callback
if headers is None:
headers = self.request.headers
if body is None:
body = self.request.body
else:
# make sure it is a post request
headers["Content-Type"] = "application/x-www-form-urlencoded"
self._call_api(url, headers, body, callback, method)
def get_current_user(self, token_body=None):
# jid = self.get_secure_cookie(settings.cookie_key)
token = self.request.headers.get("token")
if token_body:
token = token_body
if not token:
return None
token = unquote(token)
jid = tornado.web.decode_signed_value(
settings.cookie_secret,
settings.secure_cookie_name,
token
)
jid = jid and str(jid, encoding="utf-8") or ""
key = settings.session_key_prefix % jid
user = self.r_app.get(key)
if user:
if "/user/info" in self.request.path:
pass
else:
self.r_app.expire(key, settings.session_ttl)
user = str(user, encoding='utf8') if isinstance(user, bytes) else user
# return Row(ast.literal_eval(str(user, encoding="utf-8")))
return Row(ast.literal_eval(user))
else:
return None
def md5compare(self, name):
string = unquote(name)
num1 = string.split("|")[0]
num2 = string.split("|")[1]
num3 = string.split("|")[2]
num = num1 + num2
md5string = hashlib.md5(num).hexdigest().upper()
if md5string == num3:
return True
else:
return False
def tostr(self, src):
return str(src, encoding='utf8') if isinstance(src, bytes) else src
def prepare(self):
self.remove_slash()
self.prepare_context()
self.set_default_jsonbody()
# self.traffic_threshold()
def set_default_jsonbody(self):
if self.request.headers.get('Content-Type') == 'application/json' and self.request.body:
# logging.info(self.request.headers.get('Content-Type'))
# if self.request.headers.get('Content-Type') == 'application/json; charset=UTF-8':
json_body = tornado.escape.json_decode(self.request.body)
for key, value in json_body.items():
if value is not None:
if type(value) is list:
self.request.arguments.setdefault(key, []).extend(value)
elif type(value) is dict:
self.request.arguments[key] = value
else:
self.request.arguments.setdefault(key, []).extend([bytes(str(value), 'utf-8')])
def traffic_threshold(self):
# if self.request.uri in ["/api/download"]:
# return
if not self.current_user:
user_id = self.request.remote_ip
else:
user_id = self.current_user.id
freq_key = "API:FREQ:%s:%s" % (user_id, int(time.time()) / 10)
send_count = self.r_app.incr(freq_key)
if send_count > settings.api_count_in_ten_second:
freq_key = "API:FREQ:%s:%s" % (user_id, int(time.time()) / 10 + 1)
self.r_app.setex(freq_key, send_count, 10)
raise errors.HTTPAPIError(
errors.ERROR_METHOD_NOT_ALLOWED, "请勿频繁操作")
if send_count == 1:
self.r_app.expire(freq_key, 10)
def prepare_context(self):
# self.nsq.pub(settings.nsq_topic_stats, escape.json_encode(self._create_stats_msg()))
pass
def remove_slash(self):
if self.request.method == "GET":
if REMOVE_SLASH_RE.match(self.request.path):
# remove trail slash in path
uri = self.request.path.rstrip("/")
if self.request.query:
uri += "?" + self.request.query
self.redirect(uri)
# def get_json_argument(self, name, default=BaseRequestHandler._ARG_DEFAULT):
def get_json_argument(self, name, default=object()):
json_body = tornado.escape.json_decode(self.request.body)
value = json_body.get(name, default)
return escape.utf8(value) if isinstance(value, str) else value
def get_int_json_argument(self, name, default=0):
try:
return int(self.get_json_argument(name, default))
except ValueError:
return default
def get_escaped_json_argument(self, name, default=None):
if default is not None:
return self.escape_string(self.get_json_argument(name, default))
else:
return self.get_json_argument(name, default)
# def get_argument(self, name, default=BaseRequestHandler._ARG_DEFAULT, strip=True):
def get_argument(self, name, default=object(), strip=True):
value = super(BaseHandler, self).get_argument(name, default, strip)
return escape.utf8(value) if isinstance(value, str) else value
# def get_int_argument(self, name, default=0):
def get_int_argument(self, name: Any, default: int = 0) -> int:
try:
return int(self.get_argument(name, default))
except ValueError:
return default
def get_float_argument(self, name, default=0.0):
try:
return float(self.get_argument(name, default))
except ValueError:
return default
def get_uint_arg(self, name, default=0):
try:
return abs(int(self.get_argument(name, default)))
except ValueError:
return default
def unescape_string(self, s):
return escape.xhtml_unescape(s)
def escape_string(self, s):
return escape.xhtml_escape(s)
def get_escaped_argument(self, name, default=None):
if default is not None:
return self.escape_string(self.get_argument(name, default))
else:
return self.get_argument(name, default)
def get_page_url(self, page, form_id=None, tab=None):
if form_id:
return "javascript:goto_page('%s',%s);" % (form_id.strip(), page)
path = self.request.path
query = self.request.query
# qdict = urlparse.parse_qs(query)
qdict = parse_qs(query)
for k, v in qdict.items():
if isinstance(v, list):
qdict[k] = v and v[0] or ''
qdict['page'] = page
if tab:
qdict['tab'] = tab
return path + '?' + urllib.urlencode(qdict)
def find_all(self, target, substring):
current_pos = target.find(substring)
while current_pos != -1:
yield current_pos
current_pos += len(substring)
current_pos = target.find(substring, current_pos)
class WebHandler(BaseHandler):
def finish(self, chunk=None, message=None):
callback = escape.utf8(self.get_argument("callback", None))
if callback:
self.set_header("Content-Type", "application/x-javascript")
if isinstance(chunk, dict):
chunk = escape.json_encode(chunk)
self._write_buffer = [callback, "(", chunk, ")"] if chunk else []
super(WebHandler, self).finish()
else:
self.set_header("Cache-control", "no-cache")
super(WebHandler, self).finish(chunk)
def write_error(self, status_code, **kwargs):
try:
exc_info = kwargs.pop('exc_info')
e = exc_info[1]
if isinstance(e, errors.HTTPAPIError):
pass
elif isinstance(e, HTTPError):
if e.status_code == 401:
self.redirect("/", permanent=True)
return
e = errors.HTTPAPIError(e.status_code)
else:
e = errors.HTTPAPIError(errors.ERROR_INTERNAL_SERVER_ERROR)
exception = "".join([ln for ln
in traceback.format_exception(*exc_info)])
if status_code == errors.ERROR_INTERNAL_SERVER_ERROR \
and not options.debug:
self.send_error_mail(exception)
if options.debug:
e.data["exception"] = exception
self.clear()
# always return 200 OK for Web errors
self.set_status(errors.HTTP_OK)
self.set_header("Content-Type", "application/json; charset=UTF-8")
self.finish(str(e))
except Exception:
logging.error(traceback.format_exc())
return super(WebHandler, self).write_error(status_code, **kwargs)
def send_error_mail(self, exception):
pass
class APIHandler(BaseHandler):
def finish(self, chunk=None, message=None):
if chunk is None:
chunk = {}
if isinstance(chunk, dict):
chunk = {"meta": {"code": errors.HTTP_OK}, "data": chunk}
if message:
chunk["message"] = message
callback = escape.utf8(self.get_argument("callback", None))
if callback:
self.set_header("Content-Type", "application/x-javascript")
if isinstance(chunk, dict):
chunk = escape.json_encode(chunk)
self._write_buffer = [callback, "(", chunk, ")"] if chunk else []
super(APIHandler, self).finish()
else:
self.set_header("Content-Type", "application/json; charset=UTF-8")
super(APIHandler, self).finish(chunk)
def write_error(self, status_code, **kwargs):
try:
exc_info = kwargs.pop('exc_info')
e = exc_info[1]
if isinstance(e, errors.HTTPAPIError):
pass
elif isinstance(e, HTTPError):
e = errors.HTTPAPIError(e.status_code)
else:
e = errors.HTTPAPIError(errors.ERROR_INTERNAL_SERVER_ERROR)
exception = "".join([ln for ln
in traceback.format_exception(*exc_info)])
if status_code == errors.ERROR_INTERNAL_SERVER_ERROR \
and not options.debug:
self.send_error_mail(exception)
if options.debug:
e.data["exception"] = exception
self.clear()
# always return 200 OK for API errors
self.set_status(errors.HTTP_OK)
self.set_header("Content-Type", "application/json; charset=UTF-8")
self.finish(str(e))
except Exception:
logging.error(traceback.format_exc())
return super(APIHandler, self).write_error(status_code, **kwargs)
def send_error_mail(self, exception):
"""Override to implement custom error mail"""
pass
class ErrorHandler(BaseHandler):
"""Default 404: Not Found handler."""
def prepare(self):
super(ErrorHandler, self).prepare()
raise HTTPError(errors.ERROR_NOT_FOUND)
class APIErrorHandler(APIHandler):
"""Default API 404: Not Found handler."""
def prepare(self):
super(APIErrorHandler, self).prepare()
raise errors.HTTPAPIError(errors.ERROR_NOT_FOUND)
def authenticated(method):
"""Decorate methods with this to require that the user be logged in.
Just raise 401
or be avaliable
"""
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
if not self.current_user:
# raise HTTPError(401)
raise errors.HTTPAPIError(errors.ERROR_UNAUTHORIZED, "登录失效")
return method(self, *args, **kwargs)
return wrapper
def authenticated_admin(method):
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
if int(self.current_user.role) != 1001:
# raise HTTPError(403)
raise errors.HTTPAPIError(errors.ERROR_FORBIDDEN, "permission denied")
return method(self, *args, **kwargs)
return wrapper
def operation_log(primary_menu, sub_menu, ope_type, content, comment):
"""
Add logging to a function. level is the logging
level, name is the logger name, and message is the
log message. If name and message aren't specified,
they default to the function's module and name.
"""
def decorate(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
with self.app_mysql.connect() as conn:
conn.execute(text(
"insert into sys_log(user, ip, primary_menu, sub_menu, op_type, content, comment) "
"values(:user, :ip, :primary_menu, :sub_menu, :op_type, :content, :comment)"
),
{"user": self.current_user.name,
"ip": self.request.headers[
"X-Forwarded-For"] if "X-Forwarded-For" in self.request.headers else self.request.remote_ip,
"primary_menu": primary_menu,
"sub_menu": sub_menu,
"op_type": ope_type,
"content": content,
"comment": comment
}
)
conn.commit()
return func(self, *args, **kwargs)
return wrapper
return decorate
def permission(codes):
def decorate(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
rows = self.db_app.query(
"select rp.permission from role_permission rp, user_role ur where rp.role=ur.role and ur.userid=%s",
self.current_user.id
)
permissions = [item["permission"] for item in rows]
for code in codes:
if code not in permissions:
raise errors.HTTPAPIError(errors.ERROR_FORBIDDEN, "permission denied")
return func(self, *args, **kwargs)
return wrapper
return decorate
def license_validate(codes):
def decorate(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
license_cache = self.r_app.get("system:license")
license_info = {}
if license_cache:
license_info = json.loads(self.tostr(license_cache))
else:
row = self.db_app.get("select syscode, expireat from license limit 1")
if row:
self.r_app.set("system:license",
json.dumps({"syscode": row["syscode"], "expireat": row["expireat"]}))
license_info = row
license_status = get_license_status(license_info)
# logging.info("license status is : {}, need : {}".format(license_status, codes))
if license_status not in codes:
raise errors.HTTPAPIError(errors.ERROR_LICENSE_NOT_ACTIVE, "系统License未授权")
# if not license_info:
# raise errors.HTTPAPIError(errors.ERROR_LICENSE_NOT_ACTIVE, "License未授权")
# expireat = int(license_info["expireat"])
# local_time = int(time.time())
# if local_time >= expireat:
# raise errors.HTTPAPIError(errors.ERROR_LICENSE_NOT_ACTIVE, "License授权过期")
return func(self, *args, **kwargs)
return wrapper
return decorate
def userlog(method):
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
logging.info("[ip]:%s [user]:%s" % (self.request.remote_ip, self.current_user.id))
return method(self, *args, **kwargs)
return wrapper