"""Verify impossible differentials found by the SMT solver."""
import collections
import itertools
import math
import multiprocessing
import random
from arxpy.bitvector import core
from arxpy.smt.verification_differential import (
ssa2ccode, relatedssa2ccode, compile_run_empirical_weight
)
MAX_WEIGHT = 20 # pairs = 10 * (2**(MAX_WEIGHT), MW = 20, pairs = 10**7
KEY_SAMPLES = 256 # total complexity 2^{30}
[docs]def fast_empirical_weight(id_found, verbose_lvl=0, debug=False, filename=None):
"""Computes the empirical weight of the model using C code.
If ``filename`` is not ``None``, the output will be printed
to the given file rather than the to stdout.
The argument ``verbose_lvl`` can take an integer between
``0`` (no verbose) and ``3`` (full verbose).
>>> from arxpy.differential.difference import XorDiff, RXDiff
>>> from arxpy.differential.characteristic import BvCharacteristic
>>> from arxpy.primitives.chaskey import ChaskeyPi
>>> from arxpy.smt.search_impossible import SearchID
>>> from arxpy.smt.verification_impossible import fast_empirical_weight
>>> ChaskeyPi.set_rounds(2)
>>> ch = BvCharacteristic(ChaskeyPi, XorDiff, ["dv0", "dv1", "dv2", "dv3"])
>>> search_problem = SearchID(ch)
>>> id_found = search_problem.solve(2)
>>> fast_empirical_weight(id_found)
inf
>>> ch = BvCharacteristic(ChaskeyPi, RXDiff, ["dv0", "dv1", "dv2", "dv3"])
>>> search_problem = SearchID(ch)
>>> id_found = search_problem.solve(2)
>>> fast_empirical_weight(id_found)
inf
"""
from arxpy.smt.search_differential import _get_smart_print # avoid cyclic imports
smart_print = _get_smart_print(filename)
if debug:
smart_print("Symbolic characteristic:")
smart_print(id_found.ch)
smart_print("ID found:")
smart_print(id_found)
smart_print()
assert len(id_found.ch.nonlinear_diffs.items()) > 0
ssa = id_found.ch.ssa.copy()
ssa["assignments"] = list(ssa["assignments"])
ssa["output_vars"] = list(ssa["output_vars"])
# fixing duplicate var problem
var2diffval = {}
for diff_var, diff_value in itertools.chain(id_found.input_diff, id_found.output_diff):
var2diffval[diff_var.val] = diff_value.val
for j in range(len(ssa["output_vars"])):
var_j = ssa["output_vars"][j]
index_out = 0
if var_j in ssa["input_vars"]:
new_var = type(var_j)(var_j.name + "_o" + str(index_out), var_j.width)
index_out += 1
ssa["assignments"].append([new_var, var_j])
ssa["output_vars"][j] = new_var
var2diffval[new_var] = var2diffval[var_j]
for k in range(j + 1, len(ssa["output_vars"])):
if var_j == ssa["output_vars"][k]:
new_var = type(var_j)(var_j.name + "_o" + str(index_out), var_j.width)
index_out += 1
ssa["assignments"].append([new_var, var_j])
ssa["output_vars"][k] = new_var
var2diffval[new_var] = var2diffval[var_j]
ccode = ssa2ccode(ssa, id_found.ch.diff_type)
if verbose_lvl >= 3:
smart_print(" - ssa:", ssa) # pprint.pformat(ssa, width=100))
if debug:
smart_print(ccode[0])
smart_print(ccode[1])
smart_print()
input_diff_c = [v.xreplace(var2diffval) for v in ssa["input_vars"]]
output_diff_c = [v.xreplace(var2diffval) for v in ssa["output_vars"]]
if verbose_lvl >= 2:
smart_print(" - checking {} -> {} pairs 2**{}".format(
'|'.join([str(d) for d in input_diff_c]), '|'.join([str(d) for d in output_diff_c]),
MAX_WEIGHT))
input_diff_c = [int(d.val) for d in input_diff_c]
output_diff_c = [int(d.val) for d in output_diff_c]
assert all(isinstance(d, (int, core.Constant)) for d in input_diff_c), "{}".format(input_diff_c)
assert all(isinstance(d, (int, core.Constant)) for d in output_diff_c), "{}".format(output_diff_c)
current_empirical_weight = compile_run_empirical_weight(
ccode,
"_libver" + id_found.ch.func.__name__,
input_diff_c,
output_diff_c,
MAX_WEIGHT,
verbose=verbose_lvl >= 4)
if verbose_lvl >= 2:
smart_print(" - empirical weight: {}".format(current_empirical_weight))
if current_empirical_weight == math.inf:
return math.inf
else:
return current_empirical_weight
def _fast_empirical_weight_distribution(ch_found, cipher, rk_dict_diffs=None,
verbose_lvl=0, debug=False, filename=None, precision=0):
"""
>>> from arxpy.differential.difference import XorDiff
>>> from arxpy.differential.characteristic import SingleKeyCh
>>> from arxpy.smt.search_impossible import SearchSkID
>>> from arxpy.primitives import speck
>>> from arxpy.smt.verification_impossible import _fast_empirical_weight_distribution
>>> Speck32 = speck.get_Speck_instance(speck.SpeckInstance.speck_32_64)
>>> Speck32.set_rounds(1)
>>> ch = SingleKeyCh(Speck32, XorDiff)
>>> search_problem = SearchSkID(ch)
>>> id_found = search_problem.solve(2)
>>> _fast_empirical_weight_distribution(id_found, Speck32)
Counter({inf: 256})
"""
if rk_dict_diffs is not None:
raise ValueError("rk_dict_diffs must be None")
from arxpy.smt.search_differential import _get_smart_print # avoid cyclic imports
smart_print = _get_smart_print(filename)
# if rk_dict_diffs is not None:
# assert "nonlinear_diffs" in rk_dict_diffs and "output_diff" in rk_dict_diffs
if debug:
smart_print("Symbolic characteristic:")
smart_print(ch_found.ch)
smart_print("ID found:")
smart_print(ch_found)
# if rk_dict_diffs is not None:
# smart_print("rk_dict_diffs:", rk_dict_diffs)
smart_print()
# if rk_dict_diffs is not None:
# rk_var = [var.val for var, _ in rk_dict_diffs["output_diff"]]
# else:
rk_var = []
for i, width in enumerate(cipher.key_schedule.output_widths):
rk_var.append(core.Variable("k" + str(i), width))
var2diffval = {}
for diff_var, diff_value in itertools.chain(ch_found.input_diff, ch_found.output_diff):
var2diffval[diff_var.val] = diff_value.val
# if rk_dict_diffs is not None:
# for var, diff in rk_dict_diffs["output_diff"]:
# var2diffval[var.val] = diff.val
# for each related-key pair, we associated a pair of ssa
rkey2pair_ssa = [None for _ in range(KEY_SAMPLES)]
for key_index in range(KEY_SAMPLES):
master_key = []
for width in cipher.key_schedule.input_widths:
master_key.append(core.Constant(random.randrange(2 ** width), width))
rk_val = cipher.key_schedule(*master_key)
# if rk_dict_diffs is not None:
# rk_other_val = tuple([d.get_pair_element(r) for r, (_, d) in zip(rk_val, rk_dict_diffs["output_diff"])])
# else:
rk_other_val = rk_val
assert len(rk_var) == len(rk_other_val)
assert all(isinstance(rk, core.Constant) for rk in rk_val)
assert all(isinstance(rk, core.Constant) for rk in rk_other_val)
def replace_roundkeys(var2val):
new_ssa = ch_found.ch.ssa.copy()
new_ssa["assignments"] = list(new_ssa["assignments"])
new_ssa["output_vars"] = list(new_ssa["output_vars"])
for i, (var, expr) in enumerate(ch_found.ch.ssa["assignments"]):
new_ssa["assignments"][i] = (var, expr.xreplace(var2val))
return new_ssa
pair_ssa = []
for index_pair in range(2):
current_rk_val = rk_val if index_pair == 0 else rk_other_val
rkvar2rkval = {var: val for var, val in zip(rk_var, current_rk_val)}
ssa = replace_roundkeys(rkvar2rkval)
for j in range(len(ssa["output_vars"])):
var_j = ssa["output_vars"][j]
index_out = 0
if var_j in ssa["input_vars"]:
new_var = type(var_j)(var_j.name + "_o" + str(index_out), var_j.width)
index_out += 1
ssa["assignments"].append([new_var, var_j])
ssa["output_vars"][j] = new_var
var2diffval[new_var] = var2diffval[var_j]
for k in range(j + 1, len(ssa["output_vars"])):
if var_j == ssa["output_vars"][k]:
new_var = type(var_j)(var_j.name + "_o" + str(index_out), var_j.width)
index_out += 1
ssa["assignments"].append([new_var, var_j])
ssa["output_vars"][k] = new_var
var2diffval[new_var] = var2diffval[var_j]
pair_ssa.append(ssa)
rkey2pair_ssa[key_index] = pair_ssa
# for each related-key pair, we associated their weight
rkey2subch_ew = [0 for _ in range(KEY_SAMPLES)]
# start multiprocessing
with multiprocessing.Pool() as pool:
for key_index in range(KEY_SAMPLES):
ssa1 = rkey2pair_ssa[key_index][0]
ssa2 = rkey2pair_ssa[key_index][1]
if key_index <= 1:
if verbose_lvl >= 3:
smart_print(" - related-key pair index", key_index)
smart_print(" - ssa1:", ssa1)
if ssa1 == ssa2:
smart_print(" - ssa2: (same as ssa1)")
else:
smart_print(" - ssa2:", ssa2)
if ssa1 == ssa2:
ccode = ssa2ccode(ssa1, ch_found.ch.diff_type)
else:
ccode = relatedssa2ccode(ssa1, ssa2, ch_found.ch.diff_type)
if key_index <= 1 and debug:
smart_print(ccode[0])
smart_print(ccode[1])
smart_print()
input_diff_c = [v.xreplace(var2diffval) for v in ssa1["input_vars"]]
output_diff_c = [v.xreplace(var2diffval) for v in ssa1["output_vars"]]
if key_index <= 1 and verbose_lvl >= 2:
smart_print(" - rk{} | checking {} -> {} with pairs 2**{}".format(
key_index,
'|'.join([str(d) for d in input_diff_c]), '|'.join([str(d) for d in output_diff_c]),
MAX_WEIGHT))
assert all(isinstance(d, (int, core.Constant)) for d in input_diff_c), "{}".format(input_diff_c)
assert all(isinstance(d, (int, core.Constant)) for d in output_diff_c), "{}".format(output_diff_c)
input_diff_c = [int(d) for d in input_diff_c]
output_diff_c = [int(d) for d in output_diff_c]
rkey2subch_ew[key_index] = pool.apply_async(
compile_run_empirical_weight,
(
ccode,
"_libver" + ch_found.ch.func.__name__,
input_diff_c,
output_diff_c,
MAX_WEIGHT,
False
)
)
# wait until all have been compiled and run
# and replace the Async object by the result
for key_index in range(KEY_SAMPLES):
if isinstance(rkey2subch_ew[key_index], multiprocessing.pool.AsyncResult):
rkey2subch_ew[key_index] = rkey2subch_ew[key_index].get()
if key_index <= 1 and verbose_lvl >= 2:
smart_print(" - rk{} | empirical weight: {}".format(
key_index, rkey2subch_ew[key_index]))
# end multiprocessing
empirical_weight_distribution = collections.Counter()
all_rkey_weights = []
for key_index in range(KEY_SAMPLES):
rkey_weight = rkey2subch_ew[key_index]
if precision == 0:
weight = int(rkey_weight) if rkey_weight != math.inf else math.inf
else:
weight = round(rkey_weight, precision)
all_rkey_weights.append(rkey_weight)
empirical_weight_distribution[weight] += 1
if verbose_lvl >= 2:
smart_print("- distribution empirical weights: {}".format(empirical_weight_distribution))
if verbose_lvl >= 3:
smart_print("- list empirical weights:", [round(x, 8) for x in all_rkey_weights if x != math.inf])
return empirical_weight_distribution