# -*- coding: utf-8 -*-
import ast
import functools
import hashlib
import json
import logging
import re
import time
import traceback
import urllib
from typing import Any
# import urlparse
from urllib.parse import parse_qs, unquote

# from urllib import unquote
import tornado
from sqlalchemy import text
from tornado import escape
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
from tornado.options import options
# from tornado.web import RequestHandler as BaseRequestHandler, HTTPError, asynchronous
from tornado.web import RequestHandler as BaseRequestHandler, HTTPError

from website import errors
from website import settings
from website.service.license import get_license_status

# from torndb import Row

if settings.enable_curl_async_http_client:
    AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
else:
    AsyncHTTPClient.configure(None, max_clients=settings.max_clients)

REMOVE_SLASH_RE = re.compile(".+/$")


class Row(dict):
    """A dict that allows for object-like property access syntax."""

    def __getattr__(self, name):
        try:
            return self[name]
        except KeyError:
            raise AttributeError(name)


def _callback_wrapper(callback):
    """A wrapper to handling basic callback error"""

    def _wrapper(response):
        if response.error:
            logging.error("call remote api err: %s" % response)
            if isinstance(response.error, tornado.httpclient.HTTPError):
                raise errors.HTTPAPIError(response.error.code, "网络连接失败")
            else:
                raise errors.HTTPAPIError(errors.ERROR_UNKNOWN_ERROR, "未知错误")
        else:
            callback(response)

    return _wrapper


class BaseHandler(BaseRequestHandler):
    # def __init__(self, application, request, **kwargs):
    #     super(BaseHandler, self).__init__(application, request, **kwargs)
    #     self.xsrf_form_html()

    def set_default_headers(self):
        self.set_header("Access-Control-Allow-Origin", "*")
        self.set_header("Access-Control-Allow-Headers",
                        "DNT,token,X-CustomHeader,Keep-Alive,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,"
                        "Content-Type")
        # self.set_header("Access-Control-Allow-Headers",
        #                 "DNT,web-token,app-token,Authorization,Accept,Origin,Keep-Alive,User-Agent,X-Mx-ReqToken,"
        #                 "X-Data-Type,X-Auth-Token,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range")
        self.set_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS')

    # def _request_summary(self):
    #
    #     return "%s %s %s(@%s)" % (self.request.method, self.request.uri, self.request.body.decode(),
    #                                  self.request.remote_ip)

    def get(self, *args, **kwargs):
        # enable GET request when enable delegate get to post
        if settings.app_get_to_post:
            self.post(*args, **kwargs)
        else:
            raise HTTPError(405)

    def render(self, template_name, **kwargs):
        if self.current_user:
            if 'username' not in kwargs:
                kwargs["username"] = self.current_user.name
            if 'current_userid' not in kwargs:
                kwargs["current_userid"] = self.current_user.id
            if 'role' not in kwargs:
                kwargs["role"] = self.current_user.role
        self.set_header("Cache-control", "no-cache")
        return super(BaseHandler, self).render(template_name, **kwargs)

    # @property
    # def db_app(self):
    #     return self.application.db_app

    @property
    def app_mysql(self):
        return self.application.app_mysql

    @property
    def r_app(self):
        return self.application.r_app

    @property
    def kafka_producer(self):
        return self.application.kafka_producer

    #
    @property
    def es(self):
        return self.application.es

    # @property
    # def nsq(self):
    #     return self.application.nsq

    def _call_api(self, url, headers, body, callback, method, callback_wrapper=_callback_wrapper):
        start = 0
        if callback_wrapper:
            callback = callback_wrapper(callback)

        try:
            start = time.time()
            AsyncHTTPClient().fetch(HTTPRequest(url=url,
                                                method=method,
                                                body=body,
                                                headers=headers,
                                                allow_nonstandard_methods=True,
                                                connect_timeout=settings.remote_connect_timeout,
                                                request_timeout=settings.remote_request_timeout,
                                                follow_redirects=False),
                                    callback)
        except tornado.httpclient.HTTPError:
            logging.error("requet from %s,  take time: %s" % (url, (time.time() - start) * 1000))
            # if hasattr(x, "response") and x.response:
            #     callback(x.response)
            # else:
            #     logging.error("Tornado signalled HTTPError %s", x)
            #     raise x

    # @asynchronous
    async def call_api(self, url, headers=None, body=None, callback=None, method="POST"):
        if callback is None:
            callback = self.call_api_callback

        if headers is None:
            headers = self.request.headers

        if body is None:
            body = self.request.body
        else:
            # make sure it is a post request
            headers["Content-Type"] = "application/x-www-form-urlencoded"

        self._call_api(url, headers, body, callback, method)

    def get_current_user(self, token_body=None):
        # jid = self.get_secure_cookie(settings.cookie_key)
        token = self.request.headers.get("token")
        if token_body:
            token = token_body

        if not token:
            return None
        token = unquote(token)
        jid = tornado.web.decode_signed_value(
            settings.cookie_secret,
            settings.secure_cookie_name,
            token
        )

        jid = jid and str(jid, encoding="utf-8") or ""
        key = settings.session_key_prefix % jid
        user = self.r_app.get(key)
        if user:
            if "/user/info" in self.request.path:
                pass
            else:
                self.r_app.expire(key, settings.session_ttl)
            user = str(user, encoding='utf8') if isinstance(user, bytes) else user
            # return Row(ast.literal_eval(str(user, encoding="utf-8")))
            return Row(ast.literal_eval(user))
        else:
            return None

    def md5compare(self, name):
        string = unquote(name)

        num1 = string.split("|")[0]
        num2 = string.split("|")[1]
        num3 = string.split("|")[2]

        num = num1 + num2
        md5string = hashlib.md5(num).hexdigest().upper()

        if md5string == num3:
            return True
        else:
            return False

    def tostr(self, src):
        return str(src, encoding='utf8') if isinstance(src, bytes) else src

    def prepare(self):
        self.remove_slash()
        self.prepare_context()
        self.set_default_jsonbody()
        # self.traffic_threshold()

    def set_default_jsonbody(self):
        if self.request.headers.get('Content-Type') == 'application/json' and self.request.body:
            # logging.info(self.request.headers.get('Content-Type'))
            # if self.request.headers.get('Content-Type') == 'application/json; charset=UTF-8':
            json_body = tornado.escape.json_decode(self.request.body)
            for key, value in json_body.items():
                if value is not None:
                    if type(value) is list:
                        self.request.arguments.setdefault(key, []).extend(value)
                    elif type(value) is dict:
                        self.request.arguments[key] = value
                    else:
                        self.request.arguments.setdefault(key, []).extend([bytes(str(value), 'utf-8')])

    def traffic_threshold(self):
        # if self.request.uri in ["/api/download"]:
        #     return
        if not self.current_user:
            user_id = self.request.remote_ip
        else:
            user_id = self.current_user.id

        freq_key = "API:FREQ:%s:%s" % (user_id, int(time.time()) / 10)
        send_count = self.r_app.incr(freq_key)
        if send_count > settings.api_count_in_ten_second:
            freq_key = "API:FREQ:%s:%s" % (user_id, int(time.time()) / 10 + 1)
            self.r_app.setex(freq_key, send_count, 10)
            raise errors.HTTPAPIError(
                errors.ERROR_METHOD_NOT_ALLOWED, "请勿频繁操作")
        if send_count == 1:
            self.r_app.expire(freq_key, 10)

    def prepare_context(self):
        # self.nsq.pub(settings.nsq_topic_stats, escape.json_encode(self._create_stats_msg()))
        pass

    def remove_slash(self):
        if self.request.method == "GET":
            if REMOVE_SLASH_RE.match(self.request.path):
                # remove trail slash in path
                uri = self.request.path.rstrip("/")
                if self.request.query:
                    uri += "?" + self.request.query

                self.redirect(uri)

    # def get_json_argument(self, name, default=BaseRequestHandler._ARG_DEFAULT):
    def get_json_argument(self, name, default=object()):
        json_body = tornado.escape.json_decode(self.request.body)
        value = json_body.get(name, default)
        return escape.utf8(value) if isinstance(value, str) else value

    def get_int_json_argument(self, name, default=0):
        try:
            return int(self.get_json_argument(name, default))
        except ValueError:
            return default

    def get_escaped_json_argument(self, name, default=None):
        if default is not None:
            return self.escape_string(self.get_json_argument(name, default))
        else:
            return self.get_json_argument(name, default)

    # def get_argument(self, name, default=BaseRequestHandler._ARG_DEFAULT, strip=True):
    def get_argument(self, name, default=object(), strip=True):
        value = super(BaseHandler, self).get_argument(name, default, strip)
        return escape.utf8(value) if isinstance(value, str) else value

    # def get_int_argument(self, name, default=0):
    def get_int_argument(self, name: Any, default: int = 0) -> int:
        try:
            return int(self.get_argument(name, default))
        except ValueError:
            return default

    def get_float_argument(self, name, default=0.0):
        try:
            return float(self.get_argument(name, default))
        except ValueError:
            return default

    def get_uint_arg(self, name, default=0):
        try:
            return abs(int(self.get_argument(name, default)))
        except ValueError:
            return default

    def unescape_string(self, s):
        return escape.xhtml_unescape(s)

    def escape_string(self, s):
        return escape.xhtml_escape(s)

    def get_escaped_argument(self, name, default=None):
        if default is not None:
            return self.escape_string(self.get_argument(name, default))
        else:
            return self.get_argument(name, default)

    def get_page_url(self, page, form_id=None, tab=None):
        if form_id:
            return "javascript:goto_page('%s',%s);" % (form_id.strip(), page)
        path = self.request.path
        query = self.request.query
        # qdict = urlparse.parse_qs(query)
        qdict = parse_qs(query)
        for k, v in qdict.items():
            if isinstance(v, list):
                qdict[k] = v and v[0] or ''
        qdict['page'] = page
        if tab:
            qdict['tab'] = tab
        return path + '?' + urllib.urlencode(qdict)

    def find_all(self, target, substring):
        current_pos = target.find(substring)
        while current_pos != -1:
            yield current_pos
            current_pos += len(substring)
            current_pos = target.find(substring, current_pos)


class WebHandler(BaseHandler):
    def finish(self, chunk=None, message=None):
        callback = escape.utf8(self.get_argument("callback", None))
        if callback:
            self.set_header("Content-Type", "application/x-javascript")

            if isinstance(chunk, dict):
                chunk = escape.json_encode(chunk)

            self._write_buffer = [callback, "(", chunk, ")"] if chunk else []
            super(WebHandler, self).finish()
        else:
            self.set_header("Cache-control", "no-cache")
            super(WebHandler, self).finish(chunk)

    def write_error(self, status_code, **kwargs):
        try:
            exc_info = kwargs.pop('exc_info')
            e = exc_info[1]

            if isinstance(e, errors.HTTPAPIError):
                pass
            elif isinstance(e, HTTPError):
                if e.status_code == 401:
                    self.redirect("/", permanent=True)
                    return
                e = errors.HTTPAPIError(e.status_code)
            else:
                e = errors.HTTPAPIError(errors.ERROR_INTERNAL_SERVER_ERROR)

            exception = "".join([ln for ln
                                 in traceback.format_exception(*exc_info)])

            if status_code == errors.ERROR_INTERNAL_SERVER_ERROR \
                    and not options.debug:
                self.send_error_mail(exception)

            if options.debug:
                e.data["exception"] = exception

            self.clear()
            # always return 200 OK for Web errors
            self.set_status(errors.HTTP_OK)
            self.set_header("Content-Type", "application/json; charset=UTF-8")
            self.finish(str(e))
        except Exception:
            logging.error(traceback.format_exc())
            return super(WebHandler, self).write_error(status_code, **kwargs)

    def send_error_mail(self, exception):
        pass


class APIHandler(BaseHandler):
    def finish(self, chunk=None, message=None):
        if chunk is None:
            chunk = {}

        if isinstance(chunk, dict):
            chunk = {"meta": {"code": errors.HTTP_OK}, "data": chunk}

            if message:
                chunk["message"] = message
        callback = escape.utf8(self.get_argument("callback", None))
        if callback:
            self.set_header("Content-Type", "application/x-javascript")

            if isinstance(chunk, dict):
                chunk = escape.json_encode(chunk)

            self._write_buffer = [callback, "(", chunk, ")"] if chunk else []
            super(APIHandler, self).finish()
        else:
            self.set_header("Content-Type", "application/json; charset=UTF-8")
            super(APIHandler, self).finish(chunk)

    def write_error(self, status_code, **kwargs):
        try:
            exc_info = kwargs.pop('exc_info')
            e = exc_info[1]

            if isinstance(e, errors.HTTPAPIError):
                pass
            elif isinstance(e, HTTPError):
                e = errors.HTTPAPIError(e.status_code)
            else:
                e = errors.HTTPAPIError(errors.ERROR_INTERNAL_SERVER_ERROR)

            exception = "".join([ln for ln
                                 in traceback.format_exception(*exc_info)])

            if status_code == errors.ERROR_INTERNAL_SERVER_ERROR \
                    and not options.debug:
                self.send_error_mail(exception)

            if options.debug:
                e.data["exception"] = exception

            self.clear()
            # always return 200 OK for API errors
            self.set_status(errors.HTTP_OK)
            self.set_header("Content-Type", "application/json; charset=UTF-8")
            self.finish(str(e))
        except Exception:
            logging.error(traceback.format_exc())
            return super(APIHandler, self).write_error(status_code, **kwargs)

    def send_error_mail(self, exception):
        """Override to implement custom error mail"""
        pass


class ErrorHandler(BaseHandler):
    """Default 404: Not Found handler."""

    def prepare(self):
        super(ErrorHandler, self).prepare()
        raise HTTPError(errors.ERROR_NOT_FOUND)


class APIErrorHandler(APIHandler):
    """Default API 404: Not Found handler."""

    def prepare(self):
        super(APIErrorHandler, self).prepare()
        raise errors.HTTPAPIError(errors.ERROR_NOT_FOUND)


def authenticated(method):
    """Decorate methods with this to require that the user be logged in.

    Just raise 401
    or be avaliable
    """

    @functools.wraps(method)
    def wrapper(self, *args, **kwargs):
        if not self.current_user:
            # raise HTTPError(401)
            raise errors.HTTPAPIError(errors.ERROR_UNAUTHORIZED, "登录失效")
        return method(self, *args, **kwargs)

    return wrapper


def authenticated_admin(method):
    @functools.wraps(method)
    def wrapper(self, *args, **kwargs):
        if int(self.current_user.role) != 1001:
            # raise HTTPError(403)
            raise errors.HTTPAPIError(errors.ERROR_FORBIDDEN, "permission denied")
        return method(self, *args, **kwargs)

    return wrapper


def operation_log(primary_menu, sub_menu, ope_type, content, comment=""):
    """
    Add logging to a function. level is the logging
    level, name is the logger name, and message is the
    log message. If name and message aren't specified,
    they default to the function's module and name.
    """

    def decorate(func):
        @functools.wraps(func)
        def wrapper(self, *args, **kwargs):
            try:
                with self.app_mysql.connect() as conn:
                    conn.execute(text(
                        "insert into sys_log(user, ip, primary_menu, sub_menu, op_type, content, comment) "
                        "values(:user, :ip, :primary_menu, :sub_menu, :op_type, :content, :comment)"
                    ),
                        {"user": self.current_user.name,
                         "ip": self.request.headers[
                             "X-Forwarded-For"] if "X-Forwarded-For" in self.request.headers else self.request.remote_ip,
                         "primary_menu": primary_menu,
                         "sub_menu": sub_menu,
                         "op_type": ope_type,
                         "content": content,
                         "comment": comment
                         }
                    )
                    conn.commit()
            except Exception as e:
                logging.info("operation log error: %s" % e)

            return func(self, *args, **kwargs)

        return wrapper

    return decorate


def permission(codes):
    def decorate(func):
        @functools.wraps(func)
        def wrapper(self, *args, **kwargs):
            rows = self.db_app.query(
                "select rp.permission from role_permission rp, user_role ur where rp.role=ur.role and ur.userid=%s",
                self.current_user.id
            )
            permissions = [item["permission"] for item in rows]
            for code in codes:
                if code not in permissions:
                    raise errors.HTTPAPIError(errors.ERROR_FORBIDDEN, "permission denied")

            return func(self, *args, **kwargs)

        return wrapper

    return decorate


def license_validate(codes):
    def decorate(func):
        @functools.wraps(func)
        def wrapper(self, *args, **kwargs):
            license_cache = self.r_app.get("system:license")
            license_info = {}
            if license_cache:
                license_info = json.loads(self.tostr(license_cache))
            else:
                row = self.db_app.get("select syscode, expireat from license limit 1")
                if row:
                    self.r_app.set("system:license",
                                   json.dumps({"syscode": row["syscode"], "expireat": row["expireat"]}))
                    license_info = row

            license_status = get_license_status(license_info)
            # logging.info("license status is : {}, need : {}".format(license_status, codes))
            if license_status not in codes:
                raise errors.HTTPAPIError(errors.ERROR_LICENSE_NOT_ACTIVE, "系统License未授权")
            # if not license_info:
            #     raise errors.HTTPAPIError(errors.ERROR_LICENSE_NOT_ACTIVE, "License未授权")
            # expireat = int(license_info["expireat"])
            # local_time = int(time.time())
            # if local_time >= expireat:
            #     raise errors.HTTPAPIError(errors.ERROR_LICENSE_NOT_ACTIVE, "License授权过期")

            return func(self, *args, **kwargs)

        return wrapper

    return decorate


def userlog(method):
    @functools.wraps(method)
    def wrapper(self, *args, **kwargs):
        logging.info("[ip]:%s [user]:%s" % (self.request.remote_ip, self.current_user.id))
        return method(self, *args, **kwargs)

    return wrapper