You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
559 lines
20 KiB
Python
559 lines
20 KiB
Python
# -*- coding: utf-8 -*-
|
|
import ast
|
|
import functools
|
|
import hashlib
|
|
import logging
|
|
import re
|
|
import time
|
|
import traceback
|
|
import urllib
|
|
import json
|
|
# import urlparse
|
|
from urllib.parse import parse_qs, unquote
|
|
from typing import Any
|
|
# from urllib import unquote
|
|
|
|
import tornado
|
|
from tornado import escape
|
|
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
|
|
from tornado.options import options
|
|
# from tornado.web import RequestHandler as BaseRequestHandler, HTTPError, asynchronous
|
|
from tornado.web import RequestHandler as BaseRequestHandler, HTTPError
|
|
from torndb import Row
|
|
|
|
from website import errors
|
|
from website import settings
|
|
from website.service.license import get_license_status
|
|
|
|
if settings.enable_curl_async_http_client:
|
|
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
|
|
else:
|
|
AsyncHTTPClient.configure(None, max_clients=settings.max_clients)
|
|
|
|
REMOVE_SLASH_RE = re.compile(".+/$")
|
|
|
|
|
|
def _callback_wrapper(callback):
|
|
"""A wrapper to handling basic callback error"""
|
|
|
|
def _wrapper(response):
|
|
if response.error:
|
|
logging.error("call remote api err: %s" % response)
|
|
if isinstance(response.error, tornado.httpclient.HTTPError):
|
|
raise errors.HTTPAPIError(response.error.code, "网络连接失败")
|
|
else:
|
|
raise errors.HTTPAPIError(errors.ERROR_UNKNOWN_ERROR, "未知错误")
|
|
else:
|
|
callback(response)
|
|
|
|
return _wrapper
|
|
|
|
|
|
class BaseHandler(BaseRequestHandler):
|
|
# def __init__(self, application, request, **kwargs):
|
|
# super(BaseHandler, self).__init__(application, request, **kwargs)
|
|
# self.xsrf_form_html()
|
|
|
|
def set_default_headers(self):
|
|
self.set_header("Access-Control-Allow-Origin", "*")
|
|
self.set_header("Access-Control-Allow-Headers",
|
|
"DNT,token,X-CustomHeader,Keep-Alive,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,"
|
|
"Content-Type")
|
|
# self.set_header("Access-Control-Allow-Headers",
|
|
# "DNT,web-token,app-token,Authorization,Accept,Origin,Keep-Alive,User-Agent,X-Mx-ReqToken,"
|
|
# "X-Data-Type,X-Auth-Token,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range")
|
|
self.set_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS')
|
|
|
|
# def _request_summary(self):
|
|
#
|
|
# return "%s %s %s(@%s)" % (self.request.method, self.request.uri, self.request.body.decode(),
|
|
# self.request.remote_ip)
|
|
|
|
def get(self, *args, **kwargs):
|
|
# enable GET request when enable delegate get to post
|
|
if settings.app_get_to_post:
|
|
self.post(*args, **kwargs)
|
|
else:
|
|
raise HTTPError(405)
|
|
|
|
def render(self, template_name, **kwargs):
|
|
if self.current_user:
|
|
if 'username' not in kwargs:
|
|
kwargs["username"] = self.current_user.name
|
|
if 'current_userid' not in kwargs:
|
|
kwargs["current_userid"] = self.current_user.id
|
|
if 'role' not in kwargs:
|
|
kwargs["role"] = self.current_user.role
|
|
self.set_header("Cache-control", "no-cache")
|
|
return super(BaseHandler, self).render(template_name, **kwargs)
|
|
|
|
# @property
|
|
# def db_app(self):
|
|
# return self.application.db_app
|
|
|
|
@property
|
|
def app_mysql(self):
|
|
return self.application.app_mysql
|
|
|
|
@property
|
|
def r_app(self):
|
|
return self.application.r_app
|
|
|
|
@property
|
|
def kafka_producer(self):
|
|
return self.application.kafka_producer
|
|
#
|
|
@property
|
|
def es(self):
|
|
return self.application.es
|
|
|
|
# @property
|
|
# def nsq(self):
|
|
# return self.application.nsq
|
|
|
|
def _call_api(self, url, headers, body, callback, method, callback_wrapper=_callback_wrapper):
|
|
start = 0
|
|
if callback_wrapper:
|
|
callback = callback_wrapper(callback)
|
|
|
|
try:
|
|
start = time.time()
|
|
AsyncHTTPClient().fetch(HTTPRequest(url=url,
|
|
method=method,
|
|
body=body,
|
|
headers=headers,
|
|
allow_nonstandard_methods=True,
|
|
connect_timeout=settings.remote_connect_timeout,
|
|
request_timeout=settings.remote_request_timeout,
|
|
follow_redirects=False),
|
|
callback)
|
|
except tornado.httpclient.HTTPError:
|
|
logging.error("requet from %s, take time: %s" % (url, (time.time() - start) * 1000))
|
|
# if hasattr(x, "response") and x.response:
|
|
# callback(x.response)
|
|
# else:
|
|
# logging.error("Tornado signalled HTTPError %s", x)
|
|
# raise x
|
|
|
|
# @asynchronous
|
|
async def call_api(self, url, headers=None, body=None, callback=None, method="POST"):
|
|
if callback is None:
|
|
callback = self.call_api_callback
|
|
|
|
if headers is None:
|
|
headers = self.request.headers
|
|
|
|
if body is None:
|
|
body = self.request.body
|
|
else:
|
|
# make sure it is a post request
|
|
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
|
|
|
self._call_api(url, headers, body, callback, method)
|
|
|
|
def get_current_user(self, token_body=None):
|
|
# jid = self.get_secure_cookie(settings.cookie_key)
|
|
token = self.request.headers.get("token")
|
|
if token_body:
|
|
token = token_body
|
|
|
|
if not token:
|
|
return None
|
|
token = unquote(token)
|
|
jid = tornado.web.decode_signed_value(
|
|
settings.cookie_secret,
|
|
settings.secure_cookie_name,
|
|
token
|
|
)
|
|
|
|
jid = jid and str(jid, encoding="utf-8") or ""
|
|
key = settings.session_key_prefix % jid
|
|
user = self.r_app.get(key)
|
|
if user:
|
|
if "/user/info" in self.request.path:
|
|
pass
|
|
else:
|
|
self.r_app.expire(key, settings.session_ttl)
|
|
user = str(user, encoding='utf8') if isinstance(user, bytes) else user
|
|
# return Row(ast.literal_eval(str(user, encoding="utf-8")))
|
|
return Row(ast.literal_eval(user))
|
|
else:
|
|
return None
|
|
|
|
def md5compare(self, name):
|
|
string = unquote(name)
|
|
|
|
num1 = string.split("|")[0]
|
|
num2 = string.split("|")[1]
|
|
num3 = string.split("|")[2]
|
|
|
|
num = num1 + num2
|
|
md5string = hashlib.md5(num).hexdigest().upper()
|
|
|
|
if md5string == num3:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def tostr(self, src):
|
|
return str(src, encoding='utf8') if isinstance(src, bytes) else src
|
|
|
|
def prepare(self):
|
|
self.remove_slash()
|
|
self.prepare_context()
|
|
self.set_default_jsonbody()
|
|
# self.traffic_threshold()
|
|
|
|
def set_default_jsonbody(self):
|
|
if self.request.headers.get('Content-Type') == 'application/json' and self.request.body:
|
|
# logging.info(self.request.headers.get('Content-Type'))
|
|
# if self.request.headers.get('Content-Type') == 'application/json; charset=UTF-8':
|
|
json_body = tornado.escape.json_decode(self.request.body)
|
|
for key, value in json_body.items():
|
|
if value is not None:
|
|
if type(value) is list:
|
|
self.request.arguments.setdefault(key, []).extend(value)
|
|
elif type(value) is dict:
|
|
self.request.arguments[key] = value
|
|
else:
|
|
self.request.arguments.setdefault(key, []).extend([bytes(str(value), 'utf-8')])
|
|
|
|
def traffic_threshold(self):
|
|
# if self.request.uri in ["/api/download"]:
|
|
# return
|
|
if not self.current_user:
|
|
user_id = self.request.remote_ip
|
|
else:
|
|
user_id = self.current_user.id
|
|
|
|
freq_key = "API:FREQ:%s:%s" % (user_id, int(time.time()) / 10)
|
|
send_count = self.r_app.incr(freq_key)
|
|
if send_count > settings.api_count_in_ten_second:
|
|
freq_key = "API:FREQ:%s:%s" % (user_id, int(time.time()) / 10 + 1)
|
|
self.r_app.setex(freq_key, send_count, 10)
|
|
raise errors.HTTPAPIError(
|
|
errors.ERROR_METHOD_NOT_ALLOWED, "请勿频繁操作")
|
|
if send_count == 1:
|
|
self.r_app.expire(freq_key, 10)
|
|
|
|
def prepare_context(self):
|
|
# self.nsq.pub(settings.nsq_topic_stats, escape.json_encode(self._create_stats_msg()))
|
|
pass
|
|
|
|
def remove_slash(self):
|
|
if self.request.method == "GET":
|
|
if REMOVE_SLASH_RE.match(self.request.path):
|
|
# remove trail slash in path
|
|
uri = self.request.path.rstrip("/")
|
|
if self.request.query:
|
|
uri += "?" + self.request.query
|
|
|
|
self.redirect(uri)
|
|
|
|
# def get_json_argument(self, name, default=BaseRequestHandler._ARG_DEFAULT):
|
|
def get_json_argument(self, name, default=object()):
|
|
json_body = tornado.escape.json_decode(self.request.body)
|
|
value = json_body.get(name, default)
|
|
return escape.utf8(value) if isinstance(value, str) else value
|
|
|
|
def get_int_json_argument(self, name, default=0):
|
|
try:
|
|
return int(self.get_json_argument(name, default))
|
|
except ValueError:
|
|
return default
|
|
|
|
def get_escaped_json_argument(self, name, default=None):
|
|
if default is not None:
|
|
return self.escape_string(self.get_json_argument(name, default))
|
|
else:
|
|
return self.get_json_argument(name, default)
|
|
|
|
# def get_argument(self, name, default=BaseRequestHandler._ARG_DEFAULT, strip=True):
|
|
def get_argument(self, name, default=object(), strip=True):
|
|
value = super(BaseHandler, self).get_argument(name, default, strip)
|
|
return escape.utf8(value) if isinstance(value, str) else value
|
|
|
|
# def get_int_argument(self, name, default=0):
|
|
def get_int_argument(self, name: Any, default: int = 0) -> int:
|
|
try:
|
|
return int(self.get_argument(name, default))
|
|
except ValueError:
|
|
return default
|
|
|
|
def get_float_argument(self, name, default=0.0):
|
|
try:
|
|
return float(self.get_argument(name, default))
|
|
except ValueError:
|
|
return default
|
|
|
|
def get_uint_arg(self, name, default=0):
|
|
try:
|
|
return abs(int(self.get_argument(name, default)))
|
|
except ValueError:
|
|
return default
|
|
|
|
def unescape_string(self, s):
|
|
return escape.xhtml_unescape(s)
|
|
|
|
def escape_string(self, s):
|
|
return escape.xhtml_escape(s)
|
|
|
|
def get_escaped_argument(self, name, default=None):
|
|
if default is not None:
|
|
return self.escape_string(self.get_argument(name, default))
|
|
else:
|
|
return self.get_argument(name, default)
|
|
|
|
def get_page_url(self, page, form_id=None, tab=None):
|
|
if form_id:
|
|
return "javascript:goto_page('%s',%s);" % (form_id.strip(), page)
|
|
path = self.request.path
|
|
query = self.request.query
|
|
# qdict = urlparse.parse_qs(query)
|
|
qdict = parse_qs(query)
|
|
for k, v in qdict.items():
|
|
if isinstance(v, list):
|
|
qdict[k] = v and v[0] or ''
|
|
qdict['page'] = page
|
|
if tab:
|
|
qdict['tab'] = tab
|
|
return path + '?' + urllib.urlencode(qdict)
|
|
|
|
def find_all(self, target, substring):
|
|
current_pos = target.find(substring)
|
|
while current_pos != -1:
|
|
yield current_pos
|
|
current_pos += len(substring)
|
|
current_pos = target.find(substring, current_pos)
|
|
|
|
class WebHandler(BaseHandler):
|
|
def finish(self, chunk=None, message=None):
|
|
callback = escape.utf8(self.get_argument("callback", None))
|
|
if callback:
|
|
self.set_header("Content-Type", "application/x-javascript")
|
|
|
|
if isinstance(chunk, dict):
|
|
chunk = escape.json_encode(chunk)
|
|
|
|
self._write_buffer = [callback, "(", chunk, ")"] if chunk else []
|
|
super(WebHandler, self).finish()
|
|
else:
|
|
self.set_header("Cache-control", "no-cache")
|
|
super(WebHandler, self).finish(chunk)
|
|
|
|
def write_error(self, status_code, **kwargs):
|
|
try:
|
|
exc_info = kwargs.pop('exc_info')
|
|
e = exc_info[1]
|
|
|
|
if isinstance(e, errors.HTTPAPIError):
|
|
pass
|
|
elif isinstance(e, HTTPError):
|
|
if e.status_code == 401:
|
|
self.redirect("/", permanent=True)
|
|
return
|
|
e = errors.HTTPAPIError(e.status_code)
|
|
else:
|
|
e = errors.HTTPAPIError(errors.ERROR_INTERNAL_SERVER_ERROR)
|
|
|
|
exception = "".join([ln for ln
|
|
in traceback.format_exception(*exc_info)])
|
|
|
|
if status_code == errors.ERROR_INTERNAL_SERVER_ERROR \
|
|
and not options.debug:
|
|
self.send_error_mail(exception)
|
|
|
|
if options.debug:
|
|
e.data["exception"] = exception
|
|
|
|
self.clear()
|
|
# always return 200 OK for Web errors
|
|
self.set_status(errors.HTTP_OK)
|
|
self.set_header("Content-Type", "application/json; charset=UTF-8")
|
|
self.finish(str(e))
|
|
except Exception:
|
|
logging.error(traceback.format_exc())
|
|
return super(WebHandler, self).write_error(status_code, **kwargs)
|
|
|
|
def send_error_mail(self, exception):
|
|
pass
|
|
|
|
|
|
class APIHandler(BaseHandler):
|
|
def finish(self, chunk=None, message=None):
|
|
if chunk is None:
|
|
chunk = {}
|
|
|
|
if isinstance(chunk, dict):
|
|
chunk = {"meta": {"code": errors.HTTP_OK}, "data": chunk}
|
|
|
|
if message:
|
|
chunk["message"] = message
|
|
callback = escape.utf8(self.get_argument("callback", None))
|
|
if callback:
|
|
self.set_header("Content-Type", "application/x-javascript")
|
|
|
|
if isinstance(chunk, dict):
|
|
chunk = escape.json_encode(chunk)
|
|
|
|
self._write_buffer = [callback, "(", chunk, ")"] if chunk else []
|
|
super(APIHandler, self).finish()
|
|
else:
|
|
self.set_header("Content-Type", "application/json; charset=UTF-8")
|
|
super(APIHandler, self).finish(chunk)
|
|
|
|
def write_error(self, status_code, **kwargs):
|
|
try:
|
|
exc_info = kwargs.pop('exc_info')
|
|
e = exc_info[1]
|
|
|
|
if isinstance(e, errors.HTTPAPIError):
|
|
pass
|
|
elif isinstance(e, HTTPError):
|
|
e = errors.HTTPAPIError(e.status_code)
|
|
else:
|
|
e = errors.HTTPAPIError(errors.ERROR_INTERNAL_SERVER_ERROR)
|
|
|
|
exception = "".join([ln for ln
|
|
in traceback.format_exception(*exc_info)])
|
|
|
|
if status_code == errors.ERROR_INTERNAL_SERVER_ERROR \
|
|
and not options.debug:
|
|
self.send_error_mail(exception)
|
|
|
|
if options.debug:
|
|
e.data["exception"] = exception
|
|
|
|
self.clear()
|
|
# always return 200 OK for API errors
|
|
self.set_status(errors.HTTP_OK)
|
|
self.set_header("Content-Type", "application/json; charset=UTF-8")
|
|
self.finish(str(e))
|
|
except Exception:
|
|
logging.error(traceback.format_exc())
|
|
return super(APIHandler, self).write_error(status_code, **kwargs)
|
|
|
|
def send_error_mail(self, exception):
|
|
"""Override to implement custom error mail"""
|
|
pass
|
|
|
|
|
|
class ErrorHandler(BaseHandler):
|
|
"""Default 404: Not Found handler."""
|
|
|
|
def prepare(self):
|
|
super(ErrorHandler, self).prepare()
|
|
raise HTTPError(errors.ERROR_NOT_FOUND)
|
|
|
|
|
|
class APIErrorHandler(APIHandler):
|
|
"""Default API 404: Not Found handler."""
|
|
|
|
def prepare(self):
|
|
super(APIErrorHandler, self).prepare()
|
|
raise errors.HTTPAPIError(errors.ERROR_NOT_FOUND)
|
|
|
|
|
|
def authenticated(method):
|
|
"""Decorate methods with this to require that the user be logged in.
|
|
|
|
Just raise 401
|
|
or be avaliable
|
|
"""
|
|
|
|
@functools.wraps(method)
|
|
def wrapper(self, *args, **kwargs):
|
|
if not self.current_user:
|
|
# raise HTTPError(401)
|
|
raise errors.HTTPAPIError(errors.ERROR_UNAUTHORIZED, "登录失效")
|
|
return method(self, *args, **kwargs)
|
|
return wrapper
|
|
|
|
|
|
def authenticated_admin(method):
|
|
@functools.wraps(method)
|
|
def wrapper(self, *args, **kwargs):
|
|
if int(self.current_user.role) != 1001:
|
|
# raise HTTPError(403)
|
|
raise errors.HTTPAPIError(errors.ERROR_FORBIDDEN, "permission denied")
|
|
return method(self, *args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def operation_log(primary_module, secondary_module, operation_type, content, desc):
|
|
"""
|
|
Add logging to a function. level is the logging
|
|
level, name is the logger name, and message is the
|
|
log message. If name and message aren't specified,
|
|
they default to the function's module and name.
|
|
"""
|
|
def decorate(func):
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(self, *args, **kwargs):
|
|
self.db_app.insert(
|
|
"insert into system_log(user, ip, first_module, second_module, op_type, op_content, description) "
|
|
"values(%s, %s, %s, %s, %s, %s, %s)",
|
|
self.current_user.name, self.request.headers["X-Forwarded-For"], primary_module, secondary_module, operation_type,
|
|
content, desc
|
|
)
|
|
|
|
return func(self, *args, **kwargs)
|
|
return wrapper
|
|
return decorate
|
|
|
|
def permission(codes):
|
|
def decorate(func):
|
|
@functools.wraps(func)
|
|
def wrapper(self, *args, **kwargs):
|
|
rows = self.db_app.query(
|
|
"select rp.permission from role_permission rp, user_role ur where rp.role=ur.role and ur.userid=%s",
|
|
self.current_user.id
|
|
)
|
|
permissions = [item["permission"] for item in rows]
|
|
for code in codes:
|
|
if code not in permissions:
|
|
raise errors.HTTPAPIError(errors.ERROR_FORBIDDEN, "permission denied")
|
|
|
|
return func(self, *args, **kwargs)
|
|
return wrapper
|
|
return decorate
|
|
|
|
def license_validate(codes):
|
|
def decorate(func):
|
|
@functools.wraps(func)
|
|
def wrapper(self, *args, **kwargs):
|
|
license_cache = self.r_app.get("system:license")
|
|
license_info = {}
|
|
if license_cache:
|
|
license_info = json.loads(self.tostr(license_cache))
|
|
else:
|
|
row = self.db_app.get("select syscode, expireat from license limit 1")
|
|
if row:
|
|
self.r_app.set("system:license", json.dumps({"syscode":row["syscode"], "expireat":row["expireat"]}))
|
|
license_info = row
|
|
|
|
license_status = get_license_status(license_info)
|
|
# logging.info("license status is : {}, need : {}".format(license_status, codes))
|
|
if license_status not in codes:
|
|
raise errors.HTTPAPIError(errors.ERROR_LICENSE_NOT_ACTIVE, "系统License未授权")
|
|
# if not license_info:
|
|
# raise errors.HTTPAPIError(errors.ERROR_LICENSE_NOT_ACTIVE, "License未授权")
|
|
# expireat = int(license_info["expireat"])
|
|
# local_time = int(time.time())
|
|
# if local_time >= expireat:
|
|
# raise errors.HTTPAPIError(errors.ERROR_LICENSE_NOT_ACTIVE, "License授权过期")
|
|
|
|
return func(self, *args, **kwargs)
|
|
return wrapper
|
|
return decorate
|
|
|
|
def userlog(method):
|
|
@functools.wraps(method)
|
|
def wrapper(self, *args, **kwargs):
|
|
logging.info("[ip]:%s [user]:%s" % (self.request.remote_ip, self.current_user.id))
|
|
return method(self, *args, **kwargs)
|
|
|
|
return wrapper
|