ipmr: RCU conversion of mroute_sk

Use RCU and RTNL to protect (struct mr_table)->mroute_sk

Readers use RCU, writers use RTNL.

ip_ra_control() already use an RCU grace period before
ip_ra_destroy_rcu(), so we dont need synchronize_rcu() in
mrtsock_destruct()

Signed-off-by: Eric Dumazet <eric.dumazet@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv4/ipmr.c b/net/ipv4/ipmr.c
index 1a92ebd..e2db2ea 100644
--- a/net/ipv4/ipmr.c
+++ b/net/ipv4/ipmr.c
@@ -75,7 +75,7 @@
 	struct net		*net;
 #endif
 	u32			id;
-	struct sock		*mroute_sk;
+	struct sock __rcu	*mroute_sk;
 	struct timer_list	ipmr_expire_timer;
 	struct list_head	mfc_unres_queue;
 	struct list_head	mfc_cache_array[MFC_LINES];
@@ -867,6 +867,7 @@
 	const int ihl = ip_hdrlen(pkt);
 	struct igmphdr *igmp;
 	struct igmpmsg *msg;
+	struct sock *mroute_sk;
 	int ret;
 
 #ifdef CONFIG_IP_PIMSM
@@ -925,7 +926,10 @@
 	skb->transport_header = skb->network_header;
 	}
 
-	if (mrt->mroute_sk == NULL) {
+	rcu_read_lock();
+	mroute_sk = rcu_dereference(mrt->mroute_sk);
+	if (mroute_sk == NULL) {
+		rcu_read_unlock();
 		kfree_skb(skb);
 		return -EINVAL;
 	}
@@ -933,7 +937,8 @@
 	/*
 	 *	Deliver to mrouted
 	 */
-	ret = sock_queue_rcv_skb(mrt->mroute_sk, skb);
+	ret = sock_queue_rcv_skb(mroute_sk, skb);
+	rcu_read_unlock();
 	if (ret < 0) {
 		if (net_ratelimit())
 			printk(KERN_WARNING "mroute: pending queue full, dropping entries.\n");
@@ -1164,6 +1169,9 @@
 	}
 }
 
+/* called from ip_ra_control(), before an RCU grace period,
+ * we dont need to call synchronize_rcu() here
+ */
 static void mrtsock_destruct(struct sock *sk)
 {
 	struct net *net = sock_net(sk);
@@ -1171,13 +1179,9 @@
 
 	rtnl_lock();
 	ipmr_for_each_table(mrt, net) {
-		if (sk == mrt->mroute_sk) {
+		if (sk == rtnl_dereference(mrt->mroute_sk)) {
 			IPV4_DEVCONF_ALL(net, MC_FORWARDING)--;
-
-			write_lock_bh(&mrt_lock);
-			mrt->mroute_sk = NULL;
-			write_unlock_bh(&mrt_lock);
-
+			rcu_assign_pointer(mrt->mroute_sk, NULL);
 			mroute_clean_tables(mrt);
 		}
 	}
@@ -1204,7 +1208,8 @@
 		return -ENOENT;
 
 	if (optname != MRT_INIT) {
-		if (sk != mrt->mroute_sk && !capable(CAP_NET_ADMIN))
+		if (sk != rcu_dereference_raw(mrt->mroute_sk) &&
+		    !capable(CAP_NET_ADMIN))
 			return -EACCES;
 	}
 
@@ -1217,23 +1222,20 @@
 			return -ENOPROTOOPT;
 
 		rtnl_lock();
-		if (mrt->mroute_sk) {
+		if (rtnl_dereference(mrt->mroute_sk)) {
 			rtnl_unlock();
 			return -EADDRINUSE;
 		}
 
 		ret = ip_ra_control(sk, 1, mrtsock_destruct);
 		if (ret == 0) {
-			write_lock_bh(&mrt_lock);
-			mrt->mroute_sk = sk;
-			write_unlock_bh(&mrt_lock);
-
+			rcu_assign_pointer(mrt->mroute_sk, sk);
 			IPV4_DEVCONF_ALL(net, MC_FORWARDING)++;
 		}
 		rtnl_unlock();
 		return ret;
 	case MRT_DONE:
-		if (sk != mrt->mroute_sk)
+		if (sk != rcu_dereference_raw(mrt->mroute_sk))
 			return -EACCES;
 		return ip_ra_control(sk, 0, NULL);
 	case MRT_ADD_VIF:
@@ -1246,7 +1248,8 @@
 			return -ENFILE;
 		rtnl_lock();
 		if (optname == MRT_ADD_VIF) {
-			ret = vif_add(net, mrt, &vif, sk == mrt->mroute_sk);
+			ret = vif_add(net, mrt, &vif,
+				      sk == rtnl_dereference(mrt->mroute_sk));
 		} else {
 			ret = vif_delete(mrt, vif.vifc_vifi, 0, NULL);
 		}
@@ -1267,7 +1270,8 @@
 		if (optname == MRT_DEL_MFC)
 			ret = ipmr_mfc_delete(mrt, &mfc);
 		else
-			ret = ipmr_mfc_add(net, mrt, &mfc, sk == mrt->mroute_sk);
+			ret = ipmr_mfc_add(net, mrt, &mfc,
+					   sk == rtnl_dereference(mrt->mroute_sk));
 		rtnl_unlock();
 		return ret;
 		/*
@@ -1309,14 +1313,16 @@
 			return -EINVAL;
 		if (get_user(v, (u32 __user *)optval))
 			return -EFAULT;
-		if (sk == mrt->mroute_sk)
-			return -EBUSY;
 
 		rtnl_lock();
 		ret = 0;
-		if (!ipmr_new_table(net, v))
-			ret = -ENOMEM;
-		raw_sk(sk)->ipmr_table = v;
+		if (sk == rtnl_dereference(mrt->mroute_sk)) {
+			ret = -EBUSY;
+		} else {
+			if (!ipmr_new_table(net, v))
+				ret = -ENOMEM;
+			raw_sk(sk)->ipmr_table = v;
+		}
 		rtnl_unlock();
 		return ret;
 	}
@@ -1713,6 +1719,7 @@
 
 /*
  *	Multicast packets for forwarding arrive here
+ *	Called with rcu_read_lock();
  */
 
 int ip_mr_input(struct sk_buff *skb)
@@ -1726,7 +1733,7 @@
 	/* Packet is looped back after forward, it should not be
 	   forwarded second time, but still can be delivered locally.
 	 */
-	if (IPCB(skb)->flags&IPSKB_FORWARDED)
+	if (IPCB(skb)->flags & IPSKB_FORWARDED)
 		goto dont_forward;
 
 	err = ipmr_fib_lookup(net, &skb_rtable(skb)->fl, &mrt);
@@ -1736,24 +1743,24 @@
 	}
 
 	if (!local) {
-		    if (IPCB(skb)->opt.router_alert) {
-			    if (ip_call_ra_chain(skb))
-				    return 0;
-		    } else if (ip_hdr(skb)->protocol == IPPROTO_IGMP){
-			    /* IGMPv1 (and broken IGMPv2 implementations sort of
-			       Cisco IOS <= 11.2(8)) do not put router alert
-			       option to IGMP packets destined to routable
-			       groups. It is very bad, because it means
-			       that we can forward NO IGMP messages.
-			     */
-			    read_lock(&mrt_lock);
-			    if (mrt->mroute_sk) {
-				    nf_reset(skb);
-				    raw_rcv(mrt->mroute_sk, skb);
-				    read_unlock(&mrt_lock);
-				    return 0;
-			    }
-			    read_unlock(&mrt_lock);
+		if (IPCB(skb)->opt.router_alert) {
+			if (ip_call_ra_chain(skb))
+				return 0;
+		} else if (ip_hdr(skb)->protocol == IPPROTO_IGMP) {
+			/* IGMPv1 (and broken IGMPv2 implementations sort of
+			 * Cisco IOS <= 11.2(8)) do not put router alert
+			 * option to IGMP packets destined to routable
+			 * groups. It is very bad, because it means
+			 * that we can forward NO IGMP messages.
+			 */
+			struct sock *mroute_sk;
+
+			mroute_sk = rcu_dereference(mrt->mroute_sk);
+			if (mroute_sk) {
+				nf_reset(skb);
+				raw_rcv(mroute_sk, skb);
+				return 0;
+			}
 		    }
 	}