crypto: x86/chacha20 - Support partial lengths in 8-block AVX2 variant

Add a length argument to the eight block function for AVX2, so the
block function may XOR only a partial length of eight blocks.

To avoid unnecessary operations, we integrate XORing of the first four
blocks in the final lane interleaving; this also avoids some work in
the partial lengths path.

Signed-off-by: Martin Willi <martin@strongswan.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
diff --git a/arch/x86/crypto/chacha20-avx2-x86_64.S b/arch/x86/crypto/chacha20-avx2-x86_64.S
index f3cd26f..7b62d55 100644
--- a/arch/x86/crypto/chacha20-avx2-x86_64.S
+++ b/arch/x86/crypto/chacha20-avx2-x86_64.S
@@ -30,8 +30,9 @@
 
 ENTRY(chacha20_8block_xor_avx2)
 	# %rdi: Input state matrix, s
-	# %rsi: 8 data blocks output, o
-	# %rdx: 8 data blocks input, i
+	# %rsi: up to 8 data blocks output, o
+	# %rdx: up to 8 data blocks input, i
+	# %rcx: input/output length in bytes
 
 	# This function encrypts eight consecutive ChaCha20 blocks by loading
 	# the state matrix in AVX registers eight times. As we need some
@@ -48,6 +49,7 @@
 	lea		8(%rsp),%r10
 	and		$~31, %rsp
 	sub		$0x80, %rsp
+	mov		%rcx,%rax
 
 	# x0..15[0-7] = s[0..15]
 	vpbroadcastd	0x00(%rdi),%ymm0
@@ -375,74 +377,143 @@
 	vpunpckhqdq	%ymm15,%ymm0,%ymm15
 
 	# interleave 128-bit words in state n, n+4
-	vmovdqa		0x00(%rsp),%ymm0
-	vperm2i128	$0x20,%ymm4,%ymm0,%ymm1
-	vperm2i128	$0x31,%ymm4,%ymm0,%ymm4
-	vmovdqa		%ymm1,0x00(%rsp)
-	vmovdqa		0x20(%rsp),%ymm0
-	vperm2i128	$0x20,%ymm5,%ymm0,%ymm1
-	vperm2i128	$0x31,%ymm5,%ymm0,%ymm5
-	vmovdqa		%ymm1,0x20(%rsp)
-	vmovdqa		0x40(%rsp),%ymm0
-	vperm2i128	$0x20,%ymm6,%ymm0,%ymm1
-	vperm2i128	$0x31,%ymm6,%ymm0,%ymm6
-	vmovdqa		%ymm1,0x40(%rsp)
-	vmovdqa		0x60(%rsp),%ymm0
-	vperm2i128	$0x20,%ymm7,%ymm0,%ymm1
-	vperm2i128	$0x31,%ymm7,%ymm0,%ymm7
-	vmovdqa		%ymm1,0x60(%rsp)
-	vperm2i128	$0x20,%ymm12,%ymm8,%ymm0
-	vperm2i128	$0x31,%ymm12,%ymm8,%ymm12
-	vmovdqa		%ymm0,%ymm8
-	vperm2i128	$0x20,%ymm13,%ymm9,%ymm0
-	vperm2i128	$0x31,%ymm13,%ymm9,%ymm13
-	vmovdqa		%ymm0,%ymm9
-	vperm2i128	$0x20,%ymm14,%ymm10,%ymm0
-	vperm2i128	$0x31,%ymm14,%ymm10,%ymm14
-	vmovdqa		%ymm0,%ymm10
-	vperm2i128	$0x20,%ymm15,%ymm11,%ymm0
-	vperm2i128	$0x31,%ymm15,%ymm11,%ymm15
-	vmovdqa		%ymm0,%ymm11
-
-	# xor with corresponding input, write to output
-	vmovdqa		0x00(%rsp),%ymm0
+	# xor/write first four blocks
+	vmovdqa		0x00(%rsp),%ymm1
+	vperm2i128	$0x20,%ymm4,%ymm1,%ymm0
+	cmp		$0x0020,%rax
+	jl		.Lxorpart8
 	vpxor		0x0000(%rdx),%ymm0,%ymm0
 	vmovdqu		%ymm0,0x0000(%rsi)
-	vmovdqa		0x20(%rsp),%ymm0
-	vpxor		0x0080(%rdx),%ymm0,%ymm0
-	vmovdqu		%ymm0,0x0080(%rsi)
-	vmovdqa		0x40(%rsp),%ymm0
+	vperm2i128	$0x31,%ymm4,%ymm1,%ymm4
+
+	vperm2i128	$0x20,%ymm12,%ymm8,%ymm0
+	cmp		$0x0040,%rax
+	jl		.Lxorpart8
+	vpxor		0x0020(%rdx),%ymm0,%ymm0
+	vmovdqu		%ymm0,0x0020(%rsi)
+	vperm2i128	$0x31,%ymm12,%ymm8,%ymm12
+
+	vmovdqa		0x40(%rsp),%ymm1
+	vperm2i128	$0x20,%ymm6,%ymm1,%ymm0
+	cmp		$0x0060,%rax
+	jl		.Lxorpart8
 	vpxor		0x0040(%rdx),%ymm0,%ymm0
 	vmovdqu		%ymm0,0x0040(%rsi)
-	vmovdqa		0x60(%rsp),%ymm0
+	vperm2i128	$0x31,%ymm6,%ymm1,%ymm6
+
+	vperm2i128	$0x20,%ymm14,%ymm10,%ymm0
+	cmp		$0x0080,%rax
+	jl		.Lxorpart8
+	vpxor		0x0060(%rdx),%ymm0,%ymm0
+	vmovdqu		%ymm0,0x0060(%rsi)
+	vperm2i128	$0x31,%ymm14,%ymm10,%ymm14
+
+	vmovdqa		0x20(%rsp),%ymm1
+	vperm2i128	$0x20,%ymm5,%ymm1,%ymm0
+	cmp		$0x00a0,%rax
+	jl		.Lxorpart8
+	vpxor		0x0080(%rdx),%ymm0,%ymm0
+	vmovdqu		%ymm0,0x0080(%rsi)
+	vperm2i128	$0x31,%ymm5,%ymm1,%ymm5
+
+	vperm2i128	$0x20,%ymm13,%ymm9,%ymm0
+	cmp		$0x00c0,%rax
+	jl		.Lxorpart8
+	vpxor		0x00a0(%rdx),%ymm0,%ymm0
+	vmovdqu		%ymm0,0x00a0(%rsi)
+	vperm2i128	$0x31,%ymm13,%ymm9,%ymm13
+
+	vmovdqa		0x60(%rsp),%ymm1
+	vperm2i128	$0x20,%ymm7,%ymm1,%ymm0
+	cmp		$0x00e0,%rax
+	jl		.Lxorpart8
 	vpxor		0x00c0(%rdx),%ymm0,%ymm0
 	vmovdqu		%ymm0,0x00c0(%rsi)
-	vpxor		0x0100(%rdx),%ymm4,%ymm4
-	vmovdqu		%ymm4,0x0100(%rsi)
-	vpxor		0x0180(%rdx),%ymm5,%ymm5
-	vmovdqu		%ymm5,0x00180(%rsi)
-	vpxor		0x0140(%rdx),%ymm6,%ymm6
-	vmovdqu		%ymm6,0x0140(%rsi)
-	vpxor		0x01c0(%rdx),%ymm7,%ymm7
-	vmovdqu		%ymm7,0x01c0(%rsi)
-	vpxor		0x0020(%rdx),%ymm8,%ymm8
-	vmovdqu		%ymm8,0x0020(%rsi)
-	vpxor		0x00a0(%rdx),%ymm9,%ymm9
-	vmovdqu		%ymm9,0x00a0(%rsi)
-	vpxor		0x0060(%rdx),%ymm10,%ymm10
-	vmovdqu		%ymm10,0x0060(%rsi)
-	vpxor		0x00e0(%rdx),%ymm11,%ymm11
-	vmovdqu		%ymm11,0x00e0(%rsi)
-	vpxor		0x0120(%rdx),%ymm12,%ymm12
-	vmovdqu		%ymm12,0x0120(%rsi)
-	vpxor		0x01a0(%rdx),%ymm13,%ymm13
-	vmovdqu		%ymm13,0x01a0(%rsi)
-	vpxor		0x0160(%rdx),%ymm14,%ymm14
-	vmovdqu		%ymm14,0x0160(%rsi)
-	vpxor		0x01e0(%rdx),%ymm15,%ymm15
-	vmovdqu		%ymm15,0x01e0(%rsi)
+	vperm2i128	$0x31,%ymm7,%ymm1,%ymm7
 
+	vperm2i128	$0x20,%ymm15,%ymm11,%ymm0
+	cmp		$0x0100,%rax
+	jl		.Lxorpart8
+	vpxor		0x00e0(%rdx),%ymm0,%ymm0
+	vmovdqu		%ymm0,0x00e0(%rsi)
+	vperm2i128	$0x31,%ymm15,%ymm11,%ymm15
+
+	# xor remaining blocks, write to output
+	vmovdqa		%ymm4,%ymm0
+	cmp		$0x0120,%rax
+	jl		.Lxorpart8
+	vpxor		0x0100(%rdx),%ymm0,%ymm0
+	vmovdqu		%ymm0,0x0100(%rsi)
+
+	vmovdqa		%ymm12,%ymm0
+	cmp		$0x0140,%rax
+	jl		.Lxorpart8
+	vpxor		0x0120(%rdx),%ymm0,%ymm0
+	vmovdqu		%ymm0,0x0120(%rsi)
+
+	vmovdqa		%ymm6,%ymm0
+	cmp		$0x0160,%rax
+	jl		.Lxorpart8
+	vpxor		0x0140(%rdx),%ymm0,%ymm0
+	vmovdqu		%ymm0,0x0140(%rsi)
+
+	vmovdqa		%ymm14,%ymm0
+	cmp		$0x0180,%rax
+	jl		.Lxorpart8
+	vpxor		0x0160(%rdx),%ymm0,%ymm0
+	vmovdqu		%ymm0,0x0160(%rsi)
+
+	vmovdqa		%ymm5,%ymm0
+	cmp		$0x01a0,%rax
+	jl		.Lxorpart8
+	vpxor		0x0180(%rdx),%ymm0,%ymm0
+	vmovdqu		%ymm0,0x0180(%rsi)
+
+	vmovdqa		%ymm13,%ymm0
+	cmp		$0x01c0,%rax
+	jl		.Lxorpart8
+	vpxor		0x01a0(%rdx),%ymm0,%ymm0
+	vmovdqu		%ymm0,0x01a0(%rsi)
+
+	vmovdqa		%ymm7,%ymm0
+	cmp		$0x01e0,%rax
+	jl		.Lxorpart8
+	vpxor		0x01c0(%rdx),%ymm0,%ymm0
+	vmovdqu		%ymm0,0x01c0(%rsi)
+
+	vmovdqa		%ymm15,%ymm0
+	cmp		$0x0200,%rax
+	jl		.Lxorpart8
+	vpxor		0x01e0(%rdx),%ymm0,%ymm0
+	vmovdqu		%ymm0,0x01e0(%rsi)
+
+.Ldone8:
 	vzeroupper
 	lea		-8(%r10),%rsp
 	ret
+
+.Lxorpart8:
+	# xor remaining bytes from partial register into output
+	mov		%rax,%r9
+	and		$0x1f,%r9
+	jz		.Ldone8
+	and		$~0x1f,%rax
+
+	mov		%rsi,%r11
+
+	lea		(%rdx,%rax),%rsi
+	mov		%rsp,%rdi
+	mov		%r9,%rcx
+	rep movsb
+
+	vpxor		0x00(%rsp),%ymm0,%ymm0
+	vmovdqa		%ymm0,0x00(%rsp)
+
+	mov		%rsp,%rsi
+	lea		(%r11,%rax),%rdi
+	mov		%r9,%rcx
+	rep movsb
+
+	jmp		.Ldone8
+
 ENDPROC(chacha20_8block_xor_avx2)