# -*- coding: utf-8 -*-
from binascii import a2b_hex, b2a_hex

from M2Crypto import BIO
from M2Crypto import RSA


def load_pub_key_string(string):
    bio = BIO.MemoryBuffer(string)
    return RSA.load_pub_key_bio(bio)


def block_data(texts, block_size):
    for i in range(0, len(texts), block_size):
        yield texts[i:i + block_size]


def encrypt(texts):
    ciphertext = b""
    block_size = 256 - 11

    for text in block_data(texts.encode('utf-8'), block_size):
        current_text = pri_key.private_encrypt(text, RSA.pkcs1_padding)
        ciphertext += current_text

    return b2a_hex(ciphertext)


def decrypt(texts):
    plaintext = b""
    block_size = 256

    for text in block_data(a2b_hex(texts), block_size):
        current_text = pub_key.public_decrypt(text, RSA.pkcs1_padding)
        plaintext += current_text

    return plaintext


if __name__ == '__main__':
    # 2048代表生成密钥的位数,65537代表公钥的指数
    key = RSA.gen_key(2048, 65537)
    key.save_key("private_key", None)
    key.save_pub_key("public_key")

    prikey = open("private_key").read()
    pubkey = open("public_key").read()

    pri_key = RSA.load_key_string(prikey.strip('\n').encode('utf-8'))
    pub_key = load_pub_key_string(pubkey.strip('\n').encode('utf-8'))

    texts = "hellohellohellohellohellohellohellohellohellohellohellohellohello" \
            "hellohellohellohellohellohellohellohellohellohellohellohellohello" \
            "hellohellohellohellohellohellohellohellohellohellohellohellohello" \
            "hellohellohellohellohellohellohellohellohellohellohello"

    ciphertext = encrypt(texts)
    print(ciphertext)
    plaintext = decrypt(ciphertext)
    print(plaintext)