feat(util): add AES XTS and its unit test

This commit is contained in:
dongheng
2019-09-19 17:29:57 +08:00
parent 653d20dddc
commit ffcac9caca
3 changed files with 321 additions and 3 deletions

View File

@ -528,7 +528,7 @@ static void __esp_aes_encrypt(esp_aes_t *aes, const void *p_src, void *p_dst)
PUT_UINT32_LE(X3, output, 12);
}
static int __esp_aes_decrypt(esp_aes_t *aes, const void *p_src, void *p_dst)
static void __esp_aes_decrypt(esp_aes_t *aes, const void *p_src, void *p_dst)
{
int i;
uint32_t *RK, X0, X1, X2, X3, Y0, Y1, Y2, Y3;
@ -573,8 +573,6 @@ static int __esp_aes_decrypt(esp_aes_t *aes, const void *p_src, void *p_dst)
PUT_UINT32_LE(X1, output, 4);
PUT_UINT32_LE(X2, output, 8);
PUT_UINT32_LE(X3, output, 12);
return 0;
}
int esp_aes_encrypt(esp_aes_t *aes,
@ -864,3 +862,147 @@ int esp_aes_encrypt_ctr(esp_aes_t *aes,
return 0;
}
static void aes_xts_decode_keys(const uint8_t *key,
size_t keybits,
const uint8_t **key1,
size_t *key1bits,
const uint8_t **key2,
size_t *key2bits)
{
const size_t half_keybits = keybits / 2;
const size_t half_keybytes = half_keybits / 8;
*key1bits = half_keybits;
*key2bits = half_keybits;
*key1 = &key[0];
*key2 = &key[half_keybytes];
}
static void aes_gf128mul_x_ble(uint8_t *r, const uint8_t *i)
{
uint64_t x[2], y[2];
memcpy(x, i, 16);
y[0] = (x[0] << 1) ^ 0x0087 >> (8 - ((x[1] >> 63 ) << 3));
y[1] = (x[0] >> 63) | (x[1] << 1);
memcpy(r, y, 16);
}
int esp_aes_xts_set_encrypt_key(esp_aes_xts_t *ctx, const void *p_key, size_t keybits)
{
int ret;
const uint8_t *key1, *key2;
size_t key1bits, key2bits;
const uint8_t *key = (const uint8_t *)p_key;
util_assert(ctx);
util_assert(key);
if (keybits != 256 && keybits != 512)
return -EINVAL;
aes_xts_decode_keys(key, keybits, &key1, &key1bits, &key2, &key2bits);
ret = esp_aes_set_encrypt_key(&ctx->tweak, key2, key2bits);
if (ret)
return ret;
return esp_aes_set_encrypt_key(&ctx->crypt, key1, key1bits);
}
int esp_aes_xts_set_decrypt_key(esp_aes_xts_t *ctx, const void *p_key, size_t keybits)
{
int ret;
const uint8_t *key1, *key2;
size_t key1bits, key2bits;
const uint8_t *key = (const uint8_t *)p_key;
util_assert(ctx);
util_assert(key);
if (keybits != 256 && keybits != 512)
return -EINVAL;
aes_xts_decode_keys(key, keybits, &key1, &key1bits, &key2, &key2bits);
ret = esp_aes_set_encrypt_key(&ctx->tweak, key2, key2bits);
if (ret)
return ret;
return esp_aes_set_decrypt_key(&ctx->crypt, key1, key1bits);
}
int esp_aes_crypt_xts(esp_aes_xts_t *ctx,
int encrypt,
size_t length,
const void *p_data_unit,
const void *p_src,
void *p_dst)
{
size_t blocks = length / 16;
size_t leftover = length % 16;
uint8_t tweak[16];
uint8_t prev_tweak[16];
uint8_t tmp[16];
void (*crypt_func)(esp_aes_t *aes, const void *p_src, void *p_dst);
const uint8_t *data_unit = (const uint8_t *)p_data_unit;
const uint8_t *input = (const uint8_t *)p_src;
uint8_t *output = (uint8_t *)p_dst;
util_assert(ctx);
util_assert(data_unit);
util_assert(input);
util_assert(output);
if (length < 16 || (length > (1 << 20) * 16))
return -EINVAL;
crypt_func = encrypt ? __esp_aes_encrypt : __esp_aes_decrypt;
__esp_aes_encrypt(&ctx->tweak, data_unit, tweak);
while (blocks--) {
if (blocks == 0 && leftover && !encrypt) {
memcpy(prev_tweak, tweak, sizeof(tweak));
aes_gf128mul_x_ble(tweak, tweak);
}
for (int i = 0; i < 16; i++)
tmp[i] = input[i] ^ tweak[i];
crypt_func(&ctx->crypt, tmp, tmp);
for (int i = 0; i < 16; i++)
output[i] = tmp[i] ^ tweak[i];
aes_gf128mul_x_ble(tweak, tweak);
output += 16;
input += 16;
}
if (leftover) {
int i;
uint8_t *t = encrypt ? tweak : prev_tweak;
uint8_t *prev_output = output - 16;
for (i = 0; i < leftover; i++) {
output[i] = prev_output[i];
tmp[i] = input[i] ^ t[i];
}
for (; i < 16; i++ )
tmp[i] = prev_output[i] ^ t[i];
crypt_func(&ctx->crypt, tmp, tmp);
for (i = 0; i < 16; i++)
prev_output[i] = tmp[i] ^ t[i];
}
return 0;
}