fib_trie: Push tnode flushing down to inflate/halve

This change pushes the tnode freeing down into the inflate and halve
functions.  It makes more sense here as we have a better grasp of what is
going on and when a given cluster of nodes is ready to be freed.

I believe this may address a bug in the freeing logic as well.  For some
reason if the freelist got to a certain size we would call
synchronize_rcu().  I'm assuming that what they meant to do is call
synchronize_rcu() after they had handed off that much memory via
call_rcu().  As such that is what I have updated the behavior to be.

Signed-off-by: Alexander Duyck <alexander.h.duyck@redhat.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv4/fib_trie.c b/net/ipv4/fib_trie.c
index 7fbd2c5..485983e 100644
--- a/net/ipv4/fib_trie.c
+++ b/net/ipv4/fib_trie.c
@@ -147,8 +147,6 @@
 };
 
 static void resize(struct trie *t, struct tnode *tn);
-/* tnodes to free after resize(); protected by RTNL */
-static struct callback_head *tnode_free_head;
 static size_t tnode_free_size;
 
 /*
@@ -307,32 +305,6 @@
 		return vzalloc(size);
 }
 
-static void tnode_free_safe(struct tnode *tn)
-{
-	BUG_ON(IS_LEAF(tn));
-	tn->rcu.next = tnode_free_head;
-	tnode_free_head = &tn->rcu;
-}
-
-static void tnode_free_flush(void)
-{
-	struct callback_head *head;
-
-	while ((head = tnode_free_head)) {
-		struct tnode *tn = container_of(head, struct tnode, rcu);
-
-		tnode_free_head = head->next;
-		tnode_free_size += offsetof(struct tnode, child[1 << tn->bits]);
-
-		node_free(tn);
-	}
-
-	if (tnode_free_size >= PAGE_SIZE * sync_pages) {
-		tnode_free_size = 0;
-		synchronize_rcu();
-	}
-}
-
 static struct tnode *leaf_new(t_key key)
 {
 	struct tnode *l = kmem_cache_alloc(trie_leaf_kmem, GFP_KERNEL);
@@ -433,17 +405,33 @@
 		rcu_assign_pointer(t->trie, n);
 }
 
-static void tnode_clean_free(struct tnode *tn)
+static inline void tnode_free_init(struct tnode *tn)
 {
-	struct tnode *tofree;
-	unsigned long i;
+	tn->rcu.next = NULL;
+}
 
-	for (i = 0; i < tnode_child_length(tn); i++) {
-		tofree = tnode_get_child(tn, i);
-		if (tofree)
-			node_free(tofree);
+static inline void tnode_free_append(struct tnode *tn, struct tnode *n)
+{
+	n->rcu.next = tn->rcu.next;
+	tn->rcu.next = &n->rcu;
+}
+
+static void tnode_free(struct tnode *tn)
+{
+	struct callback_head *head = &tn->rcu;
+
+	while (head) {
+		head = head->next;
+		tnode_free_size += offsetof(struct tnode, child[1 << tn->bits]);
+		node_free(tn);
+
+		tn = container_of(head, struct tnode, rcu);
 	}
-	node_free(tn);
+
+	if (tnode_free_size >= PAGE_SIZE * sync_pages) {
+		tnode_free_size = 0;
+		synchronize_rcu();
+	}
 }
 
 static int inflate(struct trie *t, struct tnode *oldtnode)
@@ -476,20 +464,23 @@
 					 inode->bits - 1);
 			if (!left)
 				goto nomem;
+			tnode_free_append(tn, left);
 
 			right = tnode_new(inode->key | m, inode->pos,
 					  inode->bits - 1);
 
-			if (!right) {
-				node_free(left);
+			if (!right)
 				goto nomem;
-			}
+			tnode_free_append(tn, right);
 
 			put_child(tn, 2*i, left);
 			put_child(tn, 2*i+1, right);
 		}
 	}
 
+	/* prepare oldtnode to be freed */
+	tnode_free_init(oldtnode);
+
 	for (i = 0; i < olen; i++) {
 		struct tnode *inode = tnode_get_child(oldtnode, i);
 		struct tnode *left, *right;
@@ -505,12 +496,13 @@
 			continue;
 		}
 
+		/* drop the node in the old tnode free list */
+		tnode_free_append(oldtnode, inode);
+
 		/* An internal node with two children */
 		if (inode->bits == 1) {
 			put_child(tn, 2*i, rtnl_dereference(inode->child[0]));
 			put_child(tn, 2*i+1, rtnl_dereference(inode->child[1]));
-
-			tnode_free_safe(inode);
 			continue;
 		}
 
@@ -556,17 +548,19 @@
 		put_child(tn, 2 * i, left);
 		put_child(tn, 2 * i + 1, right);
 
-		tnode_free_safe(inode);
-
+		/* resize child nodes */
 		resize(t, left);
 		resize(t, right);
 	}
 
 	put_child_root(tp, t, tn->key, tn);
-	tnode_free_safe(oldtnode);
+
+	/* we completed without error, prepare to free old node */
+	tnode_free(oldtnode);
 	return 0;
 nomem:
-	tnode_clean_free(tn);
+	/* all pointers should be clean so we are done */
+	tnode_free(tn);
 	return -ENOMEM;
 }
 
@@ -599,17 +593,20 @@
 			struct tnode *newn;
 
 			newn = tnode_new(left->key, oldtnode->pos, 1);
-
 			if (!newn) {
-				tnode_clean_free(tn);
+				tnode_free(tn);
 				return -ENOMEM;
 			}
+			tnode_free_append(tn, newn);
 
 			put_child(tn, i/2, newn);
 		}
 
 	}
 
+	/* prepare oldtnode to be freed */
+	tnode_free_init(oldtnode);
+
 	for (i = 0; i < olen; i += 2) {
 		struct tnode *newBinNode;
 
@@ -636,11 +633,14 @@
 
 		put_child(tn, i / 2, newBinNode);
 
+		/* resize child node */
 		resize(t, newBinNode);
 	}
 
 	put_child_root(tp, t, tn->key, tn);
-	tnode_free_safe(oldtnode);
+
+	/* all pointers should be clean so we are done */
+	tnode_free(oldtnode);
 
 	return 0;
 }
@@ -798,7 +798,8 @@
 		node_set_parent(n, tp);
 
 		/* drop dead node */
-		tnode_free_safe(tn);
+		tnode_free_init(tn);
+		tnode_free(tn);
 	}
 }
 
@@ -884,16 +885,12 @@
 
 	while ((tp = node_parent(tn)) != NULL) {
 		resize(t, tn);
-
-		tnode_free_flush();
 		tn = tp;
 	}
 
 	/* Handle last (top) tnode */
 	if (IS_TNODE(tn))
 		resize(t, tn);
-
-	tnode_free_flush();
 }
 
 /* only used from updater-side */