Source code for cascada.primitives.aes_masked

"""SMT-based experiments of `A Low-Randomness Second-Order Masked AES
<https://eprint.iacr.org/2021/1111>`_.

To run, call `find_best_key_schedule_characteristic` or `find_best_encryption_characteristic`.
"""
import functools
from decimal import Decimal

from cascada.bitvector.core import Constant, Term
from cascada.bitvector.operation import BvOr, BvComp, Concat, BvIdentity
from cascada.bitvector.secondaryop import LutOperation, PopCount
from cascada.linear.mask import LinearMask
from cascada.linear.opmodel import get_weak_model
from cascada.linear.chmodel import ChModel
from cascada.smt.chsearch import (
    ChModelAssertType, ChFinder, PrintingMode, _get_smart_print
)


from cascada.primitives import aes


[docs]class SboxA(aes.SboxLut): """Represent a generic S-box with the following absolute correlations: - zero -> zero: :math:`1` - non-zero -> non-zero: :math:`2^{-3}` - zero -> non-zero: :math:`2^{-3.8}` """
[docs]class SboxB(aes.SboxLut): """Represent a generic S-box with the following absolute correlations: - zero -> zero: :math:`1` - non-zero -> non-zero: :math:`2^{-2.6}` - zero -> non-zero: :math:`2^{-4}` """
[docs]class SboxZeroWeight(aes.SboxLut): """Represent a generic S-box with the following absolute correlations: - zero -> zero: 1 - non-zero -> non-zero: 1 - zero -> non-zero: 1 """
SboxA.linear_model = get_weak_model( SboxA, Decimal(3), zero2nonzero_weight=Decimal(3.8), precision=4) SboxB.linear_model = get_weak_model( SboxB, Decimal(2.6), zero2nonzero_weight=Decimal(4), precision=4) SboxZeroWeight.linear_model = get_weak_model( SboxZeroWeight, 0, zero2nonzero_weight=0)
[docs]class AESMaskedKeySchedule(aes.AESKeySchedule): """Key schedule of masked AES-128.""" num_rounds = 10 input_widths = [8 for _ in range(16)] + [8 for _ in range(2 * 10)] output_widths = [8 for _ in range(16)] logging_mode = aes.LoggingMode.Silent num_extra_cells = 2*10 ignore_first_last_sbox_weights = True
[docs] @classmethod def set_num_rounds(cls, new_num_rounds): cls.num_rounds = new_num_rounds cls.num_extra_cells = 2 * cls.num_rounds cls.input_widths = [8 for _ in range(16)] + [8 for _ in range(cls.num_extra_cells)]
@classmethod def _sub_word(cls, my_word, extra_cell, my_sbox): a, b, c, d = my_word sbox_inputs = [x for x in [a, b, c, d]] xor_values = [BvIdentity(xv) for xv in [b, c, d, extra_cell]] # new copies with != names sbox_outputs = [my_sbox(x) for x in sbox_inputs] # changing of the guards a_ = sbox_outputs[0] ^ xor_values[0] b_ = sbox_outputs[1] ^ xor_values[1] c_ = sbox_outputs[2] ^ xor_values[2] d_ = sbox_outputs[3] ^ xor_values[3] new_word = [a_, b_, c_, d_] if cls.logging_mode == aes.LoggingMode.Debug: # ensuring the correct values are logged (i.e., not intermediate branches) by # using different names for sbox_inputs and xor_values name_sbox = my_sbox.__name__ # cls.log_msg(" - word: [{},{},{},{}]\n - sbox_inputs: [{},{},{},{}]\n - xor_values: [{},{},{},{}]", # my_word + sbox_inputs + xor_values) for i in range(len(sbox_outputs)): cls.log_msg(f" - {name_sbox}({{}}) -> {{}} XOR {{}}) -> {{}}", [sbox_inputs[i], sbox_outputs[i], xor_values[i], new_word[i]]) return new_word @classmethod def sub_word(cls, my_word): output_word = my_word[:] r = cls._current_sub_word_round if r == 0 and cls.ignore_first_last_sbox_weights: my_sbox = SboxZeroWeight else: my_sbox = SboxA output_word = cls._sub_word(output_word, cls._list_extra_cells[2*r + 0], my_sbox) if r == cls.num_rounds - 1 and cls.ignore_first_last_sbox_weights: my_sbox = SboxZeroWeight else: my_sbox = SboxB output_word = cls._sub_word(output_word, cls._list_extra_cells[2*r + 1], my_sbox) cls._current_sub_word_round += 1 return output_word
[docs] @classmethod def eval(cls, *master_key): cls._current_sub_word_round = 0 assert cls.num_extra_cells > 0 cls._list_extra_cells = master_key[-cls.num_extra_cells:] if cls.logging_mode != aes.LoggingMode.Silent: format_string = "extra cells:\n(" + ",".join(["{}"]*cls.num_extra_cells) + ")" cls.log_msg(format_string, list(cls._list_extra_cells[:])) result = super().eval(*master_key) assert cls._current_sub_word_round == cls.num_rounds return result[-16:]
[docs]def get_key_schedule_constraints(ch_model, verbose=False, filename=None): """Get the following initial constraints: - extra cells not active - (at most 2 adjacent cells in the last column of the input) OR (at most 1 row active in the input) - (at most 2 adjacent cells in the last column of the output) OR (at most 1 row active in the output) """ initial_constraints = [] smart_print = _get_smart_print(filename) num_rows, num_columns = 4, 4 num_extra_cells = ch_model.func.num_extra_cells assert num_extra_cells > 0 def to_1b(my_val): # return 0b0 if my_val == 0 and 0b1 if my_val != 0 (active) assert isinstance(my_val, Term) if my_val.width == 1: return my_val return ~(BvComp(my_val, Constant(0, my_val.width))) input_state_list = [m.val for m in ch_model.input_mask] output_state_list = [m.val for m in ch_model.output_mask] extra_cells_mask = input_state_list[-num_extra_cells:] input_state_list = input_state_list[:-num_extra_cells] # without extra cells if verbose: smart_print("get_key_schedule_constraints:") class AuxAESLikeFunction(aes.AESLikeFunction): num_rows, num_columns = 4, 4 cell_width = 8 logging_mode = aes.LoggingMode.Debug my_matrix = AuxAESLikeFunction.list2matrix(input_state_list) fs, ffo = AuxAESLikeFunction.log_activity_matrix_state(my_matrix, return_format=True) smart_print("Input masks: \n"+fs.format(*ffo)) smart_print(f"Extra cell masks: {extra_cells_mask}") my_matrix = AuxAESLikeFunction.list2matrix(output_state_list) fs, ffo = AuxAESLikeFunction.log_activity_matrix_state(my_matrix, return_format=True) smart_print("Output masks:\n"+fs.format(*ffo)) extra_cells_mask = functools.reduce(BvOr, [m for m in extra_cells_mask]) initial_constraints.append(BvComp(extra_cells_mask, Constant(0, extra_cells_mask.width))) if verbose: smart_print("Extra cell constraint:") smart_print(" - extra cells == 0 <==>", initial_constraints[-1]) for index_state, state_list in enumerate([input_state_list, output_state_list]): if verbose: if index_state == 0: smart_print("Input state constraints: (c11 & c12) | c2") else: smart_print("Output state constraints: (c11 & c12) | c2") # - c1 # state ordered by columns last_column_cell_active = [to_1b(cell) for cell in state_list[-4:]] # last_column_cell_active[i] == 1 if i-th cell in last column is active last_column_cell_pair_active = [] # last_column_cell_pair_active[j] == 1 if j-th pair on consecutive column cell is active for j in range(num_rows): last_column_cell_pair_active.append(last_column_cell_active[j] | last_column_cell_active[(j + 1) % num_rows]) last_column_cell_pair_active = functools.reduce(Concat, [m for m in last_column_cell_pair_active]) num_active = PopCount(last_column_cell_pair_active) c11 = BvComp(num_active, Constant(2, num_active.width)) if verbose: smart_print(" - c11: 2 == HW( Concat(column_cell_pair_active) = {} ) <==> {}".format(last_column_cell_pair_active, c11)) # first_columns_cells[i] is the i-th cell (containing all cells in every but last column) first_columns_cells = state_list[:-4] first_columns_cells = functools.reduce(BvOr, [m for m in first_columns_cells]) c12 = BvComp(first_columns_cells, Constant(0, first_columns_cells.width)) if verbose: smart_print(" - c12: 0 == Concat(column0,column1,column2) = {} <==> {}".format(first_columns_cells, c12)) c1 = c11 & c12 # - c2 def has_hw_1(x): return BvComp(x & (x - 1), Constant(0, x.width)) row_active = [] # row_active[i] == 1 if i-th row is active for row in range(num_rows): my_row = [state_list[row + (column * num_rows)] for column in range(num_columns)] row_active.append(to_1b(functools.reduce(Concat, my_row))) row_active = functools.reduce(Concat, [m for m in row_active]) c2 = has_hw_1(row_active) if verbose: smart_print(" - c2: 1 == HW( Concat(row_active) = {} ) <==> {}".format(row_active, c2)) initial_constraints.append(c1 | c2) if verbose: smart_print("") return initial_constraints
[docs]def find_best_key_schedule_characteristic(verbose=False): """Find the best trail spanning eight rounds and activating 21 masked S-boxes, with total absolute correlation :math:`2^{−63.60}`. >>> from cascada.primitives.aes_masked import find_best_key_schedule_characteristic >>> found_ch = find_best_key_schedule_characteristic(verbose=False) >>> print(found_ch.srepr()) # doctest:+NORMALIZE_WHITESPACE Ch(w=63.60, id=00 04 00 00 00 53 00 00 00 24 00 00 00 72 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00, od=00 04 00 00 00 53 00 00 00 24 00 00 00 72 00 00) """ num_rounds = 8 logging_mode = aes.LoggingMode.Debug if verbose else aes.LoggingMode.Silent printing_mode = PrintingMode.Debug if verbose else PrintingMode.Silent filename = None AESMaskedKeySchedule.logging_mode = logging_mode AESMaskedKeySchedule.set_num_rounds(num_rounds) input_names = [f"k{i}" for i in range(16)] + [f"r{i}" for i in range(2 * num_rounds)] ch_model = ChModel(AESMaskedKeySchedule, LinearMask, input_names, "x") initial_constraints = get_key_schedule_constraints(ch_model, verbose=verbose, filename=filename) assert_type = ChModelAssertType.ValidityAndWeight ch_finder = ChFinder(ch_model, assert_type, "btor", solver_seed=0, initial_constraints=initial_constraints, printing_mode=printing_mode, filename=filename) ch_found = next(ch_finder.find_next_ch_increasing_weight(initial_weight=1)) if verbose: smart_print = _get_smart_print(filename) smart_print("\nFound trail:", ch_found.srepr(), "\n") smart_print("\n".join(ch_found.get_formatted_logged_msgs())) return ch_found
[docs]class AESMaskedEncryption(aes.AESLikeFunction): """Encryption of masked AES-128.""" num_rounds = 10 num_rows, num_columns = 4, 5 cell_width = 8 mix_columns_bit_matrix = aes.MixColumnsBitMatrix ignore_first_subcell = True ignore_last_subcell = True ignore_first_last_sbox_weights = True name_sub_cells = "SubBytes" logging_mode = aes.LoggingMode.Silent @classmethod def sub_cells(cls, my_matrix_state, my_sbox): new_matrix_state = [row[:] for row in my_matrix_state] for row in range(cls.num_rows): a, b, c, d, e = my_matrix_state[row] if row == 0: xor_values = [b, c, d, e, a] elif row == 1: xor_values = [c, d, e, a, b] elif row == 2: xor_values = [d, e, a, b, c] elif row == 3: xor_values = [e, a, b, c, d] else: raise ValueError("invalid row") sbox_inputs = [x for x in [a, b, c, d]] xor_values = [BvIdentity(x) for x in xor_values] # new copies with != names sbox_outputs = [my_sbox(x) for x in sbox_inputs] # changing of the guards a_ = sbox_outputs[0] ^ xor_values[0] b_ = sbox_outputs[1] ^ xor_values[1] c_ = sbox_outputs[2] ^ xor_values[2] d_ = sbox_outputs[3] ^ xor_values[3] e_ = BvIdentity(xor_values[4]) # new copy with != name new_matrix_state[row] = [a_, b_, c_, d_, e_] if cls.logging_mode == aes.LoggingMode.Debug: # ensuring the correct values are logged (i.e., not intermediate branches) by # - using different names for sbox_inputs and xor_values # - using a new name for each output cell name_sbox = my_sbox.__name__ # cls.log_msg(" - row: [{},{},{},{},{}]\n - sbox_inputs: [{},{},{},{}]\n - xor_values: [{},{},{},{}]", # my_matrix_state[row] + sbox_inputs + xor_values) cls.log_msg(f" {getattr(cls, 'name_sub_cells', 'SubCells')} row {row}:") for i in range(len(sbox_outputs)): cls.log_msg(f" - {name_sbox}({{}}) -> {{}} XOR {{}}) -> {{}}", [sbox_inputs[i], sbox_outputs[i], xor_values[i], new_matrix_state[row][i]]) if cls.logging_mode not in [aes.LoggingMode.Silent, aes.LoggingMode.RoundOutputs]: cls.log_msg(f"output of {getattr(cls, 'name_sub_cells', 'SubCells')}:") cls.log_activity_matrix_state(new_matrix_state) return new_matrix_state @classmethod def shift_rows(cls, my_matrix_state): new_matrix_state = [row[:] for row in my_matrix_state] for r in range(4): offset = r # shift offset (0 for r=0, 1 for r=1, ...) for c in range(4): new_matrix_state[r][c] = my_matrix_state[r][(c + offset) % 4] if cls.logging_mode not in [aes.LoggingMode.Silent, aes.LoggingMode.RoundOutputs]: cls.log_msg(f"output of {getattr(cls, 'name_shift_rows', 'ShiftRows')}:") cls.log_activity_matrix_state(new_matrix_state) return new_matrix_state
[docs] @classmethod def eval(cls, *plaintext): matrix_state = cls.list2matrix(plaintext) if cls.logging_mode != aes.LoggingMode.Silent: cls.log_msg("\nplaintext:") cls.log_activity_matrix_state(matrix_state) ign = cls.ignore_first_last_sbox_weights for r in range(cls.num_rounds): if r == 0: if cls.ignore_first_subcell: # matrix_state = cls.sub_cells(matrix_state, my_sbox=SboxB if not ign else SboxZeroWeight) else: matrix_state = cls.sub_cells(matrix_state, my_sbox=SboxA if not ign else SboxZeroWeight) matrix_state = cls.sub_cells(matrix_state, my_sbox=SboxB) matrix_state = cls.shift_rows(matrix_state) matrix_state = cls.mix_columns(matrix_state) elif r < cls.num_rounds - 1: matrix_state = cls.sub_cells(matrix_state, my_sbox=SboxA) matrix_state = cls.sub_cells(matrix_state, my_sbox=SboxB) matrix_state = cls.shift_rows(matrix_state) matrix_state = cls.mix_columns(matrix_state) elif r == cls.num_rounds - 1: if cls.ignore_last_subcell: matrix_state = cls.sub_cells(matrix_state, my_sbox=SboxA if not ign else SboxZeroWeight) # else: matrix_state = cls.sub_cells(matrix_state, my_sbox=SboxA) matrix_state = cls.sub_cells(matrix_state, my_sbox=SboxB if not ign else SboxZeroWeight) if cls.logging_mode != aes.LoggingMode.Silent: cls.log_msg(f"output of round {r+1}:") # 1 <= {r+1} <= {cls.num_rounds} cls.log_activity_matrix_state(matrix_state) cls.log_msg("", []) # new line cls.add_round_outputs(*cls.matrix2list(matrix_state)) return tuple(cls.matrix2list(matrix_state))
[docs]def get_encryption_constraints(ch_model, verbose=False, filename=None): """Get the initial constraint ensuring exactly 1 cell active in the input and 1 cell active in the output.""" initial_constraints = [] smart_print = _get_smart_print(filename) if verbose: smart_print("get_encryption_constraints:") my_matrix = ch_model.func.list2matrix([m.val for m in ch_model.input_mask]) fs, ffo = ch_model.func.log_activity_matrix_state(my_matrix, return_format=True) smart_print("Input masks: \n"+fs.format(*ffo)) my_matrix = ch_model.func.list2matrix([m.val for m in ch_model.output_mask]) fs, ffo = ch_model.func.log_activity_matrix_state(my_matrix, return_format=True) smart_print("Output masks:\n"+fs.format(*ffo)) def has_hw_1(x): return BvComp(x & (x - 1), Constant(0, x.width)) concat_input_mask = functools.reduce(Concat, [m.val for m in ch_model.input_mask]) initial_constraints.append(has_hw_1(concat_input_mask)) concat_output_mask = functools.reduce(Concat, [m.val for m in ch_model.output_mask]) initial_constraints.append(has_hw_1(concat_output_mask)) if verbose: smart_print(" - 1 == HW( {} ) <==> {}".format(concat_input_mask, initial_constraints[0])) smart_print(" - 1 == HW( {} ) <==> {}".format(concat_output_mask, initial_constraints[1])) return initial_constraints
[docs]def find_best_encryption_characteristic(verbose=False): """Find the best trail spanning 3 rounds and activating 21 masked S-boxes, with total absolute correlation :math:`2^{-51.60}`. >>> from cascada.primitives.aes_masked import find_best_encryption_characteristic >>> found_ch = find_best_encryption_characteristic(verbose=False) >>> print(found_ch.srepr()) # doctest:+NORMALIZE_WHITESPACE Ch(w=52.20, id=00 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00, od=00 00 00 00 00 00 00 00 40 00 00 00 00 00 00 00 00 00 00 00) """ num_rounds = 3 logging_mode = aes.LoggingMode.Debug if verbose else aes.LoggingMode.Silent printing_mode = PrintingMode.Debug if verbose else PrintingMode.Silent filename = None AESMaskedEncryption.logging_mode = logging_mode AESMaskedEncryption.set_num_rounds(num_rounds) input_names = [f"p{i}" for i in range(4*5)] ch_model = ChModel(AESMaskedEncryption, LinearMask, input_names, "x") initial_constraints = get_encryption_constraints(ch_model, verbose=verbose, filename=filename) assert_type = ChModelAssertType.ValidityAndWeight ch_finder = ChFinder(ch_model, assert_type, "btor", solver_seed=0, initial_constraints=initial_constraints, printing_mode=printing_mode, filename=filename) ch_found = next(ch_finder.find_next_ch_increasing_weight(initial_weight=1)) if verbose: smart_print = _get_smart_print(filename) smart_print("\nFound trail:", ch_found.srepr(), "\n") smart_print("\n".join(ch_found.get_formatted_logged_msgs())) return ch_found