# -*- coding: utf-8 -*-
import argparse
import base64
import json
import sys

from Crypto.Cipher import PKCS1_OAEP
from Crypto.PublicKey import RSA

"""
RSA 加密中,有两种常见的填充方式:PKCS1_v1_5 和 PKCS1_OAEP。这两种填充方式在安全性和性能方面都有一些差异。

PKCS1_v1_5 填充方式:
    这是较早的 RSA 填充方式,相对简单且性能较好。
    但是它存在一些安全隐患,比如可能会受到选择密文攻击(Chosen Ciphertext Attack, CCA)。
PKCS1_OAEP 填充方式:
    PKCS1_OAEP 是一种更加安全的填充方式,它使用了随机填充来提高安全性。
    PKCS1_OAEP 可以抵御选择密文攻击(CCA)和其他一些攻击方式,因此被认为更加安全。
    但是,PKCS1_OAEP 的性能略低于 PKCS1_v1_5,因为它需要进行更多的计算。
"""


# 生成密钥对
def generate_keys():
    key = RSA.generate(2048)
    private_key = key.export_key()
    public_key = key.publickey().export_key()
    return private_key, public_key


# 公钥加密消息,message为bytes类型
def encrypt_message_pub(public_key, message):
    cipher = PKCS1_OAEP.new(RSA.import_key(public_key))
    encrypted_message = base64.b64encode(cipher.encrypt(message))
    # print("Encrypted message:", encrypted_message.decode())
    return encrypted_message


# 私钥解密消息
def decrypt_message_pri(private_key, encrypted_message):
    decipher = PKCS1_OAEP.new(RSA.import_key(private_key))
    decrypted_message = decipher.decrypt(base64.b64decode(encrypted_message))
    # print("Decrypted message:", decrypted_message.decode())
    return decrypted_message.decode()


def test():
    # 生成密钥对
    private_key, public_key = generate_keys()
    print(private_key)
    print(public_key)

    # 打印公钥和私钥
    print("Private Key:")
    print(private_key.decode())
    print("Public Key:")
    print(public_key.decode())

    # 待加密消息
    # message = b"Hello, RSA!"
    message = "Hello, RSA!".encode()

    # 加密消息
    encrypted_message = encrypt_message_pub(public_key, message)
    print("Encrypted Message:")
    print(encrypted_message)

    # 解密消息
    decrypted_message = decrypt_message_pri(private_key, encrypted_message)
    print("Decrypted Message:")
    print(decrypted_message)


# 主程序
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='manual to sign enterprise license')
    parser.add_argument("--o", type=str, help="operation type", required=True)  # name, t/test, s/sign, g/generate key, d/decrypt
    parser.add_argument("--c", type=str, help="enterprise's sys code")  # code
    parser.add_argument("--e", type=str, help="expire date")  # expire

    args = parser.parse_args()
    operation = args.o
    if operation == "t":
        test()
    elif operation == "g":
        private_key, public_key = generate_keys()

        with open("private_key.pem", "wb") as f:
            f.write(private_key)

        with open("public_key.pem", "wb") as f:
            f.write(public_key)

    elif operation == "s":
        code = args.c
        expire = args.e
        if not code or not expire:
            print("sys code and expire date are required")
            sys.exit(1)

        pub_key = open("public_key.pem", "r").read()
        license = encrypt_message_pub(pub_key.strip('\n').encode('utf-8'),
                                      json.dumps({"sys_code": code, "expire_at": expire}).encode("utf-8"))

        with open("license", "wb") as f:
            f.write(license)
    elif operation == "d":
        private_key = open("private_key.pem", "r").read()
        with open("license", "rb") as f:
            license = f.read()
            # 解密消息
            body = decrypt_message_pri(private_key.strip('\n').encode('utf-8'), license)
            json_body = json.loads(body)
            print(json_body)