afs: Fix server reaping

Fix server reaping and make sure it's all done before we start trying to
purge cells, given that servers currently pin cells.

Signed-off-by: David Howells <dhowells@redhat.com>
diff --git a/fs/afs/server.c b/fs/afs/server.c
index e47fd9b..33aeb52 100644
--- a/fs/afs/server.c
+++ b/fs/afs/server.c
@@ -15,6 +15,25 @@
 
 static unsigned afs_server_timeout = 10;	/* server timeout in seconds */
 
+static void afs_inc_servers_outstanding(struct afs_net *net)
+{
+	atomic_inc(&net->servers_outstanding);
+}
+
+static void afs_dec_servers_outstanding(struct afs_net *net)
+{
+	if (atomic_dec_and_test(&net->servers_outstanding))
+		wake_up_atomic_t(&net->servers_outstanding);
+}
+
+void afs_server_timer(struct timer_list *timer)
+{
+	struct afs_net *net = container_of(timer, struct afs_net, server_timer);
+
+	if (!queue_work(afs_wq, &net->server_reaper))
+		afs_dec_servers_outstanding(net);
+}
+
 /*
  * install a server record in the master tree
  */
@@ -81,6 +100,7 @@ static struct afs_server *afs_alloc_server(struct afs_cell *cell,
 
 		memcpy(&server->addr, addr, sizeof(struct in_addr));
 		server->addr.s_addr = addr->s_addr;
+		afs_inc_servers_outstanding(cell->net);
 		_leave(" = %p{%d}", server, atomic_read(&server->usage));
 	} else {
 		_leave(" = NULL [nomem]");
@@ -159,6 +179,7 @@ struct afs_server *afs_lookup_server(struct afs_cell *cell,
 server_in_two_cells:
 	write_unlock(&cell->servers_lock);
 	kfree(candidate);
+	afs_dec_servers_outstanding(cell->net);
 	printk(KERN_NOTICE "kAFS: Server %pI4 appears to be in two cells\n",
 	       addr);
 	_leave(" = -EEXIST");
@@ -208,6 +229,18 @@ struct afs_server *afs_find_server(struct afs_net *net,
 	return server;
 }
 
+static void afs_set_server_timer(struct afs_net *net, time64_t delay)
+{
+	afs_inc_servers_outstanding(net);
+	if (net->live) {
+		if (timer_reduce(&net->server_timer, jiffies + delay * HZ))
+			afs_dec_servers_outstanding(net);
+	} else {
+		if (!queue_work(afs_wq, &net->server_reaper))
+			afs_dec_servers_outstanding(net);
+	}
+}
+
 /*
  * destroy a server record
  * - removes from the cell list
@@ -236,8 +269,7 @@ void afs_put_server(struct afs_server *server)
 	if (atomic_read(&server->usage) == 0) {
 		list_move_tail(&server->grave, &net->server_graveyard);
 		server->time_of_death = ktime_get_real_seconds();
-		queue_delayed_work(afs_wq, &net->server_reaper,
-				   net->live ? afs_server_timeout * HZ : 0);
+		afs_set_server_timer(net, afs_server_timeout);
 	}
 	spin_unlock(&net->server_graveyard_lock);
 	_leave(" [dead]");
@@ -246,7 +278,7 @@ void afs_put_server(struct afs_server *server)
 /*
  * destroy a dead server
  */
-static void afs_destroy_server(struct afs_server *server)
+static void afs_destroy_server(struct afs_net *net, struct afs_server *server)
 {
 	_enter("%p", server);
 
@@ -260,6 +292,7 @@ static void afs_destroy_server(struct afs_server *server)
 
 	afs_put_cell(server->cell);
 	kfree(server);
+	afs_dec_servers_outstanding(net);
 }
 
 /*
@@ -269,7 +302,7 @@ void afs_reap_server(struct work_struct *work)
 {
 	LIST_HEAD(corpses);
 	struct afs_server *server;
-	struct afs_net *net = container_of(work, struct afs_net, server_reaper.work);
+	struct afs_net *net = container_of(work, struct afs_net, server_reaper);
 	unsigned long delay, expiry;
 	time64_t now;
 
@@ -284,8 +317,8 @@ void afs_reap_server(struct work_struct *work)
 		if (net->live) {
 			expiry = server->time_of_death + afs_server_timeout;
 			if (expiry > now) {
-				delay = (expiry - now) * HZ;
-				mod_delayed_work(afs_wq, &net->server_reaper, delay);
+				delay = (expiry - now);
+				afs_set_server_timer(net, delay);
 				break;
 			}
 		}
@@ -309,8 +342,10 @@ void afs_reap_server(struct work_struct *work)
 	while (!list_empty(&corpses)) {
 		server = list_entry(corpses.next, struct afs_server, grave);
 		list_del(&server->grave);
-		afs_destroy_server(server);
+		afs_destroy_server(net, server);
 	}
+
+	afs_dec_servers_outstanding(net);
 }
 
 /*
@@ -319,5 +354,13 @@ void afs_reap_server(struct work_struct *work)
  */
 void __net_exit afs_purge_servers(struct afs_net *net)
 {
-	mod_delayed_work(afs_wq, &net->server_reaper, 0);
+	if (del_timer_sync(&net->server_timer))
+		atomic_dec(&net->servers_outstanding);
+
+	afs_inc_servers_outstanding(net);
+	if (!queue_work(afs_wq, &net->server_reaper))
+		afs_dec_servers_outstanding(net);
+
+	wait_on_atomic_t(&net->servers_outstanding, atomic_t_wait,
+			 TASK_UNINTERRUPTIBLE);
 }