net/tls: Combined memory allocation for decryption request

For preparing decryption request, several memory chunks are required
(aead_req, sgin, sgout, iv, aad). For submitting the decrypt request to
an accelerator, it is required that the buffers which are read by the
accelerator must be dma-able and not come from stack. The buffers for
aad and iv can be separately kmalloced each, but it is inefficient.
This patch does a combined allocation for preparing decryption request
and then segments into aead_req || sgin || sgout || iv || aad.

Signed-off-by: Vakul Garg <vakul.garg@nxp.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 83d67df..52fbe72 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -48,19 +48,13 @@ static int tls_do_decryption(struct sock *sk,
 			     struct scatterlist *sgout,
 			     char *iv_recv,
 			     size_t data_len,
-			     struct sk_buff *skb,
-			     gfp_t flags)
+			     struct aead_request *aead_req)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
-	struct aead_request *aead_req;
-
 	int ret;
 
-	aead_req = aead_request_alloc(ctx->aead_recv, flags);
-	if (!aead_req)
-		return -ENOMEM;
-
+	aead_request_set_tfm(aead_req, ctx->aead_recv);
 	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
 	aead_request_set_crypt(aead_req, sgin, sgout,
 			       data_len + tls_ctx->rx.tag_size,
@@ -69,8 +63,6 @@ static int tls_do_decryption(struct sock *sk,
 				  crypto_req_done, &ctx->async_wait);
 
 	ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
-
-	aead_request_free(aead_req);
 	return ret;
 }
 
@@ -657,8 +649,132 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
 	return skb;
 }
 
+/* This function decrypts the input skb into either out_iov or in out_sg
+ * or in skb buffers itself. The input parameter 'zc' indicates if
+ * zero-copy mode needs to be tried or not. With zero-copy mode, either
+ * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
+ * NULL, then the decryption happens inside skb buffers itself, i.e.
+ * zero-copy gets disabled and 'zc' is updated.
+ */
+
+static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
+			    struct iov_iter *out_iov,
+			    struct scatterlist *out_sg,
+			    int *chunk, bool *zc)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	struct strp_msg *rxm = strp_msg(skb);
+	int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
+	struct aead_request *aead_req;
+	struct sk_buff *unused;
+	u8 *aad, *iv, *mem = NULL;
+	struct scatterlist *sgin = NULL;
+	struct scatterlist *sgout = NULL;
+	const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;
+
+	if (*zc && (out_iov || out_sg)) {
+		if (out_iov)
+			n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
+		else
+			n_sgout = sg_nents(out_sg);
+	} else {
+		n_sgout = 0;
+		*zc = false;
+	}
+
+	n_sgin = skb_cow_data(skb, 0, &unused);
+	if (n_sgin < 1)
+		return -EBADMSG;
+
+	/* Increment to accommodate AAD */
+	n_sgin = n_sgin + 1;
+
+	nsg = n_sgin + n_sgout;
+
+	aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
+	mem_size = aead_size + (nsg * sizeof(struct scatterlist));
+	mem_size = mem_size + TLS_AAD_SPACE_SIZE;
+	mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
+
+	/* Allocate a single block of memory which contains
+	 * aead_req || sgin[] || sgout[] || aad || iv.
+	 * This order achieves correct alignment for aead_req, sgin, sgout.
+	 */
+	mem = kmalloc(mem_size, sk->sk_allocation);
+	if (!mem)
+		return -ENOMEM;
+
+	/* Segment the allocated memory */
+	aead_req = (struct aead_request *)mem;
+	sgin = (struct scatterlist *)(mem + aead_size);
+	sgout = sgin + n_sgin;
+	aad = (u8 *)(sgout + n_sgout);
+	iv = aad + TLS_AAD_SPACE_SIZE;
+
+	/* Prepare IV */
+	err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
+			    iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
+			    tls_ctx->rx.iv_size);
+	if (err < 0) {
+		kfree(mem);
+		return err;
+	}
+	memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+
+	/* Prepare AAD */
+	tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
+		     tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
+		     ctx->control);
+
+	/* Prepare sgin */
+	sg_init_table(sgin, n_sgin);
+	sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
+	err = skb_to_sgvec(skb, &sgin[1],
+			   rxm->offset + tls_ctx->rx.prepend_size,
+			   rxm->full_len - tls_ctx->rx.prepend_size);
+	if (err < 0) {
+		kfree(mem);
+		return err;
+	}
+
+	if (n_sgout) {
+		if (out_iov) {
+			sg_init_table(sgout, n_sgout);
+			sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);
+
+			*chunk = 0;
+			err = zerocopy_from_iter(sk, out_iov, data_len, &pages,
+						 chunk, &sgout[1],
+						 (n_sgout - 1), false);
+			if (err < 0)
+				goto fallback_to_reg_recv;
+		} else if (out_sg) {
+			memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
+		} else {
+			goto fallback_to_reg_recv;
+		}
+	} else {
+fallback_to_reg_recv:
+		sgout = sgin;
+		pages = 0;
+		*chunk = 0;
+		*zc = false;
+	}
+
+	/* Prepare and submit AEAD request */
+	err = tls_do_decryption(sk, sgin, sgout, iv, data_len, aead_req);
+
+	/* Release the pages in case iov was mapped to pages */
+	for (; pages > 0; pages--)
+		put_page(sg_page(&sgout[pages]));
+
+	kfree(mem);
+	return err;
+}
+
 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
-			      struct scatterlist *sgout, bool *zc)
+			      struct iov_iter *dest, int *chunk, bool *zc)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -671,7 +787,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
 		return err;
 #endif
 	if (!ctx->decrypted) {
-		err = decrypt_skb(sk, skb, sgout);
+		err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
 		if (err < 0)
 			return err;
 	} else {
@@ -690,54 +806,10 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
 int decrypt_skb(struct sock *sk, struct sk_buff *skb,
 		struct scatterlist *sgout)
 {
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
-	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
-	char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + MAX_IV_SIZE];
-	struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2];
-	struct scatterlist *sgin = &sgin_arr[0];
-	struct strp_msg *rxm = strp_msg(skb);
-	int ret, nsg = ARRAY_SIZE(sgin_arr);
-	struct sk_buff *unused;
+	bool zc = true;
+	int chunk;
 
-	ret = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
-			    iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
-			    tls_ctx->rx.iv_size);
-	if (ret < 0)
-		return ret;
-
-	memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
-	if (!sgout) {
-		nsg = skb_cow_data(skb, 0, &unused) + 1;
-		sgin = kmalloc_array(nsg, sizeof(*sgin), sk->sk_allocation);
-		sgout = sgin;
-	}
-
-	sg_init_table(sgin, nsg);
-	sg_set_buf(&sgin[0], ctx->rx_aad_ciphertext, TLS_AAD_SPACE_SIZE);
-
-	nsg = skb_to_sgvec(skb, &sgin[1],
-			   rxm->offset + tls_ctx->rx.prepend_size,
-			   rxm->full_len - tls_ctx->rx.prepend_size);
-	if (nsg < 0) {
-		ret = nsg;
-		goto out;
-	}
-
-	tls_make_aad(ctx->rx_aad_ciphertext,
-		     rxm->full_len - tls_ctx->rx.overhead_size,
-		     tls_ctx->rx.rec_seq,
-		     tls_ctx->rx.rec_seq_size,
-		     ctx->control);
-
-	ret = tls_do_decryption(sk, sgin, sgout, iv,
-				rxm->full_len - tls_ctx->rx.overhead_size,
-				skb, sk->sk_allocation);
-
-out:
-	if (sgin != &sgin_arr[0])
-		kfree(sgin);
-
-	return ret;
+	return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
 }
 
 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
@@ -816,43 +888,17 @@ int tls_sw_recvmsg(struct sock *sk,
 		}
 
 		if (!ctx->decrypted) {
-			int page_count;
-			int to_copy;
+			int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
 
-			page_count = iov_iter_npages(&msg->msg_iter,
-						     MAX_SKB_FRAGS);
-			to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
-			if (!is_kvec && to_copy <= len && page_count < MAX_SKB_FRAGS &&
-			    likely(!(flags & MSG_PEEK)))  {
-				struct scatterlist sgin[MAX_SKB_FRAGS + 1];
-				int pages = 0;
-
+			if (!is_kvec && to_copy <= len &&
+			    likely(!(flags & MSG_PEEK)))
 				zc = true;
-				sg_init_table(sgin, MAX_SKB_FRAGS + 1);
-				sg_set_buf(&sgin[0], ctx->rx_aad_plaintext,
-					   TLS_AAD_SPACE_SIZE);
 
-				err = zerocopy_from_iter(sk, &msg->msg_iter,
-							 to_copy, &pages,
-							 &chunk, &sgin[1],
-							 MAX_SKB_FRAGS,	false);
-				if (err < 0)
-					goto fallback_to_reg_recv;
-
-				err = decrypt_skb_update(sk, skb, sgin, &zc);
-				for (; pages > 0; pages--)
-					put_page(sg_page(&sgin[pages]));
-				if (err < 0) {
-					tls_err_abort(sk, EBADMSG);
-					goto recv_end;
-				}
-			} else {
-fallback_to_reg_recv:
-				err = decrypt_skb_update(sk, skb, NULL, &zc);
-				if (err < 0) {
-					tls_err_abort(sk, EBADMSG);
-					goto recv_end;
-				}
+			err = decrypt_skb_update(sk, skb, &msg->msg_iter,
+						 &chunk, &zc);
+			if (err < 0) {
+				tls_err_abort(sk, EBADMSG);
+				goto recv_end;
 			}
 			ctx->decrypted = true;
 		}
@@ -903,7 +949,7 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
 	int err = 0;
 	long timeo;
 	int chunk;
-	bool zc;
+	bool zc = false;
 
 	lock_sock(sk);
 
@@ -920,7 +966,7 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
 	}
 
 	if (!ctx->decrypted) {
-		err = decrypt_skb_update(sk, skb, NULL, &zc);
+		err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
 
 		if (err < 0) {
 			tls_err_abort(sk, EBADMSG);