netfilter: nf_conncount: move all list iterations under spinlock

Two CPUs may race to remove a connection from the list, the existing
conn->dead will result in a use-after-free. Use the per-list spinlock to
protect list iterations.

As all accesses to the list now happen while holding the per-list lock,
we no longer need to delay free operations with rcu.

Joint work with Florian.

Fixes: 5c789e131cbb9 ("netfilter: nf_conncount: Add list lock and gc worker, and RCU for init tree search")
Reviewed-by: Shawn Bohrer <sbohrer@cloudflare.com>
Signed-off-by: Florian Westphal <fw@strlen.de>
Signed-off-by: Pablo Neira Ayuso <pablo@netfilter.org>
diff --git a/net/netfilter/nf_conncount.c b/net/netfilter/nf_conncount.c
index ce7f7d1..d0fd195 100644
--- a/net/netfilter/nf_conncount.c
+++ b/net/netfilter/nf_conncount.c
@@ -43,8 +43,6 @@ struct nf_conncount_tuple {
 	struct nf_conntrack_zone	zone;
 	int				cpu;
 	u32				jiffies32;
-	bool				dead;
-	struct rcu_head			rcu_head;
 };
 
 struct nf_conncount_rb {
@@ -83,36 +81,21 @@ static int key_diff(const u32 *a, const u32 *b, unsigned int klen)
 	return memcmp(a, b, klen * sizeof(u32));
 }
 
-static void __conn_free(struct rcu_head *h)
-{
-	struct nf_conncount_tuple *conn;
-
-	conn = container_of(h, struct nf_conncount_tuple, rcu_head);
-	kmem_cache_free(conncount_conn_cachep, conn);
-}
-
 static bool conn_free(struct nf_conncount_list *list,
 		      struct nf_conncount_tuple *conn)
 {
 	bool free_entry = false;
 
-	spin_lock_bh(&list->list_lock);
-
-	if (conn->dead) {
-		spin_unlock_bh(&list->list_lock);
-		return free_entry;
-	}
+	lockdep_assert_held(&list->list_lock);
 
 	list->count--;
-	conn->dead = true;
-	list_del_rcu(&conn->node);
+	list_del(&conn->node);
 	if (list->count == 0) {
 		list->dead = true;
 		free_entry = true;
 	}
 
-	spin_unlock_bh(&list->list_lock);
-	call_rcu(&conn->rcu_head, __conn_free);
+	kmem_cache_free(conncount_conn_cachep, conn);
 	return free_entry;
 }
 
@@ -242,7 +225,7 @@ void nf_conncount_list_init(struct nf_conncount_list *list)
 }
 EXPORT_SYMBOL_GPL(nf_conncount_list_init);
 
-/* Return true if the list is empty */
+/* Return true if the list is empty. Must be called with BH disabled. */
 bool nf_conncount_gc_list(struct net *net,
 			  struct nf_conncount_list *list)
 {
@@ -253,12 +236,18 @@ bool nf_conncount_gc_list(struct net *net,
 	bool free_entry = false;
 	bool ret = false;
 
+	/* don't bother if other cpu is already doing GC */
+	if (!spin_trylock(&list->list_lock))
+		return false;
+
 	list_for_each_entry_safe(conn, conn_n, &list->head, node) {
 		found = find_or_evict(net, list, conn, &free_entry);
 		if (IS_ERR(found)) {
 			if (PTR_ERR(found) == -ENOENT)  {
-				if (free_entry)
+				if (free_entry) {
+					spin_unlock(&list->list_lock);
 					return true;
+				}
 				collected++;
 			}
 			continue;
@@ -271,23 +260,24 @@ bool nf_conncount_gc_list(struct net *net,
 			 * closed already -> ditch it
 			 */
 			nf_ct_put(found_ct);
-			if (conn_free(list, conn))
+			if (conn_free(list, conn)) {
+				spin_unlock(&list->list_lock);
 				return true;
+			}
 			collected++;
 			continue;
 		}
 
 		nf_ct_put(found_ct);
 		if (collected > CONNCOUNT_GC_MAX_NODES)
-			return false;
+			break;
 	}
 
-	spin_lock_bh(&list->list_lock);
 	if (!list->count) {
 		list->dead = true;
 		ret = true;
 	}
-	spin_unlock_bh(&list->list_lock);
+	spin_unlock(&list->list_lock);
 
 	return ret;
 }
@@ -478,6 +468,7 @@ static void tree_gc_worker(struct work_struct *work)
 	tree = data->gc_tree % CONNCOUNT_SLOTS;
 	root = &data->root[tree];
 
+	local_bh_disable();
 	rcu_read_lock();
 	for (node = rb_first(root); node != NULL; node = rb_next(node)) {
 		rbconn = rb_entry(node, struct nf_conncount_rb, node);
@@ -485,6 +476,9 @@ static void tree_gc_worker(struct work_struct *work)
 			gc_count++;
 	}
 	rcu_read_unlock();
+	local_bh_enable();
+
+	cond_resched();
 
 	spin_lock_bh(&nf_conncount_locks[tree]);
 	if (gc_count < ARRAY_SIZE(gc_nodes))