# -*- coding: utf-8 -*- import ast import functools import hashlib import json import logging import re import time import traceback import urllib from typing import Any # import urlparse from urllib.parse import parse_qs, unquote # from urllib import unquote import tornado from sqlalchemy import text from tornado import escape from tornado.httpclient import AsyncHTTPClient, HTTPRequest from tornado.options import options # from tornado.web import RequestHandler as BaseRequestHandler, HTTPError, asynchronous from tornado.web import RequestHandler as BaseRequestHandler, HTTPError from website import errors from website import settings from website.service.license import get_license_status # from torndb import Row if settings.enable_curl_async_http_client: AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient") else: AsyncHTTPClient.configure(None, max_clients=settings.max_clients) REMOVE_SLASH_RE = re.compile(".+/$") class Row(dict): """A dict that allows for object-like property access syntax.""" def __getattr__(self, name): try: return self[name] except KeyError: raise AttributeError(name) def _callback_wrapper(callback): """A wrapper to handling basic callback error""" def _wrapper(response): if response.error: logging.error("call remote api err: %s" % response) if isinstance(response.error, tornado.httpclient.HTTPError): raise errors.HTTPAPIError(response.error.code, "网络连接失败") else: raise errors.HTTPAPIError(errors.ERROR_UNKNOWN_ERROR, "未知错误") else: callback(response) return _wrapper class BaseHandler(BaseRequestHandler): # def __init__(self, application, request, **kwargs): # super(BaseHandler, self).__init__(application, request, **kwargs) # self.xsrf_form_html() def set_default_headers(self): self.set_header("Access-Control-Allow-Origin", "*") self.set_header("Access-Control-Allow-Headers", "DNT,token,X-CustomHeader,Keep-Alive,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control," "Content-Type") # self.set_header("Access-Control-Allow-Headers", # "DNT,web-token,app-token,Authorization,Accept,Origin,Keep-Alive,User-Agent,X-Mx-ReqToken," # "X-Data-Type,X-Auth-Token,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range") self.set_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS') # def _request_summary(self): # # return "%s %s %s(@%s)" % (self.request.method, self.request.uri, self.request.body.decode(), # self.request.remote_ip) def get(self, *args, **kwargs): # enable GET request when enable delegate get to post if settings.app_get_to_post: self.post(*args, **kwargs) else: raise HTTPError(405) def render(self, template_name, **kwargs): if self.current_user: if 'username' not in kwargs: kwargs["username"] = self.current_user.name if 'current_userid' not in kwargs: kwargs["current_userid"] = self.current_user.id if 'role' not in kwargs: kwargs["role"] = self.current_user.role self.set_header("Cache-control", "no-cache") return super(BaseHandler, self).render(template_name, **kwargs) # @property # def db_app(self): # return self.application.db_app @property def app_mysql(self): return self.application.app_mysql @property def r_app(self): return self.application.r_app @property def kafka_producer(self): return self.application.kafka_producer # @property def es(self): return self.application.es # @property # def nsq(self): # return self.application.nsq def _call_api(self, url, headers, body, callback, method, callback_wrapper=_callback_wrapper): start = 0 if callback_wrapper: callback = callback_wrapper(callback) try: start = time.time() AsyncHTTPClient().fetch(HTTPRequest(url=url, method=method, body=body, headers=headers, allow_nonstandard_methods=True, connect_timeout=settings.remote_connect_timeout, request_timeout=settings.remote_request_timeout, follow_redirects=False), callback) except tornado.httpclient.HTTPError: logging.error("requet from %s, take time: %s" % (url, (time.time() - start) * 1000)) # if hasattr(x, "response") and x.response: # callback(x.response) # else: # logging.error("Tornado signalled HTTPError %s", x) # raise x # @asynchronous async def call_api(self, url, headers=None, body=None, callback=None, method="POST"): if callback is None: callback = self.call_api_callback if headers is None: headers = self.request.headers if body is None: body = self.request.body else: # make sure it is a post request headers["Content-Type"] = "application/x-www-form-urlencoded" self._call_api(url, headers, body, callback, method) def get_current_user(self, token_body=None): # jid = self.get_secure_cookie(settings.cookie_key) token = self.request.headers.get("token") if token_body: token = token_body if not token: return None token = unquote(token) jid = tornado.web.decode_signed_value( settings.cookie_secret, settings.secure_cookie_name, token ) jid = jid and str(jid, encoding="utf-8") or "" key = settings.session_key_prefix % jid user = self.r_app.get(key) if user: if "/user/info" in self.request.path: pass else: self.r_app.expire(key, settings.session_ttl) user = str(user, encoding='utf8') if isinstance(user, bytes) else user # return Row(ast.literal_eval(str(user, encoding="utf-8"))) return Row(ast.literal_eval(user)) else: return None def md5compare(self, name): string = unquote(name) num1 = string.split("|")[0] num2 = string.split("|")[1] num3 = string.split("|")[2] num = num1 + num2 md5string = hashlib.md5(num).hexdigest().upper() if md5string == num3: return True else: return False def tostr(self, src): return str(src, encoding='utf8') if isinstance(src, bytes) else src def prepare(self): self.remove_slash() self.prepare_context() self.set_default_jsonbody() # self.traffic_threshold() def set_default_jsonbody(self): if self.request.headers.get('Content-Type') == 'application/json' and self.request.body: # logging.info(self.request.headers.get('Content-Type')) # if self.request.headers.get('Content-Type') == 'application/json; charset=UTF-8': json_body = tornado.escape.json_decode(self.request.body) for key, value in json_body.items(): if value is not None: if type(value) is list: self.request.arguments.setdefault(key, []).extend(value) elif type(value) is dict: self.request.arguments[key] = value else: self.request.arguments.setdefault(key, []).extend([bytes(str(value), 'utf-8')]) def traffic_threshold(self): # if self.request.uri in ["/api/download"]: # return if not self.current_user: user_id = self.request.remote_ip else: user_id = self.current_user.id freq_key = "API:FREQ:%s:%s" % (user_id, int(time.time()) / 10) send_count = self.r_app.incr(freq_key) if send_count > settings.api_count_in_ten_second: freq_key = "API:FREQ:%s:%s" % (user_id, int(time.time()) / 10 + 1) self.r_app.setex(freq_key, send_count, 10) raise errors.HTTPAPIError( errors.ERROR_METHOD_NOT_ALLOWED, "请勿频繁操作") if send_count == 1: self.r_app.expire(freq_key, 10) def prepare_context(self): # self.nsq.pub(settings.nsq_topic_stats, escape.json_encode(self._create_stats_msg())) pass def remove_slash(self): if self.request.method == "GET": if REMOVE_SLASH_RE.match(self.request.path): # remove trail slash in path uri = self.request.path.rstrip("/") if self.request.query: uri += "?" + self.request.query self.redirect(uri) # def get_json_argument(self, name, default=BaseRequestHandler._ARG_DEFAULT): def get_json_argument(self, name, default=object()): json_body = tornado.escape.json_decode(self.request.body) value = json_body.get(name, default) return escape.utf8(value) if isinstance(value, str) else value def get_int_json_argument(self, name, default=0): try: return int(self.get_json_argument(name, default)) except ValueError: return default def get_escaped_json_argument(self, name, default=None): if default is not None: return self.escape_string(self.get_json_argument(name, default)) else: return self.get_json_argument(name, default) # def get_argument(self, name, default=BaseRequestHandler._ARG_DEFAULT, strip=True): def get_argument(self, name, default=object(), strip=True): value = super(BaseHandler, self).get_argument(name, default, strip) return escape.utf8(value) if isinstance(value, str) else value # def get_int_argument(self, name, default=0): def get_int_argument(self, name: Any, default: int = 0) -> int: try: return int(self.get_argument(name, default)) except ValueError: return default def get_float_argument(self, name, default=0.0): try: return float(self.get_argument(name, default)) except ValueError: return default def get_uint_arg(self, name, default=0): try: return abs(int(self.get_argument(name, default))) except ValueError: return default def unescape_string(self, s): return escape.xhtml_unescape(s) def escape_string(self, s): return escape.xhtml_escape(s) def get_escaped_argument(self, name, default=None): if default is not None: return self.escape_string(self.get_argument(name, default)) else: return self.get_argument(name, default) def get_page_url(self, page, form_id=None, tab=None): if form_id: return "javascript:goto_page('%s',%s);" % (form_id.strip(), page) path = self.request.path query = self.request.query # qdict = urlparse.parse_qs(query) qdict = parse_qs(query) for k, v in qdict.items(): if isinstance(v, list): qdict[k] = v and v[0] or '' qdict['page'] = page if tab: qdict['tab'] = tab return path + '?' + urllib.urlencode(qdict) def find_all(self, target, substring): current_pos = target.find(substring) while current_pos != -1: yield current_pos current_pos += len(substring) current_pos = target.find(substring, current_pos) class WebHandler(BaseHandler): def finish(self, chunk=None, message=None): callback = escape.utf8(self.get_argument("callback", None)) if callback: self.set_header("Content-Type", "application/x-javascript") if isinstance(chunk, dict): chunk = escape.json_encode(chunk) self._write_buffer = [callback, "(", chunk, ")"] if chunk else [] super(WebHandler, self).finish() else: self.set_header("Cache-control", "no-cache") super(WebHandler, self).finish(chunk) def write_error(self, status_code, **kwargs): try: exc_info = kwargs.pop('exc_info') e = exc_info[1] if isinstance(e, errors.HTTPAPIError): pass elif isinstance(e, HTTPError): if e.status_code == 401: self.redirect("/", permanent=True) return e = errors.HTTPAPIError(e.status_code) else: e = errors.HTTPAPIError(errors.ERROR_INTERNAL_SERVER_ERROR) exception = "".join([ln for ln in traceback.format_exception(*exc_info)]) if status_code == errors.ERROR_INTERNAL_SERVER_ERROR \ and not options.debug: self.send_error_mail(exception) if options.debug: e.data["exception"] = exception self.clear() # always return 200 OK for Web errors self.set_status(errors.HTTP_OK) self.set_header("Content-Type", "application/json; charset=UTF-8") self.finish(str(e)) except Exception: logging.error(traceback.format_exc()) return super(WebHandler, self).write_error(status_code, **kwargs) def send_error_mail(self, exception): pass class APIHandler(BaseHandler): def finish(self, chunk=None, message=None): if chunk is None: chunk = {} if isinstance(chunk, dict): chunk = {"meta": {"code": errors.HTTP_OK}, "data": chunk} if message: chunk["message"] = message callback = escape.utf8(self.get_argument("callback", None)) if callback: self.set_header("Content-Type", "application/x-javascript") if isinstance(chunk, dict): chunk = escape.json_encode(chunk) self._write_buffer = [callback, "(", chunk, ")"] if chunk else [] super(APIHandler, self).finish() else: self.set_header("Content-Type", "application/json; charset=UTF-8") super(APIHandler, self).finish(chunk) def write_error(self, status_code, **kwargs): try: exc_info = kwargs.pop('exc_info') e = exc_info[1] if isinstance(e, errors.HTTPAPIError): pass elif isinstance(e, HTTPError): e = errors.HTTPAPIError(e.status_code) else: e = errors.HTTPAPIError(errors.ERROR_INTERNAL_SERVER_ERROR) exception = "".join([ln for ln in traceback.format_exception(*exc_info)]) if status_code == errors.ERROR_INTERNAL_SERVER_ERROR \ and not options.debug: self.send_error_mail(exception) if options.debug: e.data["exception"] = exception self.clear() # always return 200 OK for API errors self.set_status(errors.HTTP_OK) self.set_header("Content-Type", "application/json; charset=UTF-8") self.finish(str(e)) except Exception: logging.error(traceback.format_exc()) return super(APIHandler, self).write_error(status_code, **kwargs) def send_error_mail(self, exception): """Override to implement custom error mail""" pass class ErrorHandler(BaseHandler): """Default 404: Not Found handler.""" def prepare(self): super(ErrorHandler, self).prepare() raise HTTPError(errors.ERROR_NOT_FOUND) class APIErrorHandler(APIHandler): """Default API 404: Not Found handler.""" def prepare(self): super(APIErrorHandler, self).prepare() raise errors.HTTPAPIError(errors.ERROR_NOT_FOUND) def authenticated(method): """Decorate methods with this to require that the user be logged in. Just raise 401 or be avaliable """ @functools.wraps(method) def wrapper(self, *args, **kwargs): if not self.current_user: # raise HTTPError(401) raise errors.HTTPAPIError(errors.ERROR_UNAUTHORIZED, "登录失效") return method(self, *args, **kwargs) return wrapper def authenticated_admin(method): @functools.wraps(method) def wrapper(self, *args, **kwargs): if int(self.current_user.role) != 1001: # raise HTTPError(403) raise errors.HTTPAPIError(errors.ERROR_FORBIDDEN, "permission denied") return method(self, *args, **kwargs) return wrapper def operation_log(primary_menu, sub_menu, ope_type, content, comment=""): """ Add logging to a function. level is the logging level, name is the logger name, and message is the log message. If name and message aren't specified, they default to the function's module and name. """ def decorate(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): try: with self.app_mysql.connect() as conn: conn.execute(text( "insert into sys_log(user, ip, primary_menu, sub_menu, op_type, content, comment) " "values(:user, :ip, :primary_menu, :sub_menu, :op_type, :content, :comment)" ), {"user": self.current_user.name, "ip": self.request.headers[ "X-Forwarded-For"] if "X-Forwarded-For" in self.request.headers else self.request.remote_ip, "primary_menu": primary_menu, "sub_menu": sub_menu, "op_type": ope_type, "content": content, "comment": comment } ) conn.commit() except Exception as e: logging.info("operation log error: %s" % e) return func(self, *args, **kwargs) return wrapper return decorate def permission(codes): def decorate(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): rows = self.db_app.query( "select rp.permission from role_permission rp, user_role ur where rp.role=ur.role and ur.userid=%s", self.current_user.id ) permissions = [item["permission"] for item in rows] for code in codes: if code not in permissions: raise errors.HTTPAPIError(errors.ERROR_FORBIDDEN, "permission denied") return func(self, *args, **kwargs) return wrapper return decorate def license_validate(codes): def decorate(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): license_cache = self.r_app.get("system:license") license_info = {} if license_cache: license_info = json.loads(self.tostr(license_cache)) else: row = self.db_app.get("select syscode, expireat from license limit 1") if row: self.r_app.set("system:license", json.dumps({"syscode": row["syscode"], "expireat": row["expireat"]})) license_info = row license_status = get_license_status(license_info) # logging.info("license status is : {}, need : {}".format(license_status, codes)) if license_status not in codes: raise errors.HTTPAPIError(errors.ERROR_LICENSE_NOT_ACTIVE, "系统License未授权") # if not license_info: # raise errors.HTTPAPIError(errors.ERROR_LICENSE_NOT_ACTIVE, "License未授权") # expireat = int(license_info["expireat"]) # local_time = int(time.time()) # if local_time >= expireat: # raise errors.HTTPAPIError(errors.ERROR_LICENSE_NOT_ACTIVE, "License授权过期") return func(self, *args, **kwargs) return wrapper return decorate def userlog(method): @functools.wraps(method) def wrapper(self, *args, **kwargs): logging.info("[ip]:%s [user]:%s" % (self.request.remote_ip, self.current_user.id)) return method(self, *args, **kwargs) return wrapper