udp: Use hlist_nulls in UDP RCU code

This is a straightforward patch, using hlist_nulls infrastructure.

RCUification already done on UDP two weeks ago.

Using hlist_nulls permits us to avoid some memory barriers, both
at lookup time and delete time.

Patch is large because it adds new macros to include/net/sock.h.
These macros will be used by TCP & DCCP in next patch.

Signed-off-by: Eric Dumazet <dada1@cosmosbay.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 54badc9..fea2d87 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -127,9 +127,9 @@
 						 const struct sock *sk2))
 {
 	struct sock *sk2;
-	struct hlist_node *node;
+	struct hlist_nulls_node *node;
 
-	sk_for_each(sk2, node, &hslot->head)
+	sk_nulls_for_each(sk2, node, &hslot->head)
 		if (net_eq(sock_net(sk2), net)			&&
 		    sk2 != sk					&&
 		    sk2->sk_hash == num				&&
@@ -189,12 +189,7 @@
 	inet_sk(sk)->num = snum;
 	sk->sk_hash = snum;
 	if (sk_unhashed(sk)) {
-		/*
-		 * We need that previous write to sk->sk_hash committed
-		 * before write to sk->next done in following add_node() variant
-		 */
-		smp_wmb();
-		sk_add_node_rcu(sk, &hslot->head);
+		sk_nulls_add_node_rcu(sk, &hslot->head);
 		sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
 	}
 	error = 0;
@@ -261,7 +256,7 @@
 		int dif, struct udp_table *udptable)
 {
 	struct sock *sk, *result;
-	struct hlist_node *node, *next;
+	struct hlist_nulls_node *node;
 	unsigned short hnum = ntohs(dport);
 	unsigned int hash = udp_hashfn(net, hnum);
 	struct udp_hslot *hslot = &udptable->hash[hash];
@@ -271,13 +266,7 @@
 begin:
 	result = NULL;
 	badness = -1;
-	sk_for_each_rcu_safenext(sk, node, &hslot->head, next) {
-		/*
-		 * lockless reader, and SLAB_DESTROY_BY_RCU items:
-		 * We must check this item was not moved to another chain
-		 */
-		if (udp_hashfn(net, sk->sk_hash) != hash)
-			goto begin;
+	sk_nulls_for_each_rcu(sk, node, &hslot->head) {
 		score = compute_score(sk, net, saddr, hnum, sport,
 				      daddr, dport, dif);
 		if (score > badness) {
@@ -285,6 +274,14 @@
 			badness = score;
 		}
 	}
+	/*
+	 * if the nulls value we got at the end of this lookup is
+	 * not the expected one, we must restart lookup.
+	 * We probably met an item that was moved to another chain.
+	 */
+	if (get_nulls_value(node) != hash)
+		goto begin;
+
 	if (result) {
 		if (unlikely(!atomic_inc_not_zero(&result->sk_refcnt)))
 			result = NULL;
@@ -325,11 +322,11 @@
 					     __be16 rmt_port, __be32 rmt_addr,
 					     int dif)
 {
-	struct hlist_node *node;
+	struct hlist_nulls_node *node;
 	struct sock *s = sk;
 	unsigned short hnum = ntohs(loc_port);
 
-	sk_for_each_from(s, node) {
+	sk_nulls_for_each_from(s, node) {
 		struct inet_sock *inet = inet_sk(s);
 
 		if (!net_eq(sock_net(s), net)				||
@@ -977,7 +974,7 @@
 	struct udp_hslot *hslot = &udptable->hash[hash];
 
 	spin_lock_bh(&hslot->lock);
-	if (sk_del_node_init_rcu(sk)) {
+	if (sk_nulls_del_node_init_rcu(sk)) {
 		inet_sk(sk)->num = 0;
 		sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
 	}
@@ -1130,7 +1127,7 @@
 	int dif;
 
 	spin_lock(&hslot->lock);
-	sk = sk_head(&hslot->head);
+	sk = sk_nulls_head(&hslot->head);
 	dif = skb->dev->ifindex;
 	sk = udp_v4_mcast_next(net, sk, uh->dest, daddr, uh->source, saddr, dif);
 	if (sk) {
@@ -1139,7 +1136,7 @@
 		do {
 			struct sk_buff *skb1 = skb;
 
-			sknext = udp_v4_mcast_next(net, sk_next(sk), uh->dest,
+			sknext = udp_v4_mcast_next(net, sk_nulls_next(sk), uh->dest,
 						   daddr, uh->source, saddr,
 						   dif);
 			if (sknext)
@@ -1560,10 +1557,10 @@
 	struct net *net = seq_file_net(seq);
 
 	for (state->bucket = start; state->bucket < UDP_HTABLE_SIZE; ++state->bucket) {
-		struct hlist_node *node;
+		struct hlist_nulls_node *node;
 		struct udp_hslot *hslot = &state->udp_table->hash[state->bucket];
 		spin_lock_bh(&hslot->lock);
-		sk_for_each(sk, node, &hslot->head) {
+		sk_nulls_for_each(sk, node, &hslot->head) {
 			if (!net_eq(sock_net(sk), net))
 				continue;
 			if (sk->sk_family == state->family)
@@ -1582,7 +1579,7 @@
 	struct net *net = seq_file_net(seq);
 
 	do {
-		sk = sk_next(sk);
+		sk = sk_nulls_next(sk);
 	} while (sk && (!net_eq(sock_net(sk), net) || sk->sk_family != state->family));
 
 	if (!sk) {
@@ -1753,7 +1750,7 @@
 	int i;
 
 	for (i = 0; i < UDP_HTABLE_SIZE; i++) {
-		INIT_HLIST_HEAD(&table->hash[i].head);
+		INIT_HLIST_NULLS_HEAD(&table->hash[i].head, i);
 		spin_lock_init(&table->hash[i].lock);
 	}
 }