"""Verify characteristics found by the SMT solver.
To this end, a characteristic is split into sub-characteristics
where each sub-ch has up to `MAX_WEIGHT` weight.
"""
import collections
import importlib.util
import itertools
import math
import multiprocessing
import random
import tempfile
from arxpy.bitvector import operation, extraop, core
from arxpy.differential import difference
import cffi
MAX_WEIGHT = 20 # pairs = 10 * (2**(MAX_WEIGHT), MW = 20, pairs = 10**7
KEY_SAMPLES = 256 # total complexity 2^{30}
[docs]def bv2ccode(bv):
"""Convert a bit-vector type to C code.
Args:
bv: the bit-vector `Term` to convert
::
>>> from arxpy.bitvector.core import Constant, Variable
>>> from arxpy.bitvector.operation import RotateLeft
>>> from arxpy.smt.verification_differential import bv2ccode
>>> x, y = Variable("x", 8), Variable("y", 8)
>>> bv2ccode(x | y)
'x | y'
>>> bv2ccode(x + y)
'(x + y) & 255'
>>> bv2ccode(RotateLeft(x, 1))
'((x << 1) | (x >> 7)) & 255'
>>> bv2ccode((~x) + y)
Traceback (most recent call last):
...
ValueError: nested bit-vector operations are not supported
"""
# Note that every operation involving a shift or
# an arithmetic operation (add) requires the result to be masked.
if issubclass(type(bv), extraop.PartialOperation):
raise NotImplementedError("PartialOperation is not supported")
# only variables or constants
if not all(isinstance(arg, (int, core.Constant, core.Variable)) for arg in bv.args):
raise ValueError("nested bit-vector operations are not supported")
if isinstance(bv, (core.Constant, core.Variable)):
return str(bv)
if type(bv) == operation.BvNot:
return str(bv)
elif type(bv) == operation.BvAnd:
return str(bv)
elif type(bv) == operation.BvOr:
return str(bv)
elif type(bv) == operation.BvXor:
return str(bv)
elif type(bv) in [operation.BvShl, operation.BvLshr]:
return "({}) & {}".format(bv, 2**bv.width - 1)
elif type(bv) == operation.RotateLeft:
# #define ROTL1(a) ((a << 1) | (a >> (width - 1))) & (2**width - 1)
x, r = bv.args
return "(({0} << {1}) | ({0} >> {2})) & {3}".format(x, r, x.width - r, 2**x.width - 1)
elif type(bv) == operation.RotateRight:
# #define ROTL1(a) ((a >> 1) | (a << width - 1)) & (2**width - 1)
x, r = bv.args
return "(({0} >> {1}) | ({0} << {2})) & {3}".format(x, r, x.width - r, 2**x.width - 1)
elif type(bv) == operation.BvAdd:
x, y = bv.args
return "({} + {}) & {}".format(x, y, 2**x.width - 1)
elif type(bv) == operation.BvSub:
x, y = bv.args
m = 2**x.width - 1
# no need to mask (result is <= 2**n)
return f"({x} >= {y}) ? {x} - {y} : {m} - {y} + {x}"
from arxpy.primitives.simon import SimonRF
from arxpy.primitives.shacal1 import BvIf, BvMaj
from arxpy.primitives.multi2 import BvOr as Multi2BvOr
if isinstance(bv, SimonRF):
# ((x <<< a) & (x <<< b)) ^ (x <<< c)
return "( (({}) & ({})) ^ ({}) ) & {}".format(
bv2ccode(operation.RotateLeft(bv.args[0], SimonRF.a)),
bv2ccode(operation.RotateLeft(bv.args[0], SimonRF.b)),
bv2ccode(operation.RotateLeft(bv.args[0], SimonRF.c)),
2 ** bv.args[0].width -1
)
elif isinstance(bv, BvIf):
# (x & y) | ((~x) & z)
x, y, z = bv.args
return "({0} & {1}) | ((~{0}) & {2})".format(x, y, z)
elif isinstance(bv, BvMaj):
# (x & y) | (x & z) | (y & z)
x, y, z = bv.args
return "({0} & {1}) | ({0} & {2}) | ({1} & {2})".format(x, y, z)
elif isinstance(bv, Multi2BvOr):
x, y = bv.args
return "{} | {}".format(x, y)
from arxpy.linear.correlation import BvIdentity
if isinstance(bv, BvIdentity):
return "{}".format(bv.args[0])
else:
raise ValueError("invalid operation: {}".format(type(bv)))
# elif type(bv) == operation.Ite:
# elif type(bv) == operation.BvComp:
# elif type(bv) == operation.BvUlt:
# elif type(bv) == operation.BvUle:
# elif type(bv) == operation.BvUgt:
# elif type(bv) == operation.BvUge:
# elif type(bv) == operation.Extract:
# elif type(bv) == operation.Concat:
# elif type(bv) == operation.ZeroExtend:
# elif type(bv) == operation.Repeat:
# elif type(bv) == operation.BvNeg:
# elif type(bv) == operation.BvSub:
# elif type(bv) == operation.BvMul:
# elif type(bv) == operation.BvUdiv:
# elif type(bv) == operation.BvUrem:
# no need to include <stdint.h>
# input testing:
# printf("input %d | %u %u\\n", j, input1[j], input2[j]);
# output testing (XorDiff):
# printf("%d | out1=%u out2=%u diff_out=%u th_diff_out=%u\\n", j, output1[j], output2[j], output1[j] ^ output2[j], output_diff[j]);
get_num_valid_pairs_header = "static unsigned long get_num_valid_pairs({ctype_input} input_diff[], {ctype_output} output_diff[], unsigned long pair_samples, unsigned int seed);"
get_num_valid_pairs_source = """
static unsigned long get_num_valid_pairs({ctype_input} input_diff[], {ctype_output} output_diff[], unsigned long pair_samples, unsigned int seed){{
if (seed == 0) srand((unsigned int) time (NULL));
else srand(seed);
unsigned long num_valid_pairs = 0;
unsigned long i = 0;
unsigned int j = 0;
{ctype_input} input1[{len_input}];
{ctype_input} input2[{len_input}];
{ctype_output} output1[{len_output}];
{ctype_output} output2[{len_output}];
for (; i < pair_samples; ++i) {{
for (j = 0; j < {len_input}; ++j) {{
input1[j] = rand();
input2[j] = {input_diff_ccode};
}}
{eval1}({input1}, {output1});
{eval2}({input2}, {output2});
for (j = 0; j < {len_output}; ++j) {{
if ( ({output_diff_ccode}) != output_diff[j] ) break;
if (j == {len_output} - 1) num_valid_pairs += 1;
}}
}}
return num_valid_pairs;
}}"""
[docs]def ssa2ccode(ssa, diff_type):
"""Return the C code to compute a differential probability over a function in ssa form.
>>> from arxpy.differential.difference import XorDiff, RXDiff
>>> from arxpy.primitives.chaskey import ChaskeyPi
>>> from arxpy.smt.verification_differential import ssa2ccode
>>> ChaskeyPi.set_rounds(1)
>>> ssa = ChaskeyPi.ssa(["v0", "v1", "v2", "v3"], "x") # doctest: +NORMALIZE_WHITESPACE
>>> header, source = ssa2ccode(ssa, XorDiff)
>>> print(header)
static unsigned long get_num_valid_pairs(uint32_t input_diff[], uint32_t output_diff[], unsigned long pair_samples, unsigned int seed);
>>> print(source) # doctest:+NORMALIZE_WHITESPACE
void eval(uint32_t v0, uint32_t v1, uint32_t v2, uint32_t v3, uint32_t* x7, uint32_t* x12, uint32_t* x13, uint32_t* x9){
uint32_t x0 = (v0 + v1) & 4294967295;
uint32_t x1 = ((v1 << 5) | (v1 >> 27)) & 4294967295;
uint32_t x2 = x0 ^ x1;
uint32_t x3 = ((x0 << 16) | (x0 >> 16)) & 4294967295;
uint32_t x4 = (v2 + v3) & 4294967295;
uint32_t x5 = ((v3 << 8) | (v3 >> 24)) & 4294967295;
uint32_t x6 = x4 ^ x5;
*x7 = (x3 + x6) & 4294967295;
uint32_t x8 = ((x6 << 13) | (x6 >> 19)) & 4294967295;
*x9 = *x7 ^ x8;
uint32_t x10 = (x2 + x4) & 4294967295;
uint32_t x11 = ((x2 << 7) | (x2 >> 25)) & 4294967295;
*x12 = x10 ^ x11;
*x13 = ((x10 << 16) | (x10 >> 16)) & 4294967295;
};
static unsigned long get_num_valid_pairs(uint32_t input_diff[], uint32_t output_diff[], unsigned long pair_samples, unsigned int seed){
if (seed == 0) srand((unsigned int) time (NULL));
else srand(seed);
unsigned long num_valid_pairs = 0;
unsigned long i = 0;
unsigned int j = 0;
uint32_t input1[4];
uint32_t input2[4];
uint32_t output1[4];
uint32_t output2[4];
for (; i < pair_samples; ++i) {
for (j = 0; j < 4; ++j) {
input1[j] = rand();
input2[j] = input1[j] ^ input_diff[j];
}
eval(input1[0], input1[1], input1[2], input1[3], &output1[0], &output1[1], &output1[2], &output1[3]);
eval(input2[0], input2[1], input2[2], input2[3], &output2[0], &output2[1], &output2[2], &output2[3]);
for (j = 0; j < 4; ++j) {
if ( (output1[j] ^ output2[j]) != output_diff[j] ) break;
if (j == 4 - 1) num_valid_pairs += 1;
}
}
return num_valid_pairs;
}
>>> header, source = ssa2ccode(ssa, RXDiff)
>>> print(header)
static unsigned long get_num_valid_pairs(uint32_t input_diff[], uint32_t output_diff[], unsigned long pair_samples, unsigned int seed);
>>> print(source) # doctest:+NORMALIZE_WHITESPACE
void eval(uint32_t v0, uint32_t v1, uint32_t v2, uint32_t v3, uint32_t* x7, uint32_t* x12, uint32_t* x13, uint32_t* x9){
uint32_t x0 = (v0 + v1) & 4294967295;
uint32_t x1 = ((v1 << 5) | (v1 >> 27)) & 4294967295;
uint32_t x2 = x0 ^ x1;
uint32_t x3 = ((x0 << 16) | (x0 >> 16)) & 4294967295;
uint32_t x4 = (v2 + v3) & 4294967295;
uint32_t x5 = ((v3 << 8) | (v3 >> 24)) & 4294967295;
uint32_t x6 = x4 ^ x5;
*x7 = (x3 + x6) & 4294967295;
uint32_t x8 = ((x6 << 13) | (x6 >> 19)) & 4294967295;
*x9 = *x7 ^ x8;
uint32_t x10 = (x2 + x4) & 4294967295;
uint32_t x11 = ((x2 << 7) | (x2 >> 25)) & 4294967295;
*x12 = x10 ^ x11;
*x13 = ((x10 << 16) | (x10 >> 16)) & 4294967295;
};
static unsigned long get_num_valid_pairs(uint32_t input_diff[], uint32_t output_diff[], unsigned long pair_samples, unsigned int seed){
if (seed == 0) srand((unsigned int) time (NULL));
else srand(seed);
unsigned long num_valid_pairs = 0;
unsigned long i = 0;
unsigned int j = 0;
uint32_t input1[4];
uint32_t input2[4];
uint32_t output1[4];
uint32_t output2[4];
for (; i < pair_samples; ++i) {
for (j = 0; j < 4; ++j) {
input1[j] = rand();
input2[j] = ( ((input1[j] << 1) | (input1[j] >> 31)) ^ input_diff[j] ) & 4294967295;
}
eval(input1[0], input1[1], input1[2], input1[3], &output1[0], &output1[1], &output1[2], &output1[3]);
eval(input2[0], input2[1], input2[2], input2[3], &output2[0], &output2[1], &output2[2], &output2[3]);
for (j = 0; j < 4; ++j) {
if ( (( ((output1[j] << 1) | (output1[j] >> 31)) ^ output2[j] ) & 4294967295) != output_diff[j] ) break;
if (j == 4 - 1) num_valid_pairs += 1;
}
}
return num_valid_pairs;
}
"""
name_foo = "eval"
width2type = {
8: "uint8_t",
16: "uint16_t",
32: "uint32_t",
64: "uint64_t"
}
input_vars = ssa["input_vars"]
output_vars = ssa["output_vars"]
input_vars_c = ["{} {}".format(width2type[v.width], v.name) for v in input_vars]
output_vars_c = ["{}* {}".format(width2type[v.width], v.name) for v in output_vars]
outvar2outvar_c = {v: type(v)("*" + v.name, v.width) for v in output_vars}
eval_ccode = "void {}({}, {}){{\n".format(name_foo, ', '.join(input_vars_c), ', '.join(output_vars_c))
for var, expr in ssa["assignments"]:
expr = expr.xreplace(outvar2outvar_c)
if var in output_vars:
eval_ccode += "\t*{} = {};\n".format(var, bv2ccode(expr))
else:
eval_ccode += "\t{} {} = {};\n".format(width2type[var.width], var, bv2ccode(expr))
eval_ccode += "};"
assert all(input_vars[0].width == v.width for v in input_vars)
assert all(output_vars[0].width == v.width for v in output_vars)
if diff_type == difference.XorDiff:
input_diff_ccode = "input1[j] ^ input_diff[j]"
output_diff_ccode = "output1[j] ^ output2[j]"
elif diff_type == difference.RXDiff:
# operation.RotateLeft(a, 1) ^ b
# # ROTL1(a) ((a << 1) | (a >> (width - 1))) & (2**width - 1)
input_diff_ccode = "( ((input1[j] << 1) | (input1[j] >> {})) ^ input_diff[j] ) & {}".format(
input_vars[0].width - 1, 2**input_vars[0].width - 1)
output_diff_ccode = "( ((output1[j] << 1) | (output1[j] >> {})) ^ output2[j] ) & {}".format(
output_vars[0].width - 1, 2**output_vars[0].width - 1)
else:
raise ValueError("invalid diff_type")
get_num_valid_pairs_source_formatted = get_num_valid_pairs_source.format(
ctype_input=width2type[input_vars[0].width],
ctype_output=width2type[output_vars[0].width],
len_input=len(input_vars),
len_output=len(output_vars),
eval1=name_foo,
eval2=name_foo,
input1=', '.join(["input1[{}]".format(i) for i in range(len(input_vars))]),
output1=', '.join(["&output1[{}]".format(i) for i in range(len(output_vars))]),
input2=', '.join(["input2[{}]".format(i) for i in range(len(input_vars))]),
output2=', '.join(["&output2[{}]".format(i) for i in range(len(output_vars))]),
input_diff_ccode=input_diff_ccode,
output_diff_ccode=output_diff_ccode,
)
header_ccode = get_num_valid_pairs_header.format(
ctype_input=width2type[input_vars[0].width],
ctype_output=width2type[output_vars[0].width],
)
source_ccode = eval_ccode + get_num_valid_pairs_source_formatted
return header_ccode, source_ccode
[docs]def compile_run_empirical_weight(ccode, module_name, input_diff, output_diff, target_weight, verbose=False):
"""Compile and execute the C code to compute the empirical weight
>>> from arxpy.differential.difference import XorDiff, RXDiff
>>> from arxpy.primitives.chaskey import ChaskeyPi
>>> from arxpy.smt.verification_differential import ssa2ccode, compile_run_empirical_weight
>>> ChaskeyPi.set_rounds(1)
>>> ssa = ChaskeyPi.ssa(["v0", "v1", "v2", "v3"], "x") # doctest: +NORMALIZE_WHITESPACE
>>> header, source = ssa2ccode(ssa, XorDiff)
>>> ind, outd, tw = [0, 0, 0, 0], [0, 0, 0, 0], 2
>>> compile_run_empirical_weight([header, source], "_libverChaskeyPi", ind, outd, tw)
0.0
>>> ind, outd, tw = [0, 0, 0, 0], [0, 0, 0, 1], 2
>>> compile_run_empirical_weight([header, source], "_libverChaskeyPi", ind, outd, tw)
inf
>>> header, source = ssa2ccode(ssa, RXDiff)
>>> ind, outd, tw = [0, 0, 0, 0], [0, 0, 0, 0], 8
>>> 4 <= compile_run_empirical_weight([header, source], "_libverChaskeyPi", ind, outd, tw) <= 8
True
>>> ind, outd, tw = [0, 0, 0, 0], [1, 1, 1, 1], 8
>>> compile_run_empirical_weight([header, source], "_libverChaskeyPi", ind, outd, tw)
inf
"""
ffibuilder = cffi.FFI()
ffibuilder.cdef(ccode[0])
ffibuilder.set_source(module_name, ccode[1])
with tempfile.TemporaryDirectory() as tmpdirname:
libver_path = ffibuilder.compile(tmpdir=tmpdirname, verbose=verbose)
spec = importlib.util.spec_from_file_location(module_name, libver_path)
libver = importlib.util.module_from_spec(spec)
spec.loader.exec_module(libver)
pair_samples = max(5 * (2**(int(target_weight) + 1)), 100)
seed = tuple(input_diff + output_diff + [pair_samples])
num_pairs = libver.lib.get_num_valid_pairs(input_diff, output_diff, pair_samples, hash(seed) % (2**32))
return math.inf if num_pairs == 0 else abs(-math.log2(num_pairs * 1.0 / pair_samples))
[docs]def fast_empirical_weight(ch_found, verbose_lvl=0, debug=False, filename=None):
"""Computes the empirical weight of the model using C code.
If ``filename`` is not ``None``, the output will be printed
to the given file rather than the to stdout.
The argument ``verbose_lvl`` can take an integer between
``0`` (no verbose) and ``3`` (full verbose).
>>> from arxpy.differential.difference import XorDiff, RXDiff
>>> from arxpy.differential.characteristic import BvCharacteristic
>>> from arxpy.primitives.chaskey import ChaskeyPi
>>> from arxpy.smt.search_differential import SearchCh
>>> from arxpy.smt.verification_differential import fast_empirical_weight
>>> ChaskeyPi.set_rounds(2)
>>> ch = BvCharacteristic(ChaskeyPi, XorDiff, ["dv0", "dv1", "dv2", "dv3"])
>>> search_problem = SearchCh(ch)
>>> ch_found = search_problem.solve(0)
>>> ch_found.ch_weight
0x04
>>> 3 <= fast_empirical_weight(ch_found) <= 5
True
>>> ChaskeyPi.set_rounds(1)
>>> ch = BvCharacteristic(ChaskeyPi, RXDiff, ["dv0", "dv1", "dv2", "dv3"])
>>> ic = [operation.BvComp(0, d.val) for d in ch.input_diff]
>>> ic += [operation.BvComp(0, d[1].val) for d in ch.output_diff]
>>> ch_found = SearchCh(ch, allow_zero_input_diff=True, initial_constraints=ic).solve(5)
>>> ch_found.ch_weight
0x05
>>> 4 - 1 <= fast_empirical_weight(ch_found) <= 8
True
"""
from arxpy.smt.search_differential import _get_smart_print # avoid cyclic imports
smart_print = _get_smart_print(filename)
exact_weight = ch_found.get_exact_weight()
if debug:
smart_print("Symbolic characteristic:")
smart_print(ch_found.ch)
smart_print("Characteristic found:")
smart_print(ch_found)
smart_print()
der_weights = []
for i, (diff, der) in enumerate(ch_found.ch.nonlinear_diffs.items()):
actual_diff = ch_found.nonlinear_diffs[i][1]
new_input_diff = [(d.xreplace(ch_found._diff_model)) for d in der.input_diff]
der_weights.append(der._replace_input_diff(new_input_diff).exact_weight(actual_diff))
max_subch_weight = exact_weight if exact_weight < MAX_WEIGHT else exact_weight / (exact_weight / MAX_WEIGHT)
max_subch_weight = max(1, max_subch_weight, *der_weights)
if debug:
smart_print("max_subch_weight:", max_subch_weight)
smart_print()
subch_listdiffder = [[]] # for each subch, a list of [diff, der] pairs
subch_index = 0
current_subch_weight = 0 # exact_weight
subch_weight = [] # the weight of each subch
assert len(ch_found.ch.nonlinear_diffs.items()) > 0
for i, (diff, der) in enumerate(ch_found.ch.nonlinear_diffs.items()):
der_weight = der_weights[i]
if current_subch_weight + der_weight > max_subch_weight:
subch_weight.append(current_subch_weight)
current_subch_weight = 0
subch_index += 1
subch_listdiffder.append([])
current_subch_weight += der_weight
subch_listdiffder[subch_index].append([diff, der])
subch_weight.append(current_subch_weight)
assert len(subch_weight) == len(subch_listdiffder)
num_subch = len(subch_listdiffder)
if verbose_lvl >= 3:
smart_print("- characteristic decomposed into {} subcharacteristics with exact weights {}".format(
num_subch, subch_weight
))
def subch_listdiffder2subch_ssa(listdiffder, first_var_next_subch, first_subch=False):
first_var = listdiffder[0][0].val
input_vars = []
inter_vars = set()
assignments = []
add_assignment = first_subch
for var, expr in ch_found.ch.ssa["assignments"]:
if var == first_var:
add_assignment = True
elif var == first_var_next_subch:
break
if add_assignment:
input_vars.extend([atom for atom in expr.atoms(core.Variable) if atom not in input_vars])
inter_vars.add(var)
assignments.append([var, expr])
subch_ssa = {}
subch_ssa["input_vars"] = [var for var in input_vars if var not in inter_vars]
subch_ssa["output_vars"] = []
subch_ssa["inter_vars"] = inter_vars
subch_ssa["assignments"] = assignments
return subch_ssa
subch_ssa = [None for _ in range(num_subch)]
for i in reversed(range(num_subch)):
if i == num_subch - 1:
subch_ssa[i] = subch_listdiffder2subch_ssa(subch_listdiffder[i], None, i == 0)
subch_ssa[i]["output_vars"] = list(ch_found.ch.ssa["output_vars"])
else:
first_var_next_ssa = subch_listdiffder[i + 1][0][0].val
subch_ssa[i] = subch_listdiffder2subch_ssa(subch_listdiffder[i], first_var_next_ssa, i == 0)
subch_ssa[i]["output_vars"] = subch_ssa[i + 1]["input_vars"][:]
for diff_var in subch_ssa[i]["output_vars"]:
if diff_var not in subch_ssa[i]["inter_vars"] and diff_var not in subch_ssa[i]["input_vars"]:
subch_ssa[i]["input_vars"].append(diff_var)
del subch_ssa[i]["inter_vars"]
subch_ssa[i]["weight"] = subch_weight[i]
var2diffval = {}
for diff_var, diff_value in itertools.chain(ch_found.input_diff, ch_found.nonlinear_diffs, ch_found.output_diff):
var2diffval[diff_var.val] = diff_value.val
for var, diff in ch_found.ch._var2diff.items():
if var not in var2diffval:
var2diffval[var] = diff.val.xreplace(var2diffval)
# fixing duplicate var problem
for ssa in subch_ssa:
for j in range(len(ssa["output_vars"])):
var_j = ssa["output_vars"][j]
index_out = 0
if var_j in ssa["input_vars"]:
new_var = type(var_j)(var_j.name + "_o" + str(index_out), var_j.width)
index_out += 1
ssa["assignments"].append([new_var, var_j])
ssa["output_vars"][j] = new_var
var2diffval[new_var] = var2diffval[var_j]
for k in range(j + 1, len(ssa["output_vars"])):
if var_j == ssa["output_vars"][k]:
new_var = type(var_j)(var_j.name + "_o" + str(index_out), var_j.width)
index_out += 1
ssa["assignments"].append([new_var, var_j])
ssa["output_vars"][k] = new_var
var2diffval[new_var] = var2diffval[var_j]
total_empirical_weight = 0
for i, ssa in enumerate(subch_ssa):
ccode = ssa2ccode(ssa, ch_found.ch.diff_type)
if verbose_lvl >= 2:
smart_print("- sub-characteristic {}".format(i))
if verbose_lvl >= 3:
smart_print(" - ssa:", subch_ssa[i]) # pprint.pformat(list_ssa[i], width=100))
smart_print(" - listdiffder:", subch_listdiffder[i]) # pprint.pformat(ssa_ders[i], width=100))
if debug:
smart_print(ccode[0])
smart_print(ccode[1])
smart_print()
input_diff_c = [v.xreplace(var2diffval) for v in ssa["input_vars"]]
output_diff_c = [v.xreplace(var2diffval) for v in ssa["output_vars"]]
if verbose_lvl >= 2:
smart_print(" - checking {} -> {} with weight {}".format(
'|'.join([str(d) for d in input_diff_c]), '|'.join([str(d) for d in output_diff_c]),
ssa["weight"]))
input_diff_c = [int(d.val) for d in input_diff_c]
output_diff_c = [int(d.val) for d in output_diff_c]
assert all(isinstance(d, (int, core.Constant)) for d in input_diff_c), "{}".format(input_diff_c)
assert all(isinstance(d, (int, core.Constant)) for d in output_diff_c), "{}".format(output_diff_c)
current_empirical_weight = compile_run_empirical_weight(
ccode,
"_libver" + ch_found.ch.func.__name__ + str(i),
input_diff_c,
output_diff_c,
ssa["weight"],
verbose=verbose_lvl >= 4)
if verbose_lvl >= 2:
smart_print(" - exact/empirical weight: {}, {}".format(ssa["weight"], current_empirical_weight))
if current_empirical_weight == math.inf:
return math.inf
total_empirical_weight += current_empirical_weight
return total_empirical_weight
def _fast_empirical_weight_distribution(ch_found, cipher, rk_dict_diffs=None,
verbose_lvl=0, debug=False, filename=None, precision=0):
"""
>>> from arxpy.differential.difference import XorDiff
>>> from arxpy.differential.characteristic import SingleKeyCh
>>> from arxpy.smt.search_differential import SearchSkCh
>>> from arxpy.primitives import speck
>>> from arxpy.smt.verification_differential import _fast_empirical_weight_distribution
>>> Speck32 = speck.get_Speck_instance(speck.SpeckInstance.speck_32_64)
>>> Speck32.set_rounds(1)
>>> ch = SingleKeyCh(Speck32, XorDiff)
>>> search_problem = SearchSkCh(ch)
>>> ch_found = search_problem.solve(0)
>>> _fast_empirical_weight_distribution(ch_found, Speck32)
Counter({0: 256})
"""
# similar to _empirical_distribution_weight of characteristic module
from arxpy.smt.search_differential import _get_smart_print # avoid cyclic imports
smart_print = _get_smart_print(filename)
exact_weight = ch_found.get_exact_weight()
if rk_dict_diffs is not None:
assert "nonlinear_diffs" in rk_dict_diffs and "output_diff" in rk_dict_diffs
if debug:
smart_print("Symbolic characteristic:")
smart_print(ch_found.ch)
smart_print("Characteristic found:")
smart_print(ch_found)
if rk_dict_diffs is not None:
smart_print("rk_dict_diffs:", rk_dict_diffs)
smart_print()
der_weights = []
for i, (diff, der) in enumerate(ch_found.ch.nonlinear_diffs.items()):
actual_diff = ch_found.nonlinear_diffs[i][1]
new_input_diff = [(d.xreplace(ch_found._diff_model)) for d in der.input_diff]
der_weights.append(der._replace_input_diff(new_input_diff).exact_weight(actual_diff))
max_subch_weight = exact_weight if exact_weight < MAX_WEIGHT else exact_weight / (exact_weight / MAX_WEIGHT)
max_subch_weight = max(1, max_subch_weight, *der_weights)
if debug:
smart_print("max_subch_weight:", max_subch_weight)
smart_print()
subch_listdiffder = [[]] # for each subch, a list of [diff, der] pairs
subch_index = 0
current_subch_weight = 0 # exact_weight
subch_weight = [] # the weight of each subch
assert len(ch_found.ch.nonlinear_diffs.items()) > 0
for i, (diff, der) in enumerate(ch_found.ch.nonlinear_diffs.items()):
der_weight = der_weights[i]
if current_subch_weight + der_weight > max_subch_weight:
subch_weight.append(current_subch_weight)
current_subch_weight = 0
subch_index += 1
subch_listdiffder.append([])
current_subch_weight += der_weight
subch_listdiffder[subch_index].append([diff, der])
subch_weight.append(current_subch_weight)
assert len(subch_weight) == len(subch_listdiffder)
num_subch = len(subch_listdiffder)
if verbose_lvl >= 3:
smart_print("- characteristic decomposed into {} subcharacteristics with exact weights {}".format(
num_subch, subch_weight
))
if rk_dict_diffs is not None:
rk_var = [var.val for var, _ in rk_dict_diffs["output_diff"]]
else:
rk_var = []
for i, width in enumerate(cipher.key_schedule.output_widths):
rk_var.append(core.Variable("k" + str(i), width))
var2diffval = {}
for diff_var, diff_value in itertools.chain(ch_found.input_diff, ch_found.nonlinear_diffs, ch_found.output_diff):
var2diffval[diff_var.val] = diff_value.val
if rk_dict_diffs is not None:
for var, diff in rk_dict_diffs["output_diff"]:
var2diffval[var.val] = diff.val
for var, diff in rk_dict_diffs["nonlinear_diffs"]:
var2diffval[var.val] = diff.val
for var, diff in ch_found.ch._var2diff.items():
if var not in var2diffval:
if isinstance(diff, core.Term):
# e.g., symbolic computations with the key
var2diffval[var] = diff.xreplace(var2diffval)
else:
var2diffval[var] = diff.val.xreplace(var2diffval)
# for each related-key pair, we associated a pair of subch_ssa
rkey2pair_subchssa = [None for _ in range(KEY_SAMPLES)]
for key_index in range(KEY_SAMPLES):
master_key = []
for width in cipher.key_schedule.input_widths:
master_key.append(core.Constant(random.randrange(2 ** width), width))
rk_val = cipher.key_schedule(*master_key)
if rk_dict_diffs is not None:
rk_other_val = tuple([d.get_pair_element(r) for r, (_, d) in zip(rk_val, rk_dict_diffs["output_diff"])])
else:
rk_other_val = rk_val
assert len(rk_var) == len(rk_other_val)
assert all(isinstance(rk, core.Constant) for rk in rk_val)
assert all(isinstance(rk, core.Constant) for rk in rk_other_val)
def subch_listdiffder2subch_ssa(listdiffder, first_der_var_next_subch, var2val, first_subch=False):
first_der_var = listdiffder[0][0].val
input_vars = []
inter_vars = set()
assignments = []
add_assignment = first_subch
for var, expr in ch_found.ch.ssa["assignments"]:
if var == first_der_var:
add_assignment = True
elif var == first_der_var_next_subch:
break
expr = expr.xreplace(var2val)
if add_assignment:
input_vars.extend([atom for atom in expr.atoms(core.Variable) if atom not in input_vars])
inter_vars.add(var)
assignments.append([var, expr])
subch_ssa = {}
subch_ssa["input_vars"] = [var for var in input_vars if var not in inter_vars]
subch_ssa["output_vars"] = []
subch_ssa["inter_vars"] = inter_vars
subch_ssa["assignments"] = assignments
return subch_ssa
pair_subchssa = []
for index_pair in range(2):
current_rk_val = rk_val if index_pair == 0 else rk_other_val
rkvar2rkval = {var: val for var, val in zip(rk_var, current_rk_val)}
subch_ssa = [None for _ in range(num_subch)]
for i in reversed(range(num_subch)):
if i == num_subch - 1:
subch_ssa[i] = subch_listdiffder2subch_ssa(subch_listdiffder[i], None, rkvar2rkval, i == 0)
subch_ssa[i]["output_vars"] = list(ch_found.ch.ssa["output_vars"])
else:
first_var_next_ssa = subch_listdiffder[i + 1][0][0].val
subch_ssa[i] = subch_listdiffder2subch_ssa(subch_listdiffder[i], first_var_next_ssa, rkvar2rkval, i == 0)
subch_ssa[i]["output_vars"] = subch_ssa[i + 1]["input_vars"][:]
for var in subch_ssa[i]["output_vars"]:
if var not in subch_ssa[i]["inter_vars"] and var not in subch_ssa[i]["input_vars"]:
subch_ssa[i]["input_vars"].append(var)
del subch_ssa[i]["inter_vars"]
subch_ssa[i]["weight"] = subch_weight[i]
for _, ssa in enumerate(subch_ssa):
for j in range(len(ssa["output_vars"])):
var_j = ssa["output_vars"][j]
index_out = 0
if var_j in ssa["input_vars"]:
new_var = type(var_j)(var_j.name + "_o" + str(index_out), var_j.width)
index_out += 1
ssa["assignments"].append([new_var, var_j])
ssa["output_vars"][j] = new_var
var2diffval[new_var] = var2diffval[var_j]
for k in range(j + 1, len(ssa["output_vars"])):
if var_j == ssa["output_vars"][k]:
new_var = type(var_j)(var_j.name + "_o" + str(index_out), var_j.width)
index_out += 1
ssa["assignments"].append([new_var, var_j])
ssa["output_vars"][k] = new_var
var2diffval[new_var] = var2diffval[var_j]
pair_subchssa.append(subch_ssa)
assert len(pair_subchssa[0]) == len(pair_subchssa[1]) == num_subch
rkey2pair_subchssa[key_index] = pair_subchssa
# for each related-key pair, we associated the list of the weight of each subch
rkey2subch_ew = [[0 for _ in range(num_subch)] for _ in range(KEY_SAMPLES)]
# start multiprocessing
with multiprocessing.Pool() as pool:
for i in range(num_subch):
for key_index in range(KEY_SAMPLES):
ssa1 = rkey2pair_subchssa[key_index][0][i]
ssa2 = rkey2pair_subchssa[key_index][1][i]
if key_index <= 1:
if verbose_lvl >= 2 and key_index == 0:
smart_print("- sub-characteristic {}".format(i))
if verbose_lvl >= 3 and key_index == 0:
smart_print(" - listdiffder:", subch_listdiffder[i])
if verbose_lvl >= 3:
smart_print(" - related-key pair index", key_index)
smart_print(" - ssa1:", ssa1)
if ssa1 == ssa2:
smart_print(" - ssa2: (same as ssa1)")
else:
smart_print(" - ssa2:", ssa2)
if i > 0 and rkey2subch_ew[key_index][i-1] == math.inf:
rkey2subch_ew[key_index][i] = math.inf
if key_index <= 1 and verbose_lvl >= 2:
smart_print(" - rk{} | skipping since invalid sub-ch[{}]".format(key_index, i-1))
continue
if ssa1 == ssa2:
ccode = ssa2ccode(ssa1, ch_found.ch.diff_type)
else:
ccode = relatedssa2ccode(ssa1, ssa2, ch_found.ch.diff_type)
if key_index <= 1 and debug:
smart_print(ccode[0])
smart_print(ccode[1])
smart_print()
input_diff_c = [v.xreplace(var2diffval) for v in ssa1["input_vars"]]
output_diff_c = [v.xreplace(var2diffval) for v in ssa1["output_vars"]]
if key_index <= 1 and verbose_lvl >= 2:
smart_print(" - rk{} | checking {} -> {} with weight {}".format(
key_index,
'|'.join([str(d) for d in input_diff_c]), '|'.join([str(d) for d in output_diff_c]),
ssa1["weight"]))
assert ssa1["weight"] == ssa2["weight"]
assert all(isinstance(d, (int, core.Constant)) for d in input_diff_c), "{}".format(input_diff_c)
assert all(isinstance(d, (int, core.Constant)) for d in output_diff_c), "{}".format(output_diff_c)
input_diff_c = [int(d) for d in input_diff_c]
output_diff_c = [int(d) for d in output_diff_c]
rkey2subch_ew[key_index][i] = pool.apply_async(
compile_run_empirical_weight,
(
ccode,
"_libver" + ch_found.ch.func.__name__ + str(i),
input_diff_c,
output_diff_c,
ssa1["weight"],
False
)
)
# wait until all i-subch have been compiled and run
# and replace the Async object by the result
for key_index in range(KEY_SAMPLES):
if isinstance(rkey2subch_ew[key_index][i], multiprocessing.pool.AsyncResult):
rkey2subch_ew[key_index][i] = rkey2subch_ew[key_index][i].get()
if key_index <= 1 and verbose_lvl >= 2:
smart_print(" - rk{} | exact/empirical weight: {}, {}".format(
key_index, subch_weight[i], rkey2subch_ew[key_index][i]))
# end multiprocessing
empirical_weight_distribution = collections.Counter()
all_rkey_weights = []
for key_index in range(KEY_SAMPLES):
rkey_weight = sum(rkey2subch_ew[key_index])
if precision == 0:
weight = int(rkey_weight) if rkey_weight != math.inf else math.inf
else:
weight = round(rkey_weight, precision)
all_rkey_weights.append(rkey_weight)
empirical_weight_distribution[weight] += 1
if verbose_lvl >= 2:
smart_print("- distribution empirical weights: {}".format(empirical_weight_distribution))
if verbose_lvl >= 3:
smart_print("- list empirical weights:", [round(x, 8) for x in all_rkey_weights if x != math.inf])
return empirical_weight_distribution