SUNRPC: Use RCU to dereference the rpc_clnt.cl_xprt field

A migration event will replace the rpc_xprt used by an rpc_clnt.  To
ensure this can be done safely, all references to cl_xprt must now use
a form of rcu_dereference().

Special care is taken with rpc_peeraddr2str(), which returns a pointer
to memory whose lifetime is the same as the rpc_xprt.

Signed-off-by: Trond Myklebust <Trond.Myklebust@netapp.com>
[ cel: fix lockdep splats and layering violations ]
[ cel: forward ported to 3.4 ]
[ cel: remove rpc_max_reqs(), add rpc_net_ns() ]
Signed-off-by: Chuck Lever <chuck.lever@oracle.com>
Signed-off-by: Trond Myklebust <Trond.Myklebust@netapp.com>
diff --git a/net/sunrpc/auth_gss/auth_gss.c b/net/sunrpc/auth_gss/auth_gss.c
index cb2e564..d3ad81f 100644
--- a/net/sunrpc/auth_gss/auth_gss.c
+++ b/net/sunrpc/auth_gss/auth_gss.c
@@ -799,7 +799,7 @@
 static void gss_pipes_dentries_destroy_net(struct rpc_clnt *clnt,
 					   struct rpc_auth *auth)
 {
-	struct net *net = clnt->cl_xprt->xprt_net;
+	struct net *net = rpc_net_ns(clnt);
 	struct super_block *sb;
 
 	sb = rpc_get_sb_net(net);
@@ -813,7 +813,7 @@
 static int gss_pipes_dentries_create_net(struct rpc_clnt *clnt,
 					 struct rpc_auth *auth)
 {
-	struct net *net = clnt->cl_xprt->xprt_net;
+	struct net *net = rpc_net_ns(clnt);
 	struct super_block *sb;
 	int err = 0;
 
diff --git a/net/sunrpc/clnt.c b/net/sunrpc/clnt.c
index 25c3da5..7783fc0 100644
--- a/net/sunrpc/clnt.c
+++ b/net/sunrpc/clnt.c
@@ -31,6 +31,7 @@
 #include <linux/in.h>
 #include <linux/in6.h>
 #include <linux/un.h>
+#include <linux/rcupdate.h>
 
 #include <linux/sunrpc/clnt.h>
 #include <linux/sunrpc/rpc_pipe_fs.h>
@@ -81,7 +82,8 @@
 
 static void rpc_register_client(struct rpc_clnt *clnt)
 {
-	struct sunrpc_net *sn = net_generic(clnt->cl_xprt->xprt_net, sunrpc_net_id);
+	struct net *net = rpc_net_ns(clnt);
+	struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
 
 	spin_lock(&sn->rpc_client_lock);
 	list_add(&clnt->cl_clients, &sn->all_clients);
@@ -90,7 +92,8 @@
 
 static void rpc_unregister_client(struct rpc_clnt *clnt)
 {
-	struct sunrpc_net *sn = net_generic(clnt->cl_xprt->xprt_net, sunrpc_net_id);
+	struct net *net = rpc_net_ns(clnt);
+	struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
 
 	spin_lock(&sn->rpc_client_lock);
 	list_del(&clnt->cl_clients);
@@ -109,12 +112,13 @@
 
 static void rpc_clnt_remove_pipedir(struct rpc_clnt *clnt)
 {
+	struct net *net = rpc_net_ns(clnt);
 	struct super_block *pipefs_sb;
 
-	pipefs_sb = rpc_get_sb_net(clnt->cl_xprt->xprt_net);
+	pipefs_sb = rpc_get_sb_net(net);
 	if (pipefs_sb) {
 		__rpc_clnt_remove_pipedir(clnt);
-		rpc_put_sb_net(clnt->cl_xprt->xprt_net);
+		rpc_put_sb_net(net);
 	}
 }
 
@@ -155,17 +159,18 @@
 static int
 rpc_setup_pipedir(struct rpc_clnt *clnt, const char *dir_name)
 {
+	struct net *net = rpc_net_ns(clnt);
 	struct super_block *pipefs_sb;
 	struct dentry *dentry;
 
 	clnt->cl_dentry = NULL;
 	if (dir_name == NULL)
 		return 0;
-	pipefs_sb = rpc_get_sb_net(clnt->cl_xprt->xprt_net);
+	pipefs_sb = rpc_get_sb_net(net);
 	if (!pipefs_sb)
 		return 0;
 	dentry = rpc_setup_pipedir_sb(pipefs_sb, clnt, dir_name);
-	rpc_put_sb_net(clnt->cl_xprt->xprt_net);
+	rpc_put_sb_net(net);
 	if (IS_ERR(dentry))
 		return PTR_ERR(dentry);
 	clnt->cl_dentry = dentry;
@@ -295,7 +300,7 @@
 	if (clnt->cl_server == NULL)
 		goto out_no_server;
 
-	clnt->cl_xprt     = xprt;
+	rcu_assign_pointer(clnt->cl_xprt, xprt);
 	clnt->cl_procinfo = version->procs;
 	clnt->cl_maxproc  = version->nrprocs;
 	clnt->cl_protname = program->name;
@@ -310,7 +315,7 @@
 	INIT_LIST_HEAD(&clnt->cl_tasks);
 	spin_lock_init(&clnt->cl_lock);
 
-	if (!xprt_bound(clnt->cl_xprt))
+	if (!xprt_bound(xprt))
 		clnt->cl_autobind = 1;
 
 	clnt->cl_timeout = xprt->timeout;
@@ -477,6 +482,7 @@
 rpc_clone_client(struct rpc_clnt *clnt)
 {
 	struct rpc_clnt *new;
+	struct rpc_xprt *xprt;
 	int err = -ENOMEM;
 
 	new = kmemdup(clnt, sizeof(*new), GFP_KERNEL);
@@ -499,18 +505,25 @@
 		if (new->cl_principal == NULL)
 			goto out_no_principal;
 	}
+	rcu_read_lock();
+	xprt = xprt_get(rcu_dereference(clnt->cl_xprt));
+	rcu_read_unlock();
+	if (xprt == NULL)
+		goto out_no_transport;
+	rcu_assign_pointer(new->cl_xprt, xprt);
 	atomic_set(&new->cl_count, 1);
 	err = rpc_setup_pipedir(new, clnt->cl_program->pipe_dir_name);
 	if (err != 0)
 		goto out_no_path;
 	if (new->cl_auth)
 		atomic_inc(&new->cl_auth->au_count);
-	xprt_get(clnt->cl_xprt);
 	atomic_inc(&clnt->cl_count);
 	rpc_register_client(new);
 	rpciod_up();
 	return new;
 out_no_path:
+	xprt_put(xprt);
+out_no_transport:
 	kfree(new->cl_principal);
 out_no_principal:
 	rpc_free_iostats(new->cl_metrics);
@@ -590,7 +603,7 @@
 	rpc_free_iostats(clnt->cl_metrics);
 	kfree(clnt->cl_principal);
 	clnt->cl_metrics = NULL;
-	xprt_put(clnt->cl_xprt);
+	xprt_put(rcu_dereference_raw(clnt->cl_xprt));
 	rpciod_down();
 	kfree(clnt);
 }
@@ -879,13 +892,18 @@
 size_t rpc_peeraddr(struct rpc_clnt *clnt, struct sockaddr *buf, size_t bufsize)
 {
 	size_t bytes;
-	struct rpc_xprt *xprt = clnt->cl_xprt;
+	struct rpc_xprt *xprt;
 
-	bytes = sizeof(xprt->addr);
+	rcu_read_lock();
+	xprt = rcu_dereference(clnt->cl_xprt);
+
+	bytes = xprt->addrlen;
 	if (bytes > bufsize)
 		bytes = bufsize;
-	memcpy(buf, &clnt->cl_xprt->addr, bytes);
-	return xprt->addrlen;
+	memcpy(buf, &xprt->addr, bytes);
+	rcu_read_unlock();
+
+	return bytes;
 }
 EXPORT_SYMBOL_GPL(rpc_peeraddr);
 
@@ -894,11 +912,16 @@
  * @clnt: RPC client structure
  * @format: address format
  *
+ * NB: the lifetime of the memory referenced by the returned pointer is
+ * the same as the rpc_xprt itself.  As long as the caller uses this
+ * pointer, it must hold the RCU read lock.
  */
 const char *rpc_peeraddr2str(struct rpc_clnt *clnt,
 			     enum rpc_display_format_t format)
 {
-	struct rpc_xprt *xprt = clnt->cl_xprt;
+	struct rpc_xprt *xprt;
+
+	xprt = rcu_dereference(clnt->cl_xprt);
 
 	if (xprt->address_strings[format] != NULL)
 		return xprt->address_strings[format];
@@ -910,14 +933,51 @@
 void
 rpc_setbufsize(struct rpc_clnt *clnt, unsigned int sndsize, unsigned int rcvsize)
 {
-	struct rpc_xprt *xprt = clnt->cl_xprt;
+	struct rpc_xprt *xprt;
+
+	rcu_read_lock();
+	xprt = rcu_dereference(clnt->cl_xprt);
 	if (xprt->ops->set_buffer_size)
 		xprt->ops->set_buffer_size(xprt, sndsize, rcvsize);
+	rcu_read_unlock();
 }
 EXPORT_SYMBOL_GPL(rpc_setbufsize);
 
-/*
- * Return size of largest payload RPC client can support, in bytes
+/**
+ * rpc_protocol - Get transport protocol number for an RPC client
+ * @clnt: RPC client to query
+ *
+ */
+int rpc_protocol(struct rpc_clnt *clnt)
+{
+	int protocol;
+
+	rcu_read_lock();
+	protocol = rcu_dereference(clnt->cl_xprt)->prot;
+	rcu_read_unlock();
+	return protocol;
+}
+EXPORT_SYMBOL_GPL(rpc_protocol);
+
+/**
+ * rpc_net_ns - Get the network namespace for this RPC client
+ * @clnt: RPC client to query
+ *
+ */
+struct net *rpc_net_ns(struct rpc_clnt *clnt)
+{
+	struct net *ret;
+
+	rcu_read_lock();
+	ret = rcu_dereference(clnt->cl_xprt)->xprt_net;
+	rcu_read_unlock();
+	return ret;
+}
+EXPORT_SYMBOL_GPL(rpc_net_ns);
+
+/**
+ * rpc_max_payload - Get maximum payload size for a transport, in bytes
+ * @clnt: RPC client to query
  *
  * For stream transports, this is one RPC record fragment (see RFC
  * 1831), as we don't support multi-record requests yet.  For datagram
@@ -926,7 +986,12 @@
  */
 size_t rpc_max_payload(struct rpc_clnt *clnt)
 {
-	return clnt->cl_xprt->max_payload;
+	size_t ret;
+
+	rcu_read_lock();
+	ret = rcu_dereference(clnt->cl_xprt)->max_payload;
+	rcu_read_unlock();
+	return ret;
 }
 EXPORT_SYMBOL_GPL(rpc_max_payload);
 
@@ -937,8 +1002,11 @@
  */
 void rpc_force_rebind(struct rpc_clnt *clnt)
 {
-	if (clnt->cl_autobind)
-		xprt_clear_bound(clnt->cl_xprt);
+	if (clnt->cl_autobind) {
+		rcu_read_lock();
+		xprt_clear_bound(rcu_dereference(clnt->cl_xprt));
+		rcu_read_unlock();
+	}
 }
 EXPORT_SYMBOL_GPL(rpc_force_rebind);
 
diff --git a/net/sunrpc/rpc_pipe.c b/net/sunrpc/rpc_pipe.c
index ac9ee15..3d30943 100644
--- a/net/sunrpc/rpc_pipe.c
+++ b/net/sunrpc/rpc_pipe.c
@@ -16,6 +16,7 @@
 #include <linux/namei.h>
 #include <linux/fsnotify.h>
 #include <linux/kernel.h>
+#include <linux/rcupdate.h>
 
 #include <asm/ioctls.h>
 #include <linux/poll.h>
@@ -402,12 +403,14 @@
 {
 	struct rpc_clnt *clnt = m->private;
 
+	rcu_read_lock();
 	seq_printf(m, "RPC server: %s\n", clnt->cl_server);
 	seq_printf(m, "service: %s (%d) version %d\n", clnt->cl_protname,
 			clnt->cl_prog, clnt->cl_vers);
 	seq_printf(m, "address: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_ADDR));
 	seq_printf(m, "protocol: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_PROTO));
 	seq_printf(m, "port: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_PORT));
+	rcu_read_unlock();
 	return 0;
 }
 
diff --git a/net/sunrpc/rpcb_clnt.c b/net/sunrpc/rpcb_clnt.c
index b1f08bd..4f8af63 100644
--- a/net/sunrpc/rpcb_clnt.c
+++ b/net/sunrpc/rpcb_clnt.c
@@ -620,9 +620,10 @@
 static struct rpc_clnt *rpcb_find_transport_owner(struct rpc_clnt *clnt)
 {
 	struct rpc_clnt *parent = clnt->cl_parent;
+	struct rpc_xprt *xprt = rcu_dereference(clnt->cl_xprt);
 
 	while (parent != clnt) {
-		if (parent->cl_xprt != clnt->cl_xprt)
+		if (rcu_dereference(parent->cl_xprt) != xprt)
 			break;
 		if (clnt->cl_autobind)
 			break;
@@ -653,8 +654,12 @@
 	size_t salen;
 	int status;
 
-	clnt = rpcb_find_transport_owner(task->tk_client);
-	xprt = clnt->cl_xprt;
+	rcu_read_lock();
+	do {
+		clnt = rpcb_find_transport_owner(task->tk_client);
+		xprt = xprt_get(rcu_dereference(clnt->cl_xprt));
+	} while (xprt == NULL);
+	rcu_read_unlock();
 
 	dprintk("RPC: %5u %s(%s, %u, %u, %d)\n",
 		task->tk_pid, __func__,
@@ -667,6 +672,7 @@
 	if (xprt_test_and_set_binding(xprt)) {
 		dprintk("RPC: %5u %s: waiting for another binder\n",
 			task->tk_pid, __func__);
+		xprt_put(xprt);
 		return;
 	}
 
@@ -734,7 +740,7 @@
 	switch (bind_version) {
 	case RPCBVERS_4:
 	case RPCBVERS_3:
-		map->r_netid = rpc_peeraddr2str(clnt, RPC_DISPLAY_NETID);
+		map->r_netid = xprt->address_strings[RPC_DISPLAY_NETID];
 		map->r_addr = rpc_sockaddr2uaddr(sap, GFP_ATOMIC);
 		map->r_owner = "";
 		break;
@@ -763,6 +769,7 @@
 bailout_nofree:
 	rpcb_wake_rpcbind_waiters(xprt, status);
 	task->tk_status = status;
+	xprt_put(xprt);
 }
 EXPORT_SYMBOL_GPL(rpcb_getport_async);
 
diff --git a/net/sunrpc/stats.c b/net/sunrpc/stats.c
index 1eb3304..bc2068e 100644
--- a/net/sunrpc/stats.c
+++ b/net/sunrpc/stats.c
@@ -22,6 +22,7 @@
 #include <linux/sunrpc/clnt.h>
 #include <linux/sunrpc/svcsock.h>
 #include <linux/sunrpc/metrics.h>
+#include <linux/rcupdate.h>
 
 #include "netns.h"
 
@@ -179,7 +180,7 @@
 void rpc_print_iostats(struct seq_file *seq, struct rpc_clnt *clnt)
 {
 	struct rpc_iostats *stats = clnt->cl_metrics;
-	struct rpc_xprt *xprt = clnt->cl_xprt;
+	struct rpc_xprt *xprt;
 	unsigned int op, maxproc = clnt->cl_maxproc;
 
 	if (!stats)
@@ -189,8 +190,11 @@
 	seq_printf(seq, "p/v: %u/%u (%s)\n",
 			clnt->cl_prog, clnt->cl_vers, clnt->cl_protname);
 
+	rcu_read_lock();
+	xprt = rcu_dereference(clnt->cl_xprt);
 	if (xprt)
 		xprt->ops->print_stats(xprt, seq);
+	rcu_read_unlock();
 
 	seq_printf(seq, "\tper-op statistics\n");
 	for (op = 0; op < maxproc; op++) {