xprtrdma: Refactor rpcrdma_reply_handler()

Refactor the reply handler's transport header decoding logic to make
it easier to understand and update.

Convert some of the handler to use xdr_streams, which will enable
stricter validation of input data and enable the eventual addition
of support for new combinations of chunks, such as "Write + Reply"
or "PZRC + normal Read".

Signed-off-by: Chuck Lever <chuck.lever@oracle.com>
Signed-off-by: Anna Schumaker <Anna.Schumaker@Netapp.com>
diff --git a/net/sunrpc/xprtrdma/rpc_rdma.c b/net/sunrpc/xprtrdma/rpc_rdma.c
index 9b5ab59..56f2277 100644
--- a/net/sunrpc/xprtrdma/rpc_rdma.c
+++ b/net/sunrpc/xprtrdma/rpc_rdma.c
@@ -1004,6 +1004,124 @@ rpcrdma_is_bcall(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep,
 }
 #endif	/* CONFIG_SUNRPC_BACKCHANNEL */
 
+static int
+rpcrdma_decode_msg(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep,
+		   struct rpc_rqst *rqst)
+{
+	struct xdr_stream *xdr = &rep->rr_stream;
+	int rdmalen, status;
+	__be32 *p;
+
+	p = xdr_inline_decode(xdr, 2 * sizeof(*p));
+	if (unlikely(!p))
+		return -EIO;
+
+	/* never expect read list */
+	if (unlikely(*p++ != xdr_zero))
+		return -EIO;
+
+	/* Write list */
+	if (*p != xdr_zero) {
+		char *base = rep->rr_hdrbuf.head[0].iov_base;
+
+		p++;
+		rdmalen = rpcrdma_count_chunks(rep, 1, &p);
+		if (rdmalen < 0 || *p++ != xdr_zero)
+			return -EIO;
+
+		rep->rr_len -= (char *)p - base;
+		status = rep->rr_len + rdmalen;
+		r_xprt->rx_stats.total_rdma_reply += rdmalen;
+
+		/* special case - last segment may omit padding */
+		rdmalen &= 3;
+		if (rdmalen) {
+			rdmalen = 4 - rdmalen;
+			status += rdmalen;
+		}
+	} else {
+		p = xdr_inline_decode(xdr, sizeof(*p));
+		if (!p)
+			return -EIO;
+
+		/* never expect reply chunk */
+		if (*p++ != xdr_zero)
+			return -EIO;
+		rdmalen = 0;
+		rep->rr_len -= RPCRDMA_HDRLEN_MIN;
+		status = rep->rr_len;
+	}
+
+	r_xprt->rx_stats.fixup_copy_count +=
+		rpcrdma_inline_fixup(rqst, (char *)p, rep->rr_len,
+				     rdmalen);
+
+	return status;
+}
+
+static noinline int
+rpcrdma_decode_nomsg(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep)
+{
+	struct xdr_stream *xdr = &rep->rr_stream;
+	int rdmalen;
+	__be32 *p;
+
+	p = xdr_inline_decode(xdr, 3 * sizeof(*p));
+	if (unlikely(!p))
+		return -EIO;
+
+	/* never expect Read chunks */
+	if (unlikely(*p++ != xdr_zero))
+		return -EIO;
+	/* never expect Write chunks */
+	if (unlikely(*p++ != xdr_zero))
+		return -EIO;
+	/* always expect a Reply chunk */
+	if (unlikely(*p++ == xdr_zero))
+		return -EIO;
+
+	rdmalen = rpcrdma_count_chunks(rep, 0, &p);
+	if (rdmalen < 0)
+		return -EIO;
+	r_xprt->rx_stats.total_rdma_reply += rdmalen;
+
+	/* Reply chunk buffer already is the reply vector - no fixup. */
+	return rdmalen;
+}
+
+static noinline int
+rpcrdma_decode_error(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep,
+		     struct rpc_rqst *rqst)
+{
+	struct xdr_stream *xdr = &rep->rr_stream;
+	__be32 *p;
+
+	p = xdr_inline_decode(xdr, sizeof(*p));
+	if (unlikely(!p))
+		return -EIO;
+
+	switch (*p) {
+	case err_vers:
+		p = xdr_inline_decode(xdr, 2 * sizeof(*p));
+		if (!p)
+			break;
+		dprintk("RPC: %5u: %s: server reports version error (%u-%u)\n",
+			rqst->rq_task->tk_pid, __func__,
+			be32_to_cpup(p), be32_to_cpu(*(p + 1)));
+		break;
+	case err_chunk:
+		dprintk("RPC: %5u: %s: server reports header decoding error\n",
+			rqst->rq_task->tk_pid, __func__);
+		break;
+	default:
+		dprintk("RPC: %5u: %s: server reports unrecognized error %d\n",
+			rqst->rq_task->tk_pid, __func__, be32_to_cpup(p));
+	}
+
+	r_xprt->rx_stats.bad_reply_count++;
+	return -EREMOTEIO;
+}
+
 /* Process received RPC/RDMA messages.
  *
  * Errors must result in the RPC task either being awakened, or
@@ -1018,13 +1136,12 @@ rpcrdma_reply_handler(struct work_struct *work)
 	struct rpcrdma_buffer *buf = &r_xprt->rx_buf;
 	struct rpc_xprt *xprt = &r_xprt->rx_xprt;
 	struct xdr_stream *xdr = &rep->rr_stream;
-	struct rpcrdma_msg *headerp;
 	struct rpcrdma_req *req;
 	struct rpc_rqst *rqst;
-	__be32 *iptr, *p, xid, vers, proc;
-	int rdmalen, status, rmerr;
+	__be32 *p, xid, vers, proc;
 	unsigned long cwnd;
 	struct list_head mws;
+	int status;
 
 	dprintk("RPC:       %s: incoming rep %p\n", __func__, rep);
 
@@ -1043,7 +1160,6 @@ rpcrdma_reply_handler(struct work_struct *work)
 	p++;	/* credits */
 	proc = *p++;
 
-	headerp = rdmab_to_msg(rep->rr_rdmabuf);
 	if (rpcrdma_is_bcall(r_xprt, rep, xid, proc))
 		return;
 
@@ -1091,75 +1207,21 @@ rpcrdma_reply_handler(struct work_struct *work)
 	if (vers != rpcrdma_version)
 		goto out_badversion;
 
-	/* check for expected message types */
-	/* The order of some of these tests is important. */
 	switch (proc) {
 	case rdma_msg:
-		/* never expect read chunks */
-		/* never expect reply chunks (two ways to check) */
-		if (headerp->rm_body.rm_chunks[0] != xdr_zero ||
-		    (headerp->rm_body.rm_chunks[1] == xdr_zero &&
-		     headerp->rm_body.rm_chunks[2] != xdr_zero))
-			goto badheader;
-		if (headerp->rm_body.rm_chunks[1] != xdr_zero) {
-			/* count any expected write chunks in read reply */
-			/* start at write chunk array count */
-			iptr = &headerp->rm_body.rm_chunks[2];
-			rdmalen = rpcrdma_count_chunks(rep, 1, &iptr);
-			/* check for validity, and no reply chunk after */
-			if (rdmalen < 0 || *iptr++ != xdr_zero)
-				goto badheader;
-			rep->rr_len -=
-			    ((unsigned char *)iptr - (unsigned char *)headerp);
-			status = rep->rr_len + rdmalen;
-			r_xprt->rx_stats.total_rdma_reply += rdmalen;
-			/* special case - last chunk may omit padding */
-			if (rdmalen &= 3) {
-				rdmalen = 4 - rdmalen;
-				status += rdmalen;
-			}
-		} else {
-			/* else ordinary inline */
-			rdmalen = 0;
-			iptr = (__be32 *)((unsigned char *)headerp +
-							RPCRDMA_HDRLEN_MIN);
-			rep->rr_len -= RPCRDMA_HDRLEN_MIN;
-			status = rep->rr_len;
-		}
-
-		r_xprt->rx_stats.fixup_copy_count +=
-			rpcrdma_inline_fixup(rqst, (char *)iptr, rep->rr_len,
-					     rdmalen);
+		status = rpcrdma_decode_msg(r_xprt, rep, rqst);
 		break;
-
 	case rdma_nomsg:
-		/* never expect read or write chunks, always reply chunks */
-		if (headerp->rm_body.rm_chunks[0] != xdr_zero ||
-		    headerp->rm_body.rm_chunks[1] != xdr_zero ||
-		    headerp->rm_body.rm_chunks[2] != xdr_one)
-			goto badheader;
-		iptr = (__be32 *)((unsigned char *)headerp +
-							RPCRDMA_HDRLEN_MIN);
-		rdmalen = rpcrdma_count_chunks(rep, 0, &iptr);
-		if (rdmalen < 0)
-			goto badheader;
-		r_xprt->rx_stats.total_rdma_reply += rdmalen;
-		/* Reply chunk buffer already is the reply vector - no fixup. */
-		status = rdmalen;
+		status = rpcrdma_decode_nomsg(r_xprt, rep);
 		break;
-
 	case rdma_error:
-		goto out_rdmaerr;
-
-badheader:
-	default:
-		dprintk("RPC: %5u %s: invalid rpcrdma reply (type %u)\n",
-			rqst->rq_task->tk_pid, __func__,
-			be32_to_cpu(proc));
-		status = -EIO;
-		r_xprt->rx_stats.bad_reply_count++;
+		status = rpcrdma_decode_error(r_xprt, rep, rqst);
 		break;
+	default:
+		status = -EIO;
 	}
+	if (status < 0)
+		goto out_badheader;
 
 out:
 	cwnd = xprt->cwnd;
@@ -1192,25 +1254,11 @@ rpcrdma_reply_handler(struct work_struct *work)
 	r_xprt->rx_stats.bad_reply_count++;
 	goto out;
 
-out_rdmaerr:
-	rmerr = be32_to_cpu(headerp->rm_body.rm_error.rm_err);
-	switch (rmerr) {
-	case ERR_VERS:
-		pr_err("%s: server reports header version error (%u-%u)\n",
-		       __func__,
-		       be32_to_cpu(headerp->rm_body.rm_error.rm_vers_low),
-		       be32_to_cpu(headerp->rm_body.rm_error.rm_vers_high));
-		break;
-	case ERR_CHUNK:
-		pr_err("%s: server reports header decoding error\n",
-		       __func__);
-		break;
-	default:
-		pr_err("%s: server reports unknown error %d\n",
-		       __func__, rmerr);
-	}
-	status = -EREMOTEIO;
+out_badheader:
+	dprintk("RPC: %5u %s: invalid rpcrdma reply (type %u)\n",
+		rqst->rq_task->tk_pid, __func__, be32_to_cpu(proc));
 	r_xprt->rx_stats.bad_reply_count++;
+	status = -EIO;
 	goto out;
 
 /* The req was still available, but by the time the transport_lock