"""Xtea cipher."""
from arxpy.bitvector.core import Constant
from arxpy.primitives.primitives import KeySchedule, Encryption, Cipher
[docs]class XteaKeySchedule(KeySchedule):
"""Key schedule function."""
rounds = 64
input_widths = [32, 32, 32, 32]
output_widths = [32 for i in range(64)]
[docs] @classmethod
def set_rounds(cls, new_rounds):
cls.rounds = new_rounds
cls.output_widths = [32 for _ in range(new_rounds)]
[docs] @classmethod
def eval(cls, *master_key):
mk = list(master_key)
s = Constant(0, 32)
delta = Constant(0x9E3779B9, 32)
k = []
for i in range(cls.rounds):
if hasattr(cls, "skip_rounds") and i in cls.skip_rounds:
k.append(mk[0]) # cte outputs not supported
else:
if i % 2 == 0:
k.append(s + mk[int(s & Constant(3, 32))])
# s += delta
else:
k.append(s + mk[int((s >> Constant(11, 32)) & Constant(3, 32))])
if i % 2 == 0:
s += delta
return k
[docs]class XteaEncryption(Encryption):
"""Encryption function."""
rounds = 64
input_widths = [32, 32]
output_widths = [32, 32]
round_keys = None
[docs] @classmethod
def set_rounds(cls, new_rounds):
cls.rounds = new_rounds
[docs] @classmethod
def eval(cls, x, y):
v0 = x
v1 = y
k = cls.round_keys
cls.round_inputs = []
for i in range(cls.rounds):
cls.round_inputs.append([v0, v1])
if hasattr(cls, "skip_rounds") and i in cls.skip_rounds:
continue
v0, v1 = v1, v0 + ((((v1 << Constant(4, 32)) ^ (v1 >> Constant(5, 32))) + v1) ^ k[i])
cls.round_inputs.append([v0, v1])
return v0, v1
[docs]class XteaCipher(Cipher):
key_schedule = XteaKeySchedule
encryption = XteaEncryption
rounds = 64
[docs] @classmethod
def set_rounds(cls, new_rounds):
# assert new_rounds >= 2
cls.rounds = new_rounds
cls.encryption.set_rounds(new_rounds)
cls.key_schedule.set_rounds(new_rounds)
@classmethod
def set_skip_rounds(cls, skip_rounds):
assert isinstance(skip_rounds, (list, tuple))
cls.encryption.skip_rounds = skip_rounds
cls.key_schedule.skip_rounds = skip_rounds
[docs] @classmethod
def test(cls):
"""Test Xtea with official test vectors."""
# https://go.googlesource.com/crypto/+/master/xtea/xtea_test.go
plaintext = (0x41424344, 0x45464748)
key = (0, 0, 0, 0)
assert cls(plaintext, key) == (0xa0390589, 0xf8b8efa5)
plaintext = (0x41424344, 0x45464748)
key = (0x00010203, 0x04050607, 0x08090A0B, 0x0C0D0E0F)
assert cls(plaintext, key) == (0x497df3d0, 0x72612cb5)