crypto: arm64/aes-blk - add support for CTS-CBC mode
Currently, we rely on the generic CTS chaining mode wrapper to
instantiate the cts(cbc(aes)) skcipher. Due to the high performance
of the ARMv8 Crypto Extensions AES instructions (~1 cycles per byte),
any overhead in the chaining mode layers is amplified, and so it pays
off considerably to fold the CTS handling into the SIMD routines.
On Cortex-A53, this results in a ~50% speedup for smaller input sizes.
Signed-off-by: Ard Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
diff --git a/arch/arm64/crypto/aes-glue.c b/arch/arm64/crypto/aes-glue.c
index 1c69345..26d2b02 100644
--- a/arch/arm64/crypto/aes-glue.c
+++ b/arch/arm64/crypto/aes-glue.c
@@ -15,6 +15,7 @@
#include <crypto/internal/hash.h>
#include <crypto/internal/simd.h>
#include <crypto/internal/skcipher.h>
+#include <crypto/scatterwalk.h>
#include <linux/module.h>
#include <linux/cpufeature.h>
#include <crypto/xts.h>
@@ -31,6 +32,8 @@
#define aes_ecb_decrypt ce_aes_ecb_decrypt
#define aes_cbc_encrypt ce_aes_cbc_encrypt
#define aes_cbc_decrypt ce_aes_cbc_decrypt
+#define aes_cbc_cts_encrypt ce_aes_cbc_cts_encrypt
+#define aes_cbc_cts_decrypt ce_aes_cbc_cts_decrypt
#define aes_ctr_encrypt ce_aes_ctr_encrypt
#define aes_xts_encrypt ce_aes_xts_encrypt
#define aes_xts_decrypt ce_aes_xts_decrypt
@@ -45,6 +48,8 @@ MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
#define aes_ecb_decrypt neon_aes_ecb_decrypt
#define aes_cbc_encrypt neon_aes_cbc_encrypt
#define aes_cbc_decrypt neon_aes_cbc_decrypt
+#define aes_cbc_cts_encrypt neon_aes_cbc_cts_encrypt
+#define aes_cbc_cts_decrypt neon_aes_cbc_cts_decrypt
#define aes_ctr_encrypt neon_aes_ctr_encrypt
#define aes_xts_encrypt neon_aes_xts_encrypt
#define aes_xts_decrypt neon_aes_xts_decrypt
@@ -73,6 +78,11 @@ asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
int rounds, int blocks, u8 iv[]);
+asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
+ int rounds, int bytes, u8 const iv[]);
+asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
+ int rounds, int bytes, u8 const iv[]);
+
asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
int rounds, int blocks, u8 ctr[]);
@@ -87,6 +97,12 @@ asmlinkage void aes_mac_update(u8 const in[], u32 const rk[], int rounds,
int blocks, u8 dg[], int enc_before,
int enc_after);
+struct cts_cbc_req_ctx {
+ struct scatterlist sg_src[2];
+ struct scatterlist sg_dst[2];
+ struct skcipher_request subreq;
+};
+
struct crypto_aes_xts_ctx {
struct crypto_aes_ctx key1;
struct crypto_aes_ctx __aligned(8) key2;
@@ -209,6 +225,136 @@ static int cbc_decrypt(struct skcipher_request *req)
return err;
}
+static int cts_cbc_init_tfm(struct crypto_skcipher *tfm)
+{
+ crypto_skcipher_set_reqsize(tfm, sizeof(struct cts_cbc_req_ctx));
+ return 0;
+}
+
+static int cts_cbc_encrypt(struct skcipher_request *req)
+{
+ struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+ struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+ struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
+ int err, rounds = 6 + ctx->key_length / 4;
+ int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
+ struct scatterlist *src = req->src, *dst = req->dst;
+ struct skcipher_walk walk;
+
+ skcipher_request_set_tfm(&rctx->subreq, tfm);
+
+ if (req->cryptlen == AES_BLOCK_SIZE)
+ cbc_blocks = 1;
+
+ if (cbc_blocks > 0) {
+ unsigned int blocks;
+
+ skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
+ cbc_blocks * AES_BLOCK_SIZE,
+ req->iv);
+
+ err = skcipher_walk_virt(&walk, &rctx->subreq, false);
+
+ while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+ kernel_neon_begin();
+ aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+ ctx->key_enc, rounds, blocks, walk.iv);
+ kernel_neon_end();
+ err = skcipher_walk_done(&walk,
+ walk.nbytes % AES_BLOCK_SIZE);
+ }
+ if (err)
+ return err;
+
+ if (req->cryptlen == AES_BLOCK_SIZE)
+ return 0;
+
+ dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
+ rctx->subreq.cryptlen);
+ if (req->dst != req->src)
+ dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
+ rctx->subreq.cryptlen);
+ }
+
+ /* handle ciphertext stealing */
+ skcipher_request_set_crypt(&rctx->subreq, src, dst,
+ req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
+ req->iv);
+
+ err = skcipher_walk_virt(&walk, &rctx->subreq, false);
+ if (err)
+ return err;
+
+ kernel_neon_begin();
+ aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+ ctx->key_enc, rounds, walk.nbytes, walk.iv);
+ kernel_neon_end();
+
+ return skcipher_walk_done(&walk, 0);
+}
+
+static int cts_cbc_decrypt(struct skcipher_request *req)
+{
+ struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+ struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+ struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
+ int err, rounds = 6 + ctx->key_length / 4;
+ int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
+ struct scatterlist *src = req->src, *dst = req->dst;
+ struct skcipher_walk walk;
+
+ skcipher_request_set_tfm(&rctx->subreq, tfm);
+
+ if (req->cryptlen == AES_BLOCK_SIZE)
+ cbc_blocks = 1;
+
+ if (cbc_blocks > 0) {
+ unsigned int blocks;
+
+ skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
+ cbc_blocks * AES_BLOCK_SIZE,
+ req->iv);
+
+ err = skcipher_walk_virt(&walk, &rctx->subreq, false);
+
+ while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+ kernel_neon_begin();
+ aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+ ctx->key_dec, rounds, blocks, walk.iv);
+ kernel_neon_end();
+ err = skcipher_walk_done(&walk,
+ walk.nbytes % AES_BLOCK_SIZE);
+ }
+ if (err)
+ return err;
+
+ if (req->cryptlen == AES_BLOCK_SIZE)
+ return 0;
+
+ dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
+ rctx->subreq.cryptlen);
+ if (req->dst != req->src)
+ dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
+ rctx->subreq.cryptlen);
+ }
+
+ /* handle ciphertext stealing */
+ skcipher_request_set_crypt(&rctx->subreq, src, dst,
+ req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
+ req->iv);
+
+ err = skcipher_walk_virt(&walk, &rctx->subreq, false);
+ if (err)
+ return err;
+
+ kernel_neon_begin();
+ aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+ ctx->key_dec, rounds, walk.nbytes, walk.iv);
+ kernel_neon_end();
+
+ return skcipher_walk_done(&walk, 0);
+}
+
static int ctr_encrypt(struct skcipher_request *req)
{
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
@@ -336,6 +482,25 @@ static struct skcipher_alg aes_algs[] = { {
.decrypt = cbc_decrypt,
}, {
.base = {
+ .cra_name = "__cts(cbc(aes))",
+ .cra_driver_name = "__cts-cbc-aes-" MODE,
+ .cra_priority = PRIO,
+ .cra_flags = CRYPTO_ALG_INTERNAL,
+ .cra_blocksize = 1,
+ .cra_ctxsize = sizeof(struct crypto_aes_ctx),
+ .cra_module = THIS_MODULE,
+ },
+ .min_keysize = AES_MIN_KEY_SIZE,
+ .max_keysize = AES_MAX_KEY_SIZE,
+ .ivsize = AES_BLOCK_SIZE,
+ .chunksize = AES_BLOCK_SIZE,
+ .walksize = 2 * AES_BLOCK_SIZE,
+ .setkey = skcipher_aes_setkey,
+ .encrypt = cts_cbc_encrypt,
+ .decrypt = cts_cbc_decrypt,
+ .init = cts_cbc_init_tfm,
+}, {
+ .base = {
.cra_name = "__ctr(aes)",
.cra_driver_name = "__ctr-aes-" MODE,
.cra_priority = PRIO,
diff --git a/arch/arm64/crypto/aes-modes.S b/arch/arm64/crypto/aes-modes.S
index 35632d1..9697eda 100644
--- a/arch/arm64/crypto/aes-modes.S
+++ b/arch/arm64/crypto/aes-modes.S
@@ -171,6 +171,84 @@
/*
+ * aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
+ * int rounds, int bytes, u8 const iv[])
+ * aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
+ * int rounds, int bytes, u8 const iv[])
+ */
+
+AES_ENTRY(aes_cbc_cts_encrypt)
+ adr_l x8, .Lcts_permute_table
+ sub x4, x4, #16
+ add x9, x8, #32
+ add x8, x8, x4
+ sub x9, x9, x4
+ ld1 {v3.16b}, [x8]
+ ld1 {v4.16b}, [x9]
+
+ ld1 {v0.16b}, [x1], x4 /* overlapping loads */
+ ld1 {v1.16b}, [x1]
+
+ ld1 {v5.16b}, [x5] /* get iv */
+ enc_prepare w3, x2, x6
+
+ eor v0.16b, v0.16b, v5.16b /* xor with iv */
+ tbl v1.16b, {v1.16b}, v4.16b
+ encrypt_block v0, w3, x2, x6, w7
+
+ eor v1.16b, v1.16b, v0.16b
+ tbl v0.16b, {v0.16b}, v3.16b
+ encrypt_block v1, w3, x2, x6, w7
+
+ add x4, x0, x4
+ st1 {v0.16b}, [x4] /* overlapping stores */
+ st1 {v1.16b}, [x0]
+ ret
+AES_ENDPROC(aes_cbc_cts_encrypt)
+
+AES_ENTRY(aes_cbc_cts_decrypt)
+ adr_l x8, .Lcts_permute_table
+ sub x4, x4, #16
+ add x9, x8, #32
+ add x8, x8, x4
+ sub x9, x9, x4
+ ld1 {v3.16b}, [x8]
+ ld1 {v4.16b}, [x9]
+
+ ld1 {v0.16b}, [x1], x4 /* overlapping loads */
+ ld1 {v1.16b}, [x1]
+
+ ld1 {v5.16b}, [x5] /* get iv */
+ dec_prepare w3, x2, x6
+
+ tbl v2.16b, {v1.16b}, v4.16b
+ decrypt_block v0, w3, x2, x6, w7
+ eor v2.16b, v2.16b, v0.16b
+
+ tbx v0.16b, {v1.16b}, v4.16b
+ tbl v2.16b, {v2.16b}, v3.16b
+ decrypt_block v0, w3, x2, x6, w7
+ eor v0.16b, v0.16b, v5.16b /* xor with iv */
+
+ add x4, x0, x4
+ st1 {v2.16b}, [x4] /* overlapping stores */
+ st1 {v0.16b}, [x0]
+ ret
+AES_ENDPROC(aes_cbc_cts_decrypt)
+
+ .section ".rodata", "a"
+ .align 6
+.Lcts_permute_table:
+ .byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+ .byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+ .byte 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7
+ .byte 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf
+ .byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+ .byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+ .previous
+
+
+ /*
* aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
* int blocks, u8 ctr[])
*/
@@ -253,7 +331,6 @@
ins v4.d[0], x7
b .Lctrcarrydone
AES_ENDPROC(aes_ctr_encrypt)
- .ltorg
/*