tls: Fix recvmsg() to be able to peek across multiple records

This fixes recvmsg() to be able to peek across multiple tls records.
Without this patch, the tls's selftests test case
'recv_peek_large_buf_mult_recs' fails. Each tls receive context now
maintains a 'rx_list' to retain incoming skb carrying tls records. If a
tls record needs to be retained e.g. for peek case or for the case when
the buffer passed to recvmsg() has a length smaller than decrypted
record length, then it is added to 'rx_list'. Additionally, records are
added in 'rx_list' if the crypto operation runs in async mode. The
records are dequeued from 'rx_list' after the decrypted data is consumed
by copying into the buffer passed to recvmsg(). In case, the MSG_PEEK
flag is used in recvmsg(), then records are not consumed or removed
from the 'rx_list'.

Signed-off-by: Vakul Garg <vakul.garg@nxp.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/include/net/tls.h b/include/net/tls.h
index 2a6ac8d..90bf52d 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -145,12 +145,13 @@ struct tls_sw_context_tx {
 struct tls_sw_context_rx {
 	struct crypto_aead *aead_recv;
 	struct crypto_wait async_wait;
-
 	struct strparser strp;
+	struct sk_buff_head rx_list;	/* list of decrypted 'data' records */
 	void (*saved_data_ready)(struct sock *sk);
 
 	struct sk_buff *recv_pkt;
 	u8 control;
+	int async_capable;
 	bool decrypted;
 	atomic_t decrypt_pending;
 	bool async_notify;
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index b8e50e2..86b9527 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -124,6 +124,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
 {
 	struct aead_request *aead_req = (struct aead_request *)req;
 	struct scatterlist *sgout = aead_req->dst;
+	struct scatterlist *sgin = aead_req->src;
 	struct tls_sw_context_rx *ctx;
 	struct tls_context *tls_ctx;
 	struct scatterlist *sg;
@@ -134,12 +135,16 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
 	skb = (struct sk_buff *)req->data;
 	tls_ctx = tls_get_ctx(skb->sk);
 	ctx = tls_sw_ctx_rx(tls_ctx);
-	pending = atomic_dec_return(&ctx->decrypt_pending);
 
 	/* Propagate if there was an err */
 	if (err) {
 		ctx->async_wait.err = err;
 		tls_err_abort(skb->sk, err);
+	} else {
+		struct strp_msg *rxm = strp_msg(skb);
+
+		rxm->offset += tls_ctx->rx.prepend_size;
+		rxm->full_len -= tls_ctx->rx.overhead_size;
 	}
 
 	/* After using skb->sk to propagate sk through crypto async callback
@@ -147,18 +152,21 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
 	 */
 	skb->sk = NULL;
 
-	/* Release the skb, pages and memory allocated for crypto req */
-	kfree_skb(skb);
 
-	/* Skip the first S/G entry as it points to AAD */
-	for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
-		if (!sg)
-			break;
-		put_page(sg_page(sg));
+	/* Free the destination pages if skb was not decrypted inplace */
+	if (sgout != sgin) {
+		/* Skip the first S/G entry as it points to AAD */
+		for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
+			if (!sg)
+				break;
+			put_page(sg_page(sg));
+		}
 	}
 
 	kfree(aead_req);
 
+	pending = atomic_dec_return(&ctx->decrypt_pending);
+
 	if (!pending && READ_ONCE(ctx->async_notify))
 		complete(&ctx->async_wait.completion);
 }
@@ -1271,7 +1279,7 @@ static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
 static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 			    struct iov_iter *out_iov,
 			    struct scatterlist *out_sg,
-			    int *chunk, bool *zc)
+			    int *chunk, bool *zc, bool async)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -1371,13 +1379,13 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 fallback_to_reg_recv:
 		sgout = sgin;
 		pages = 0;
-		*chunk = 0;
+		*chunk = data_len;
 		*zc = false;
 	}
 
 	/* Prepare and submit AEAD request */
 	err = tls_do_decryption(sk, skb, sgin, sgout, iv,
-				data_len, aead_req, *zc);
+				data_len, aead_req, async);
 	if (err == -EINPROGRESS)
 		return err;
 
@@ -1390,7 +1398,8 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 }
 
 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
-			      struct iov_iter *dest, int *chunk, bool *zc)
+			      struct iov_iter *dest, int *chunk, bool *zc,
+			      bool async)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -1403,7 +1412,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
 		return err;
 #endif
 	if (!ctx->decrypted) {
-		err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
+		err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async);
 		if (err < 0) {
 			if (err == -EINPROGRESS)
 				tls_advance_record_sn(sk, &tls_ctx->rx);
@@ -1429,7 +1438,7 @@ int decrypt_skb(struct sock *sk, struct sk_buff *skb,
 	bool zc = true;
 	int chunk;
 
-	return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
+	return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false);
 }
 
 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
@@ -1456,6 +1465,77 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
 	return true;
 }
 
+/* This function traverses the rx_list in tls receive context to copies the
+ * decrypted data records into the buffer provided by caller zero copy is not
+ * true. Further, the records are removed from the rx_list if it is not a peek
+ * case and the record has been consumed completely.
+ */
+static int process_rx_list(struct tls_sw_context_rx *ctx,
+			   struct msghdr *msg,
+			   size_t skip,
+			   size_t len,
+			   bool zc,
+			   bool is_peek)
+{
+	struct sk_buff *skb = skb_peek(&ctx->rx_list);
+	ssize_t copied = 0;
+
+	while (skip && skb) {
+		struct strp_msg *rxm = strp_msg(skb);
+
+		if (skip < rxm->full_len)
+			break;
+
+		skip = skip - rxm->full_len;
+		skb = skb_peek_next(skb, &ctx->rx_list);
+	}
+
+	while (len && skb) {
+		struct sk_buff *next_skb;
+		struct strp_msg *rxm = strp_msg(skb);
+		int chunk = min_t(unsigned int, rxm->full_len - skip, len);
+
+		if (!zc || (rxm->full_len - skip) > len) {
+			int err = skb_copy_datagram_msg(skb, rxm->offset + skip,
+						    msg, chunk);
+			if (err < 0)
+				return err;
+		}
+
+		len = len - chunk;
+		copied = copied + chunk;
+
+		/* Consume the data from record if it is non-peek case*/
+		if (!is_peek) {
+			rxm->offset = rxm->offset + chunk;
+			rxm->full_len = rxm->full_len - chunk;
+
+			/* Return if there is unconsumed data in the record */
+			if (rxm->full_len - skip)
+				break;
+		}
+
+		/* The remaining skip-bytes must lie in 1st record in rx_list.
+		 * So from the 2nd record, 'skip' should be 0.
+		 */
+		skip = 0;
+
+		if (msg)
+			msg->msg_flags |= MSG_EOR;
+
+		next_skb = skb_peek_next(skb, &ctx->rx_list);
+
+		if (!is_peek) {
+			skb_unlink(skb, &ctx->rx_list);
+			kfree_skb(skb);
+		}
+
+		skb = next_skb;
+	}
+
+	return copied;
+}
+
 int tls_sw_recvmsg(struct sock *sk,
 		   struct msghdr *msg,
 		   size_t len,
@@ -1466,7 +1546,8 @@ int tls_sw_recvmsg(struct sock *sk,
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
 	struct sk_psock *psock;
-	unsigned char control;
+	unsigned char control = 0;
+	ssize_t decrypted = 0;
 	struct strp_msg *rxm;
 	struct sk_buff *skb;
 	ssize_t copied = 0;
@@ -1474,6 +1555,7 @@ int tls_sw_recvmsg(struct sock *sk,
 	int target, err = 0;
 	long timeo;
 	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
+	bool is_peek = flags & MSG_PEEK;
 	int num_async = 0;
 
 	flags |= nonblock;
@@ -1484,11 +1566,28 @@ int tls_sw_recvmsg(struct sock *sk,
 	psock = sk_psock_get(sk);
 	lock_sock(sk);
 
-	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
-	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+	/* Process pending decrypted records. It must be non-zero-copy */
+	err = process_rx_list(ctx, msg, 0, len, false, is_peek);
+	if (err < 0) {
+		tls_err_abort(sk, err);
+		goto end;
+	} else {
+		copied = err;
+	}
+
+	len = len - copied;
+	if (len) {
+		target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
+		timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+	} else {
+		goto recv_end;
+	}
+
 	do {
-		bool zc = false;
+		bool retain_skb = false;
 		bool async = false;
+		bool zc = false;
+		int to_decrypt;
 		int chunk = 0;
 
 		skb = tls_wait_data(sk, psock, flags, timeo, &err);
@@ -1498,7 +1597,7 @@ int tls_sw_recvmsg(struct sock *sk,
 							    msg, len, flags);
 
 				if (ret > 0) {
-					copied += ret;
+					decrypted += ret;
 					len -= ret;
 					continue;
 				}
@@ -1525,70 +1624,70 @@ int tls_sw_recvmsg(struct sock *sk,
 			goto recv_end;
 		}
 
-		if (!ctx->decrypted) {
-			int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
+		to_decrypt = rxm->full_len - tls_ctx->rx.overhead_size;
 
-			if (!is_kvec && to_copy <= len &&
-			    likely(!(flags & MSG_PEEK)))
-				zc = true;
+		if (to_decrypt <= len && !is_kvec && !is_peek)
+			zc = true;
 
-			err = decrypt_skb_update(sk, skb, &msg->msg_iter,
-						 &chunk, &zc);
-			if (err < 0 && err != -EINPROGRESS) {
-				tls_err_abort(sk, EBADMSG);
-				goto recv_end;
-			}
-
-			if (err == -EINPROGRESS) {
-				async = true;
-				num_async++;
-				goto pick_next_record;
-			}
-
-			ctx->decrypted = true;
+		err = decrypt_skb_update(sk, skb, &msg->msg_iter,
+					 &chunk, &zc, ctx->async_capable);
+		if (err < 0 && err != -EINPROGRESS) {
+			tls_err_abort(sk, EBADMSG);
+			goto recv_end;
 		}
 
-		if (!zc) {
-			chunk = min_t(unsigned int, rxm->full_len, len);
+		if (err == -EINPROGRESS) {
+			async = true;
+			num_async++;
+			goto pick_next_record;
+		} else {
+			if (!zc) {
+				if (rxm->full_len > len) {
+					retain_skb = true;
+					chunk = len;
+				} else {
+					chunk = rxm->full_len;
+				}
 
-			err = skb_copy_datagram_msg(skb, rxm->offset, msg,
-						    chunk);
-			if (err < 0)
-				goto recv_end;
+				err = skb_copy_datagram_msg(skb, rxm->offset,
+							    msg, chunk);
+				if (err < 0)
+					goto recv_end;
+
+				if (!is_peek) {
+					rxm->offset = rxm->offset + chunk;
+					rxm->full_len = rxm->full_len - chunk;
+				}
+			}
 		}
 
 pick_next_record:
-		copied += chunk;
+		if (chunk > len)
+			chunk = len;
+
+		decrypted += chunk;
 		len -= chunk;
-		if (likely(!(flags & MSG_PEEK))) {
-			u8 control = ctx->control;
 
-			/* For async, drop current skb reference */
-			if (async)
-				skb = NULL;
+		/* For async or peek case, queue the current skb */
+		if (async || is_peek || retain_skb) {
+			skb_queue_tail(&ctx->rx_list, skb);
+			skb = NULL;
+		}
 
-			if (tls_sw_advance_skb(sk, skb, chunk)) {
-				/* Return full control message to
-				 * userspace before trying to parse
-				 * another message type
-				 */
-				msg->msg_flags |= MSG_EOR;
-				if (control != TLS_RECORD_TYPE_DATA)
-					goto recv_end;
-			} else {
-				break;
-			}
-		} else {
-			/* MSG_PEEK right now cannot look beyond current skb
-			 * from strparser, meaning we cannot advance skb here
-			 * and thus unpause strparser since we'd loose original
-			 * one.
+		if (tls_sw_advance_skb(sk, skb, chunk)) {
+			/* Return full control message to
+			 * userspace before trying to parse
+			 * another message type
 			 */
+			msg->msg_flags |= MSG_EOR;
+			if (ctx->control != TLS_RECORD_TYPE_DATA)
+				goto recv_end;
+		} else {
 			break;
 		}
 
 		/* If we have a new message from strparser, continue now. */
-		if (copied >= target && !ctx->recv_pkt)
+		if (decrypted >= target && !ctx->recv_pkt)
 			break;
 	} while (len);
 
@@ -1602,13 +1701,33 @@ int tls_sw_recvmsg(struct sock *sk,
 				/* one of async decrypt failed */
 				tls_err_abort(sk, err);
 				copied = 0;
+				decrypted = 0;
+				goto end;
 			}
 		} else {
 			reinit_completion(&ctx->async_wait.completion);
 		}
 		WRITE_ONCE(ctx->async_notify, false);
+
+		/* Drain records from the rx_list & copy if required */
+		if (is_peek || is_kvec)
+			err = process_rx_list(ctx, msg, copied,
+					      decrypted, false, is_peek);
+		else
+			err = process_rx_list(ctx, msg, 0,
+					      decrypted, true, is_peek);
+		if (err < 0) {
+			tls_err_abort(sk, err);
+			copied = 0;
+			goto end;
+		}
+
+		WARN_ON(decrypted != err);
 	}
 
+	copied += decrypted;
+
+end:
 	release_sock(sk);
 	if (psock)
 		sk_psock_put(sk, psock);
@@ -1645,7 +1764,7 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
 	}
 
 	if (!ctx->decrypted) {
-		err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
+		err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
 
 		if (err < 0) {
 			tls_err_abort(sk, EBADMSG);
@@ -1832,6 +1951,7 @@ void tls_sw_release_resources_rx(struct sock *sk)
 	if (ctx->aead_recv) {
 		kfree_skb(ctx->recv_pkt);
 		ctx->recv_pkt = NULL;
+		skb_queue_purge(&ctx->rx_list);
 		crypto_free_aead(ctx->aead_recv);
 		strp_stop(&ctx->strp);
 		write_lock_bh(&sk->sk_callback_lock);
@@ -1881,6 +2001,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 	struct crypto_aead **aead;
 	struct strp_callbacks cb;
 	u16 nonce_size, tag_size, iv_size, rec_seq_size;
+	struct crypto_tfm *tfm;
 	char *iv, *rec_seq;
 	int rc = 0;
 
@@ -1927,6 +2048,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 		crypto_init_wait(&sw_ctx_rx->async_wait);
 		crypto_info = &ctx->crypto_recv.info;
 		cctx = &ctx->rx;
+		skb_queue_head_init(&sw_ctx_rx->rx_list);
 		aead = &sw_ctx_rx->aead_recv;
 	}
 
@@ -1994,6 +2116,10 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 		goto free_aead;
 
 	if (sw_ctx_rx) {
+		tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
+		sw_ctx_rx->async_capable =
+			tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC;
+
 		/* Set up strparser */
 		memset(&cb, 0, sizeof(cb));
 		cb.rcv_msg = tls_queue;