Source code for arxpy.differential.difference

"""Manipulate differences."""
import collections

from arxpy.bitvector import core
from arxpy.bitvector import operation
from arxpy.bitvector import extraop


def _tuplify(seq):
    if isinstance(seq, collections.abc.Sequence):
        return tuple(seq)
    else:
        return tuple([seq])


[docs]class Difference(object): """Represent differences. The *difference* between two `Term` :math:`x` and :math:`y` is defined as :math:`\\alpha = y - x`, where the *difference operation* :math:`-` is a bit-vector `Operation`. In other words, the pair :math:`(x, x + \\alpha)` has difference :math:`\\alpha`, where :math:`+` is the inverse of the difference operation. The most common difference used in differential cryptanalysis is the XOR difference `XorDiff` (where the difference operation is `BvXor`). Other examples are the additive difference (where the difference operation is `BvSub`) or the rotational-XOR difference `RXDiff`. Note that arithmetic with differences is not supported. For example, two `Difference` objects ``d1`` and ``d2`` cannot be XORed, i.e., ``d1 ^ d2``. This can be done instead by performing the arithmetic with the difference values and converting the resulting `Term` to a difference, that is, ``Difference(d1.val ^ d2.val)`` This class is not meant to be instantiated but to provide a base class for the different types of differences. Attributes: val: a `Term` representing the value of the difference. diff_op: the difference `Operation`. inv_diff_op: the inverse of the difference operation. """ diff_op = None inv_diff_op = None def __init__(self, value): assert isinstance(value, core.Term) self.val = value def __str__(self): """Return the non-verbose string representation.""" return "{}({})".format(type(self).__name__, str(self.val)) __repr__ = __str__ def __hash__(self): return hash(self.val) def __eq__(self, other): if isinstance(other, type(self)): return self.val == other.val else: return False
[docs] def xreplace(self, rule): """Replace occurrences of differences within the expression. The argument ``rule`` is a dict-like object representing the replacement rule. This method is similar to SymPy `xreplace <https://docs.sympy.org/latest/modules/core.html?highlight=xreplace# sympy.core.basic.Basic.xreplace>`_ but with the restriction that only differences objects are allowed in ``rule``. """ for d in rule: assert isinstance(d, type(self)) and isinstance(rule[d], type(self)) rule = {d.val: rule[d].val for d in rule} return type(self)(self.val.xreplace(rule))
[docs] def vrepr(self): """Return a verbose string representation.""" return "{}({})".format(type(self).__name__, self.val.vrepr())
[docs] @classmethod def from_pair(cls, x, y): """Return the `Difference` :math:`\\alpha = y - x` given two `Term`.""" assert isinstance(x, core.Term) assert isinstance(y, core.Term) return cls(cls.diff_op(x, y)) # The order of the operands is important
[docs] def get_pair_element(self, x): """Return the `Term` :math:`y` such that :math:`y = \\alpha + x`.""" assert isinstance(x, core.Term) return self.inv_diff_op(x, self.val)
[docs] @classmethod def derivative(cls, op, input_diff): """Return the derivative of ``op`` at the point ``input_diff``. The derivative of an `Operation` :math:`f` at the point :math:`\\alpha` (also called the input difference) is defined as :math:`f_{\\alpha} (x) = f(x + \\alpha) - f(x)`. Note that :math:`f_{\\alpha} (x)` is the difference of :math:`(f(x), f(x + \\alpha))`. If :math:`f` has multiple operands, :math:`\\alpha` is a list containing the `Difference` of each operand and the computation :math:`x + \\alpha` is defined component-wise, that is, :math:`x = (x_1, \dots, x_n)`, :math:`\\alpha = (\\alpha_1, \dots, \\alpha_n)`, and :math:`x + \\alpha = (x_1 + \\alpha_1, \dots, x_n + \\alpha_n)`. For some operations, there is a unique output difference :math:`\\beta` for every input difference :math:`\\alpha`, that is, :math:`f_{\\alpha}(x) = \\beta` is a constant function. In this case, this method returns the `Difference` :math:`\\beta`. Otherwise, it returns a `Derivative` object representing :math:`f_{\\alpha}`. Operations with scalar operands are not supported, but these operands can be removed with `make_partial_operation` and the derivative of the resulting operator can then be computed. Args: op: a bit-vector operator input_diff: a list containing the difference of each operand """ raise NotImplementedError("subclasses need to override this method")
[docs]class XorDiff(Difference): """Represent XOR differences. The XOR difference of two `Term` is given by the XOR of the terms. In other words, the *difference operation* of `XorDiff` is the `BvXor` (see `Difference`). >>> from arxpy.bitvector.core import Constant, Variable >>> from arxpy.differential.difference import XorDiff >>> x, y = Constant(0b000, 3), Constant(0b000, 3) >>> alpha = XorDiff.from_pair(x, y) >>> alpha XorDiff(0b000) >>> alpha.get_pair_element(x) 0b000 >>> x, y = Constant(0b010, 3), Constant(0b101, 3) >>> alpha = XorDiff.from_pair(x, y) >>> alpha XorDiff(0b111) >>> alpha.get_pair_element(x) 0b101 >>> k = Variable("k", 8) >>> alpha = XorDiff.from_pair(k, k) >>> alpha XorDiff(0x00) >>> alpha.get_pair_element(k) k """ diff_op = operation.BvXor inv_diff_op = operation.BvXor
[docs] @classmethod def derivative(cls, op, input_diff): """Return the derivative of ``op`` at the point ``input_diff``. See `Difference.derivative` for more information. >>> from arxpy.bitvector.core import Variable, Constant >>> from arxpy.bitvector.operation import BvAdd, BvXor, RotateLeft, BvSub >>> from arxpy.bitvector.extraop import make_partial_operation >>> from arxpy.differential.difference import XorDiff >>> d1, d2 = XorDiff(Variable("d1", 8)), XorDiff(Variable("d2", 8)) >>> XorDiff.derivative(BvXor, [d1, d2]) XorDiff(d1 ^ d2) >>> Xor1 = make_partial_operation(BvXor, tuple([None, Constant(1, 8)])) >>> XorDiff.derivative(Xor1, d1) XorDiff(d1) >>> Rotate1 = make_partial_operation(RotateLeft, tuple([None, 1])) >>> XorDiff.derivative(Rotate1, d1) XorDiff(d1 <<< 1) >>> XorDiff.derivative(BvAdd, [d1, d2]) XDA(XorDiff(d1), XorDiff(d2)) >>> XorDiff.derivative(BvSub, [d1, d2]) XDS(XorDiff(d1), XorDiff(d2)) >>> CteAdd1 = make_partial_operation(BvAdd, tuple([None, Constant(1, 8)])) >>> XorDiff.derivative(CteAdd1, d1) XDCA_0x01(XorDiff(d1)) """ input_diff = _tuplify(input_diff) assert len(input_diff) == sum(op.arity) msg = "invalid arguments: op={}, input_diff={}".format( op.__name__, [d.vrepr() if isinstance(d, core.Term) else d for d in input_diff]) if not all(isinstance(diff, cls) for diff in input_diff): raise ValueError(msg) if op == operation.BvNot: return input_diff[0] if op == operation.BvXor: return cls(op(*[d.val for d in input_diff])) if op == operation.Concat: return cls(op(*[d.val for d in input_diff])) if op == operation.BvAdd: from arxpy.differential import derivative return derivative.XDA(input_diff) if op == operation.BvSub: from arxpy.differential import derivative return derivative.XDS(input_diff) if issubclass(op, extraop.PartialOperation): if op.base_op == operation.BvXor: assert len(input_diff) == 1 d1 = input_diff[0] val = op.fixed_args[0] if op.fixed_args[0] is not None else op.fixed_args[1] d2 = cls.from_pair(val, val) input_diff = [d1, d2] return cls(op.base_op(*[d.val for d in input_diff])) if op.base_op == operation.BvAnd: assert len(input_diff) == 1 d1 = input_diff[0] val = op.fixed_args[0] if op.fixed_args[0] is not None else op.fixed_args[1] if isinstance(val, core.Constant): return cls(op.base_op(d1.val, val)) if op.base_op in [operation.RotateLeft, operation.RotateRight]: if op.fixed_args[0] is None and op.fixed_args[1] is not None: assert len(input_diff) == 1 d = input_diff[0] return cls(op.base_op(d.val, op.fixed_args[1])) else: raise ValueError(msg) if op.base_op in [operation.BvShl, operation.BvLshr]: if op.fixed_args[0] is None and op.fixed_args[1] is not None: assert len(input_diff) == 1 d = input_diff[0] return cls(op.base_op(d.val, op.fixed_args[1])) else: raise ValueError(msg) if op.base_op == operation.Extract: if op.fixed_args[0] is None and op.fixed_args[1] is not None and op.fixed_args[2] is not None: assert len(input_diff) == 1 d = input_diff[0] return cls(op.base_op(d.val, op.fixed_args[1], op.fixed_args[2])) else: raise ValueError(msg) if op.base_op == operation.Concat: assert len(input_diff) == 1 d1 = input_diff[0] if op.fixed_args[0] is not None: val = op.fixed_args[0] input_diff = [cls.from_pair(val, val), d1] else: val = op.fixed_args[1] input_diff = [d1, cls.from_pair(val, val)] return cls(op.base_op(*[d.val for d in input_diff])) if op.base_op == operation.BvAdd: assert len(input_diff) == 1 d = input_diff[0] cte = op.fixed_args[0] if op.fixed_args[0] is not None else op.fixed_args[1] from arxpy.differential import derivative return derivative.XDCA(d, cte) else: raise ValueError(msg) if hasattr(op, "xor_derivative"): return op.xor_derivative(input_diff) raise ValueError(msg)
[docs]class RXOp(operation.Operation): """The difference operation of `RXDiff`.""" arity = [2, 0] is_symmetric = False is_simple = True
[docs] @classmethod def condition(cls, x, y): return x.width == y.width
[docs] @classmethod def output_width(cls, x, y): return x.width
[docs] @classmethod def eval(cls, x, y): return operation.RotateLeft(x, 1) ^ y
[docs]class RXInvOp(operation.Operation): """The inverse of the difference operation of `RXDiff`.""" arity = [2, 0] is_symmetric = False is_simple = True
[docs] @classmethod def condition(cls, x, d): return x.width == d.width
[docs] @classmethod def output_width(cls, x, d): return x.width
[docs] @classmethod def eval(cls, x, d): return operation.RotateLeft(x, 1) ^ d
[docs]class RXDiff(Difference): """Represent rotational-XOR (RX) differences. The pair ``(x, (x <<< 1) ^ d)`` has RX difference ``d``. In other words, the RX difference of two `Terms` ``x`` and ``y`` is defined as ``(x <<< 1) ^ y``. See `Difference` for more information. >>> from arxpy.bitvector.core import Constant, Variable >>> from arxpy.differential.difference import RXDiff >>> x, y = Constant(0b000, 3), Constant(0b000, 3) >>> alpha = RXDiff.from_pair(x, y) >>> alpha RXDiff(0b000) >>> alpha.get_pair_element(x) 0b000 >>> x, y = Constant(0b000, 3), Constant(0b001, 3) >>> alpha = RXDiff.from_pair(x, y) >>> alpha RXDiff(0b001) >>> alpha.get_pair_element(x) 0b001 >>> k = Variable("k", 8) >>> alpha = RXDiff.from_pair(k, k) >>> alpha RXDiff(k ^ (k <<< 1)) >>> alpha.get_pair_element(k) k """ diff_op = RXOp inv_diff_op = RXInvOp
[docs] @classmethod def derivative(cls, op, input_diff): """Return the derivative of ``op`` at the point ``input_diff``. See `Difference.derivative` for more information. >>> from arxpy.bitvector.core import Variable, Constant >>> from arxpy.bitvector.operation import BvAdd, BvXor, RotateLeft >>> from arxpy.bitvector.extraop import make_partial_operation >>> from arxpy.differential.difference import RXDiff >>> d1, d2 = RXDiff(Variable("d1", 8)), RXDiff(Variable("d2", 8)) >>> RXDiff.derivative(BvXor, [d1, d2]) RXDiff(d1 ^ d2) >>> Xor1 = make_partial_operation(BvXor, tuple([None, Constant(1, 8)])) >>> RXDiff.derivative(Xor1, d1) RXDiff(0x03 ^ d1) >>> Rotate1 = make_partial_operation(RotateLeft, tuple([None, 1])) >>> RXDiff.derivative(Rotate1, d1) RXDiff(d1 <<< 1) >>> RXDiff.derivative(BvAdd, [d1, d2]) RXDA(RXDiff(d1), RXDiff(d2)) """ input_diff = _tuplify(input_diff) assert len(input_diff) == sum(op.arity) msg = "invalid arguments: op={}, input_diff={}".format( op.__name__, [d.vrepr() if isinstance(d, core.Term) else d for d in input_diff]) if not all(isinstance(diff, cls) for diff in input_diff): raise ValueError(msg) if op == operation.BvNot: return input_diff[0] if op == operation.BvXor: return cls(op(*[d.val for d in input_diff])) if op == operation.BvAdd: from arxpy.differential import derivative return derivative.RXDA(input_diff) # Concact, BvSub if issubclass(op, extraop.PartialOperation): if op.base_op == operation.BvXor: assert len(input_diff) == 1 d1 = input_diff[0] val = op.fixed_args[0] if op.fixed_args[0] is not None else op.fixed_args[1] d2 = cls.from_pair(val, val) input_diff = [d1, d2] return cls(op.base_op(*[d.val for d in input_diff])) if op.base_op in [operation.RotateLeft, operation.RotateRight]: if op.fixed_args[0] is None and op.fixed_args[1] is not None: assert len(input_diff) == 1 d = input_diff[0] return cls(op.base_op(d.val, op.fixed_args[1])) else: raise ValueError(msg) # RX-model of BvAddCte not implemented (approximation with BvAdd too weak) # RX-model of BvShl and BvLshr not implemented (non-linear w.r.t RX-diffs) if hasattr(op, "rx_derivative"): return op.rx_derivative(input_diff) raise ValueError(msg)