Source code for cascada.primitives.aes_like

"""AES-like functions."""
import enum

from cascada.bitvector.core import Constant, Term
from cascada.bitvector.operation import BvComp, BvNot
from cascada.bitvector.context import Memoization
from cascada.abstractproperty.property import make_partial_propextract

from cascada.bitvector.ssa import RoundBasedFunction


[docs]class LoggingMode(enum.Enum): """Represent the options available for the information to log. Attributes: Silent: nothing is logged RoundOutputs: the plaintext and the output of each round is logged StepOutputs: the plaintext and the output of each step is logged Debug: similar as `StepOutputs`, but also logs debugging information """ Silent = enum.auto() RoundOutputs = enum.auto() StepOutputs = enum.auto() Debug = enum.auto()
[docs]class AESLikeFunction(RoundBasedFunction): """Base class to implemented AES-like functions. Subclasses need to set the integer attributes ``num_rounds``, ``num_rows``, ``num_columns`` and ``cell_width``. Subclasses also need to set ``sbox`` (resp. ``mix_columns_bit_matrix``) to use ``sub_cells`` (resp. ``mix_columns``). See `AESEncryption` for an example. Optionally, subclasses can set the attributes ``logging_mode`` and ``name*`` to customize the information to log (see `LoggingMode`, `BvFunction.get_formatted_logged_msgs` and `Characteristic.get_formatted_logged_msgs`). By default input/output bit-vector tuples are loaded column-wise. """ # num_rounds # num_rows, num_columns # cell_width # # num_extra_cells # sbox, mix_columns_bit_matrix # name_add_round_constants, name_sub_cells, name_shift_rows, name_mix_columns logging_mode = LoggingMode.Silent # - RoundBasedFunction attributes/methods @classmethod @property def input_widths(cls): return [cls.cell_width for _ in range(cls.num_rows*cls.num_columns)] # cls.num_extra_cells @classmethod @property def output_widths(cls): return [cls.cell_width for _ in range(cls.num_rows*cls.num_columns)]
[docs] @classmethod def set_num_rounds(cls, new_num_rounds): assert cls.num_rounds is not None cls.num_rounds = new_num_rounds
# - steps of AES-like functions @classmethod def add_round_constant(cls, my_matrix_state, my_matrix_constant): new_matrix_state = [row[:] for row in my_matrix_state] for row in range(len(my_matrix_constant)): for column in range(len(my_matrix_constant[row])): new_matrix_state[row][column] ^= my_matrix_constant[row][column] if cls.logging_mode not in [LoggingMode.Silent, LoggingMode.RoundOutputs]: cls.log_msg(f"output of {getattr(cls, 'name_add_round_constants', 'AddRoundConstant')}:") cls.log_activity_matrix_state(new_matrix_state) return new_matrix_state @classmethod def sub_cells(cls, my_matrix_state): new_matrix_state = [row[:] for row in my_matrix_state] for row in range(cls.num_rows): for column in range(cls.num_columns): new_matrix_state[row][column] = cls.sbox(my_matrix_state[row][column]) if cls.logging_mode not in [LoggingMode.Silent, LoggingMode.RoundOutputs]: cls.log_msg(f"output of {getattr(cls, 'name_sub_cells', 'SubCells')}:") cls.log_activity_matrix_state(new_matrix_state) return new_matrix_state @classmethod def shift_rows(cls, my_matrix_state): new_matrix_state = [row[:] for row in my_matrix_state] for r in range(cls.num_rows): offset = r # shift offset (0 for r=0, 1 for r=1, ...) for c in range(cls.num_columns): new_matrix_state[r][c] = my_matrix_state[r][(c + offset) % cls.num_columns] if cls.logging_mode not in [LoggingMode.Silent, LoggingMode.RoundOutputs]: cls.log_msg(f"output of {getattr(cls, 'name_shift_rows', 'ShiftRows')}:") cls.log_activity_matrix_state(new_matrix_state) return new_matrix_state @classmethod def mix_columns(cls, my_matrix_state): new_matrix_state = [row[:] for row in my_matrix_state] num_rows = cls.mix_columns_bit_matrix.arity[0] assert cls.num_rows >= num_rows cw = cls.cell_width extract_cell = [make_partial_propextract((r+1)*cw - 1, r*cw) for r in range(num_rows)] for index_column in range(cls.num_columns): input_column = [my_matrix_state[r][index_column] for r in range(num_rows)] output_column = cls.mix_columns_bit_matrix(*input_column) for r in range(num_rows): new_matrix_state[r][index_column] = extract_cell[r](output_column) if cls.logging_mode not in [LoggingMode.Silent, LoggingMode.RoundOutputs]: cls.log_msg(f"output of {getattr(cls, 'name_mix_columns', 'MixColumns')}:") cls.log_activity_matrix_state(new_matrix_state) return new_matrix_state # - auxiliary functions @classmethod def list2matrix(cls, my_list, num_rows=None, num_columns=None, column_order=True): # , ignore_num_extra_cells=False): if num_rows is None: num_rows = cls.num_rows if num_columns is None: num_columns = cls.num_columns my_matrix = [[None for _ in range(num_columns)] for _ in range(num_rows)] position = 0 if column_order: for column in range(num_columns): # assuming my_list ordered by columns for row in range(num_rows): assert isinstance(my_list[position], Term) and my_list[position].width == cls.cell_width my_matrix[row][column] = my_list[position] position += 1 else: for row in range(num_rows): for column in range(num_columns): # assuming my_list ordered by columns assert isinstance(my_list[position], Term) and my_list[position].width == cls.cell_width my_matrix[row][column] = my_list[position] position += 1 # if cls.num_extra_cells and ignore_num_extra_cells is not True: # my_matrix.append([None for _ in range(cls.num_extra_cells)]) # for cell in range(cls.num_extra_cells): # assert isinstance(my_list[position], Term) and my_list[position].width == cls.cell_width # my_matrix[-1][cell] = my_list[position] # position += 1 assert position == len(my_list) return my_matrix @classmethod def matrix2list(cls, my_matrix, num_rows=None, num_columns=None, column_order=True): # , ignore_num_extra_cells=False): if num_rows is None: num_rows = cls.num_rows if num_columns is None: num_columns = cls.num_columns my_list = [] if column_order: for column in range(num_columns): # sorting my_list by columns for row in range(num_rows): my_list.append(my_matrix[row][column]) else: for row in range(num_rows): for column in range(num_columns): my_list.append(my_matrix[row][column]) # if cls.num_extra_cells and ignore_num_extra_cells is not True: # my_list.extend(my_matrix[-1]) assert all(isinstance(x, Term) and x.width == cls.cell_width for x in my_list) return my_list @classmethod def log_activity_matrix_state(cls, my_matrix_state, return_format=False): def activity(my_bv): assert isinstance(my_bv, Term) and my_bv.width == cls.cell_width with Memoization(None): my_bv = BvNot(BvComp(my_bv, Constant(0, my_bv.width))) return my_bv # assert len(my_matrix_state) == cls.num_rows + (1 if cls.num_extra_cells else 0) format_string = "" format_field_objects = [] for row in range(cls.num_rows): if cls.logging_mode == LoggingMode.Debug: format_string += "(" + ','.join(["{}"]*cls.num_columns) + ")\t" format_field_objects.extend(my_matrix_state[row]) format_string += "(" + ','.join(["{}"]*cls.num_columns) + ")\n" format_field_objects.extend(activity(x) for x in my_matrix_state[row]) # if cls.num_extra_cells: # assert len(my_matrix_state[-1]) == cls.num_extra_cells # if cls.logging_mode == LoggingMode.Debug: # format_string += "[" + ','.join(["{}"]*cls.num_extra_cells) + "]\t" # format_field_objects.extend(my_matrix_state[-1]) # format_string += "[" + ','.join(["{}"]*cls.num_extra_cells) + "]\n" # format_field_objects.extend(activity(x) for x in my_matrix_state[-1]) format_string = format_string[:-1] # remove last "\n" if return_format: return format_string, format_field_objects else: cls.log_msg(format_string, format_field_objects=format_field_objects)