import base64 from math import ceil from typing import List, Tuple from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from utils import * backend = default_backend() def split_bytes_in_blocks(x: bytes, block_size: int) -> List[bytes]: """Splits a byte string into a list of blocks of equal size. Args: x (bytes): The byte string to split. block_size (int): The size of each block in bytes. Returns: List[bytes]: A list of byte strings, each of length block_size, except for the last one which may be shorter. """ nb_blocks = ceil(len(x) / block_size) return [x[block_size * i : block_size * (i + 1)] for i in range(nb_blocks)] def pkcs7_padding(message: bytes, block_size: int) -> bytes: """Applies PKCS#7 padding to a byte string. Args: message (bytes): The byte string to pad. block_size (int): The size of the block in bytes. Returns: bytes: A byte string that is a multiple of block_size in length, with padding bytes added at the end. The value of each padding byte is equal to the number of padding bytes added. """ padding_length = block_size - (len(message) % block_size) if padding_length == 0: padding_length = block_size padding = bytes([padding_length]) * padding_length return message + padding def pkcs7_strip(data: bytes) -> bytes: """Removes PKCS#7 padding from a byte string. Args: data (bytes): The byte string to strip. Returns: bytes: A byte string with the padding bytes removed from the end. """ padding_length = data[-1] return data[:-padding_length] def encrypt_aes_128_ecb(plaintext: bytes, key: bytes) -> bytes: """Encrypts a byte string using AES-128 in ECB mode. Args: plaintext (bytes): The byte string to encrypt. It will be padded using PKCS#7. key (bytes): The encryption key. It must be 16 bytes in length. Returns: bytes: A byte string that is the encrypted version of plaintext. """ padded_msg = pkcs7_padding(plaintext, block_size=16) cipher = Cipher(algorithms.AES(key), modes.ECB(), backend=backend) encryptor = cipher.encryptor() return encryptor.update(padded_msg) + encryptor.finalize() def decrypt_aes_128_ecb(ciphertext: bytes, key: bytes) -> bytes: """Decrypts a byte string using AES-128 in ECB mode. Args: ciphertext (bytes): The byte string to decrypt. It must be a multiple of 16 bytes in length. key (bytes): The decryption key. It must be 16 bytes in length. Returns: bytes: A byte string that is the decrypted version of ciphertext. The PKCS#7 padding will be removed. """ cipher = Cipher(algorithms.AES(key), modes.ECB(), backend=backend) decryptor = cipher.decryptor() decrypted_data = decryptor.update(ciphertext) + decryptor.finalize() message = pkcs7_strip(decrypted_data) return message class Oracle: """A class that simulates an encryption oracle using AES-128 in ECB mode. You are not suppose to see this""" def __init__(self) -> None: self.key = "Mambo NumberFive".encode() self.prefix = "PREF".encode() # You are suppose to break this self.target = base64.b64decode( "RG8gbm90IGxheSB1cCBmb3IgeW91cnNlbHZlcyB0cmVhc3VyZXMgb24gZWFydGgsI" "HdoZXJlIG1vdGggYW5kIHJ1c3QgZGVzdHJveSBhbmQgd2hlcmUgdGhpZXZlcyBicm" "VhayBpbiBhbmQgc3RlYWwsCmJ1dCBsYXkgdXAgZm9yIHlvdXJzZWx2ZXMgdHJlYXN" "1cmVzIGluIGhlYXZlbiwgd2hlcmUgbmVpdGhlciBtb3RoIG5vciBydXN0IGRlc3Ry" "b3lzIGFuZCB3aGVyZSB0aGlldmVzIGRvIG5vdCBicmVhayBpbiBhbmQgc3RlYWwuC" "kZvciB3aGVyZSB5b3VyIHRyZWFzdXJlIGlzLCB0aGVyZSB5b3VyIGhlYXJ0IHdpbG" "wgYmUgYWxzby4=" ) def encrypt(self, message: bytes) -> bytes: return encrypt_aes_128_ecb( self.prefix + message + self.target, self.key, ) # Task 1 def find_block_size() -> Tuple[int, int, int]: initial_length = len(Oracle().encrypt(b"")) i = 0 block_size = 0 size_of_prefix_target_padding = 0 minimum_size_to_align_plaintext = 0 while 1: # Feed identical bytes of your-string to the function 1 at a time # until you get the block length. You will also need to determine # here the size of fixed prefix + target + pad, and the minimum # size of the plaintext to make a new block length = len(Oracle().encrypt(b"X" * i)) i += 1 # TODO 1: find block_size, size_of_prefix_target_padding, # and minimum_size_to_align_plaintext break return ( block_size, size_of_prefix_target_padding, minimum_size_to_align_plaintext, ) # Task 2 def find_prefix_size(block_size: int) -> int: initial_blocks = split_bytes_in_blocks(Oracle().encrypt(b""), block_size) # TODO 2: Find when prefix_size + padding_size - 1 = block_size # Use split_bytes_in_blocks to get blocks of size block_size. # TODO 2.1: Find the block containing the prefix by comparing # initial_blocks and modified_blocks # You may find enumerate() and zip() useful. modified_blocks = split_bytes_in_blocks(Oracle().encrypt(b"X"), block_size) prefix_block_index = 0 # TODO 2.2: As now we know in which block to look, find when that block # does not change anymore when adding more X's. The complementary will # represent the prefix. prefix_size_in_block = 0 prefix_size = prefix_block_index * block_size + prefix_size_in_block return prefix_size # Task 3 def recover_one_byte_at_a_time( block_size: int, prefix_size: int, target_size: int, ) -> str: known_target_bytes = b"" for _ in range(target_size): # prefix_size + padding_length + known_len + 1 = 0 mod block_size known_len = len(known_target_bytes) padding_length = (-known_len - 1 - prefix_size) % block_size padding = b"X" * padding_length # TODO 3.1: Determine the target block index which contains only known # characters except its last character. # TODO 3.2: Get the target block form split_bytes_in_blocks at the index # previously determined. # TODO 3.3: Try every possibility for the last character and search for # the block that you already know. That character will be added to # the known target bytes. return known_target_bytes.decode() def main() -> None: # Find block size, prefix size, and length of plaintext size to align blocks ( block_size, size_of_prefix_target_padding, minimum_size_to_align_plaintext, ) = find_block_size() print(f"Block size:\t\t\t\t{block_size}") print( "Size of prefix, target, and padding:" f"\t{size_of_prefix_target_padding}" ) print(f"Pad needed to align:\t\t\t{minimum_size_to_align_plaintext}") # Find size of the prefix prefix_size = find_prefix_size(block_size) print(f"\nPrefix Size:\t{prefix_size}") # Size of the target target_size = ( size_of_prefix_target_padding - minimum_size_to_align_plaintext - prefix_size ) # Recover the target recovered_target = recover_one_byte_at_a_time( block_size, prefix_size, target_size, ) print(f"\nTarget: {recovered_target}") if __name__ == "__main__": main()