crypto: arm64/aes-blk - add support for CTS-CBC mode

Currently, we rely on the generic CTS chaining mode wrapper to
instantiate the cts(cbc(aes)) skcipher. Due to the high performance
of the ARMv8 Crypto Extensions AES instructions (~1 cycles per byte),
any overhead in the chaining mode layers is amplified, and so it pays
off considerably to fold the CTS handling into the SIMD routines.

On Cortex-A53, this results in a ~50% speedup for smaller input sizes.

Signed-off-by: Ard Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
diff --git a/arch/arm64/crypto/aes-glue.c b/arch/arm64/crypto/aes-glue.c
index 1c69345..26d2b02 100644
--- a/arch/arm64/crypto/aes-glue.c
+++ b/arch/arm64/crypto/aes-glue.c
@@ -15,6 +15,7 @@
 #include <crypto/internal/hash.h>
 #include <crypto/internal/simd.h>
 #include <crypto/internal/skcipher.h>
+#include <crypto/scatterwalk.h>
 #include <linux/module.h>
 #include <linux/cpufeature.h>
 #include <crypto/xts.h>
@@ -31,6 +32,8 @@
 #define aes_ecb_decrypt		ce_aes_ecb_decrypt
 #define aes_cbc_encrypt		ce_aes_cbc_encrypt
 #define aes_cbc_decrypt		ce_aes_cbc_decrypt
+#define aes_cbc_cts_encrypt	ce_aes_cbc_cts_encrypt
+#define aes_cbc_cts_decrypt	ce_aes_cbc_cts_decrypt
 #define aes_ctr_encrypt		ce_aes_ctr_encrypt
 #define aes_xts_encrypt		ce_aes_xts_encrypt
 #define aes_xts_decrypt		ce_aes_xts_decrypt
@@ -45,6 +48,8 @@ MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
 #define aes_ecb_decrypt		neon_aes_ecb_decrypt
 #define aes_cbc_encrypt		neon_aes_cbc_encrypt
 #define aes_cbc_decrypt		neon_aes_cbc_decrypt
+#define aes_cbc_cts_encrypt	neon_aes_cbc_cts_encrypt
+#define aes_cbc_cts_decrypt	neon_aes_cbc_cts_decrypt
 #define aes_ctr_encrypt		neon_aes_ctr_encrypt
 #define aes_xts_encrypt		neon_aes_xts_encrypt
 #define aes_xts_decrypt		neon_aes_xts_decrypt
@@ -73,6 +78,11 @@ asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
 asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
 				int rounds, int blocks, u8 iv[]);
 
+asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
+				int rounds, int bytes, u8 const iv[]);
+asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
+				int rounds, int bytes, u8 const iv[]);
+
 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
 				int rounds, int blocks, u8 ctr[]);
 
@@ -87,6 +97,12 @@ asmlinkage void aes_mac_update(u8 const in[], u32 const rk[], int rounds,
 			       int blocks, u8 dg[], int enc_before,
 			       int enc_after);
 
+struct cts_cbc_req_ctx {
+	struct scatterlist sg_src[2];
+	struct scatterlist sg_dst[2];
+	struct skcipher_request subreq;
+};
+
 struct crypto_aes_xts_ctx {
 	struct crypto_aes_ctx key1;
 	struct crypto_aes_ctx __aligned(8) key2;
@@ -209,6 +225,136 @@ static int cbc_decrypt(struct skcipher_request *req)
 	return err;
 }
 
+static int cts_cbc_init_tfm(struct crypto_skcipher *tfm)
+{
+	crypto_skcipher_set_reqsize(tfm, sizeof(struct cts_cbc_req_ctx));
+	return 0;
+}
+
+static int cts_cbc_encrypt(struct skcipher_request *req)
+{
+	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+	struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
+	int err, rounds = 6 + ctx->key_length / 4;
+	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
+	struct scatterlist *src = req->src, *dst = req->dst;
+	struct skcipher_walk walk;
+
+	skcipher_request_set_tfm(&rctx->subreq, tfm);
+
+	if (req->cryptlen == AES_BLOCK_SIZE)
+		cbc_blocks = 1;
+
+	if (cbc_blocks > 0) {
+		unsigned int blocks;
+
+		skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
+					   cbc_blocks * AES_BLOCK_SIZE,
+					   req->iv);
+
+		err = skcipher_walk_virt(&walk, &rctx->subreq, false);
+
+		while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+			kernel_neon_begin();
+			aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+					ctx->key_enc, rounds, blocks, walk.iv);
+			kernel_neon_end();
+			err = skcipher_walk_done(&walk,
+						 walk.nbytes % AES_BLOCK_SIZE);
+		}
+		if (err)
+			return err;
+
+		if (req->cryptlen == AES_BLOCK_SIZE)
+			return 0;
+
+		dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
+					     rctx->subreq.cryptlen);
+		if (req->dst != req->src)
+			dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
+					       rctx->subreq.cryptlen);
+	}
+
+	/* handle ciphertext stealing */
+	skcipher_request_set_crypt(&rctx->subreq, src, dst,
+				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
+				   req->iv);
+
+	err = skcipher_walk_virt(&walk, &rctx->subreq, false);
+	if (err)
+		return err;
+
+	kernel_neon_begin();
+	aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+			    ctx->key_enc, rounds, walk.nbytes, walk.iv);
+	kernel_neon_end();
+
+	return skcipher_walk_done(&walk, 0);
+}
+
+static int cts_cbc_decrypt(struct skcipher_request *req)
+{
+	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+	struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
+	int err, rounds = 6 + ctx->key_length / 4;
+	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
+	struct scatterlist *src = req->src, *dst = req->dst;
+	struct skcipher_walk walk;
+
+	skcipher_request_set_tfm(&rctx->subreq, tfm);
+
+	if (req->cryptlen == AES_BLOCK_SIZE)
+		cbc_blocks = 1;
+
+	if (cbc_blocks > 0) {
+		unsigned int blocks;
+
+		skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
+					   cbc_blocks * AES_BLOCK_SIZE,
+					   req->iv);
+
+		err = skcipher_walk_virt(&walk, &rctx->subreq, false);
+
+		while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+			kernel_neon_begin();
+			aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+					ctx->key_dec, rounds, blocks, walk.iv);
+			kernel_neon_end();
+			err = skcipher_walk_done(&walk,
+						 walk.nbytes % AES_BLOCK_SIZE);
+		}
+		if (err)
+			return err;
+
+		if (req->cryptlen == AES_BLOCK_SIZE)
+			return 0;
+
+		dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
+					     rctx->subreq.cryptlen);
+		if (req->dst != req->src)
+			dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
+					       rctx->subreq.cryptlen);
+	}
+
+	/* handle ciphertext stealing */
+	skcipher_request_set_crypt(&rctx->subreq, src, dst,
+				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
+				   req->iv);
+
+	err = skcipher_walk_virt(&walk, &rctx->subreq, false);
+	if (err)
+		return err;
+
+	kernel_neon_begin();
+	aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+			    ctx->key_dec, rounds, walk.nbytes, walk.iv);
+	kernel_neon_end();
+
+	return skcipher_walk_done(&walk, 0);
+}
+
 static int ctr_encrypt(struct skcipher_request *req)
 {
 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
@@ -336,6 +482,25 @@ static struct skcipher_alg aes_algs[] = { {
 	.decrypt	= cbc_decrypt,
 }, {
 	.base = {
+		.cra_name		= "__cts(cbc(aes))",
+		.cra_driver_name	= "__cts-cbc-aes-" MODE,
+		.cra_priority		= PRIO,
+		.cra_flags		= CRYPTO_ALG_INTERNAL,
+		.cra_blocksize		= 1,
+		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
+		.cra_module		= THIS_MODULE,
+	},
+	.min_keysize	= AES_MIN_KEY_SIZE,
+	.max_keysize	= AES_MAX_KEY_SIZE,
+	.ivsize		= AES_BLOCK_SIZE,
+	.chunksize	= AES_BLOCK_SIZE,
+	.walksize	= 2 * AES_BLOCK_SIZE,
+	.setkey		= skcipher_aes_setkey,
+	.encrypt	= cts_cbc_encrypt,
+	.decrypt	= cts_cbc_decrypt,
+	.init		= cts_cbc_init_tfm,
+}, {
+	.base = {
 		.cra_name		= "__ctr(aes)",
 		.cra_driver_name	= "__ctr-aes-" MODE,
 		.cra_priority		= PRIO,
diff --git a/arch/arm64/crypto/aes-modes.S b/arch/arm64/crypto/aes-modes.S
index 35632d1..9697eda 100644
--- a/arch/arm64/crypto/aes-modes.S
+++ b/arch/arm64/crypto/aes-modes.S
@@ -171,6 +171,84 @@
 
 
 	/*
+	 * aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
+	 *		       int rounds, int bytes, u8 const iv[])
+	 * aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
+	 *		       int rounds, int bytes, u8 const iv[])
+	 */
+
+AES_ENTRY(aes_cbc_cts_encrypt)
+	adr_l		x8, .Lcts_permute_table
+	sub		x4, x4, #16
+	add		x9, x8, #32
+	add		x8, x8, x4
+	sub		x9, x9, x4
+	ld1		{v3.16b}, [x8]
+	ld1		{v4.16b}, [x9]
+
+	ld1		{v0.16b}, [x1], x4		/* overlapping loads */
+	ld1		{v1.16b}, [x1]
+
+	ld1		{v5.16b}, [x5]			/* get iv */
+	enc_prepare	w3, x2, x6
+
+	eor		v0.16b, v0.16b, v5.16b		/* xor with iv */
+	tbl		v1.16b, {v1.16b}, v4.16b
+	encrypt_block	v0, w3, x2, x6, w7
+
+	eor		v1.16b, v1.16b, v0.16b
+	tbl		v0.16b, {v0.16b}, v3.16b
+	encrypt_block	v1, w3, x2, x6, w7
+
+	add		x4, x0, x4
+	st1		{v0.16b}, [x4]			/* overlapping stores */
+	st1		{v1.16b}, [x0]
+	ret
+AES_ENDPROC(aes_cbc_cts_encrypt)
+
+AES_ENTRY(aes_cbc_cts_decrypt)
+	adr_l		x8, .Lcts_permute_table
+	sub		x4, x4, #16
+	add		x9, x8, #32
+	add		x8, x8, x4
+	sub		x9, x9, x4
+	ld1		{v3.16b}, [x8]
+	ld1		{v4.16b}, [x9]
+
+	ld1		{v0.16b}, [x1], x4		/* overlapping loads */
+	ld1		{v1.16b}, [x1]
+
+	ld1		{v5.16b}, [x5]			/* get iv */
+	dec_prepare	w3, x2, x6
+
+	tbl		v2.16b, {v1.16b}, v4.16b
+	decrypt_block	v0, w3, x2, x6, w7
+	eor		v2.16b, v2.16b, v0.16b
+
+	tbx		v0.16b, {v1.16b}, v4.16b
+	tbl		v2.16b, {v2.16b}, v3.16b
+	decrypt_block	v0, w3, x2, x6, w7
+	eor		v0.16b, v0.16b, v5.16b		/* xor with iv */
+
+	add		x4, x0, x4
+	st1		{v2.16b}, [x4]			/* overlapping stores */
+	st1		{v0.16b}, [x0]
+	ret
+AES_ENDPROC(aes_cbc_cts_decrypt)
+
+	.section	".rodata", "a"
+	.align		6
+.Lcts_permute_table:
+	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+	.byte		 0x0,  0x1,  0x2,  0x3,  0x4,  0x5,  0x6,  0x7
+	.byte		 0x8,  0x9,  0xa,  0xb,  0xc,  0xd,  0xe,  0xf
+	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+	.previous
+
+
+	/*
 	 * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
 	 *		   int blocks, u8 ctr[])
 	 */
@@ -253,7 +331,6 @@
 	ins		v4.d[0], x7
 	b		.Lctrcarrydone
 AES_ENDPROC(aes_ctr_encrypt)
-	.ltorg
 
 
 	/*