afs: Make vnode->cb_interest RCU safe

Use RCU-based freeing for afs_cb_interest struct objects and use RCU on
vnode->cb_interest.  Use that change to allow afs_check_validity() to use
read_seqbegin_or_lock() instead of read_seqlock_excl().

This also requires the caller of afs_check_validity() to hold the RCU read
lock across the call.

Signed-off-by: David Howells <dhowells@redhat.com>
diff --git a/fs/afs/inode.c b/fs/afs/inode.c
index 5c95da7..ba35b48 100644
--- a/fs/afs/inode.c
+++ b/fs/afs/inode.c
@@ -139,9 +139,10 @@ static int afs_inode_init_from_status(struct afs_vnode *vnode, struct key *key,
 		vnode->cb_expires_at = ktime_get_real_seconds();
 	} else {
 		vnode->cb_expires_at = scb->callback.expires_at;
-		old_cbi = vnode->cb_interest;
+		old_cbi = rcu_dereference_protected(vnode->cb_interest,
+						    lockdep_is_held(&vnode->cb_lock.lock));
 		if (cbi != old_cbi)
-			vnode->cb_interest = afs_get_cb_interest(cbi);
+			rcu_assign_pointer(vnode->cb_interest, afs_get_cb_interest(cbi));
 		else
 			old_cbi = NULL;
 		set_bit(AFS_VNODE_CB_PROMISED, &vnode->flags);
@@ -245,9 +246,10 @@ static void afs_apply_callback(struct afs_fs_cursor *fc,
 
 	if (!afs_cb_is_broken(cb_break, vnode, fc->cbi)) {
 		vnode->cb_expires_at	= cb->expires_at;
-		old = vnode->cb_interest;
+		old = rcu_dereference_protected(vnode->cb_interest,
+						lockdep_is_held(&vnode->cb_lock.lock));
 		if (old != fc->cbi) {
-			vnode->cb_interest = afs_get_cb_interest(fc->cbi);
+			rcu_assign_pointer(vnode->cb_interest, afs_get_cb_interest(fc->cbi));
 			afs_put_cb_interest(afs_v2net(vnode), old);
 		}
 		set_bit(AFS_VNODE_CB_PROMISED, &vnode->flags);
@@ -565,36 +567,46 @@ void afs_zap_data(struct afs_vnode *vnode)
  */
 bool afs_check_validity(struct afs_vnode *vnode)
 {
+	struct afs_cb_interest *cbi;
+	struct afs_server *server;
+	struct afs_volume *volume = vnode->volume;
 	time64_t now = ktime_get_real_seconds();
 	bool valid;
+	unsigned int cb_break, cb_s_break, cb_v_break;
+	int seq = 0;
 
-	/* Quickly check the callback state.  Ideally, we'd use read_seqbegin
-	 * here, but we have no way to pass the net namespace to the RCU
-	 * cleanup for the server record.
-	 */
-	read_seqlock_excl(&vnode->cb_lock);
+	do {
+		read_seqbegin_or_lock(&vnode->cb_lock, &seq);
+		cb_v_break = READ_ONCE(volume->cb_v_break);
+		cb_break = vnode->cb_break;
 
-	if (test_bit(AFS_VNODE_CB_PROMISED, &vnode->flags)) {
-		if (vnode->cb_s_break != vnode->cb_interest->server->cb_s_break ||
-		    vnode->cb_v_break != vnode->volume->cb_v_break) {
-			vnode->cb_s_break = vnode->cb_interest->server->cb_s_break;
-			vnode->cb_v_break = vnode->volume->cb_v_break;
-			valid = false;
-		} else if (test_bit(AFS_VNODE_ZAP_DATA, &vnode->flags)) {
-			valid = false;
-		} else if (vnode->cb_expires_at - 10 <= now) {
-			valid = false;
-		} else {
+		if (test_bit(AFS_VNODE_CB_PROMISED, &vnode->flags)) {
+			cbi = rcu_dereference(vnode->cb_interest);
+			server = rcu_dereference(cbi->server);
+			cb_s_break = READ_ONCE(server->cb_s_break);
+
+			if (vnode->cb_s_break != cb_s_break ||
+			    vnode->cb_v_break != cb_v_break) {
+				vnode->cb_s_break = cb_s_break;
+				vnode->cb_v_break = cb_v_break;
+				valid = false;
+			} else if (test_bit(AFS_VNODE_ZAP_DATA, &vnode->flags)) {
+				valid = false;
+			} else if (vnode->cb_expires_at - 10 <= now) {
+				valid = false;
+			} else {
+				valid = true;
+			}
+		} else if (test_bit(AFS_VNODE_DELETED, &vnode->flags)) {
 			valid = true;
+		} else {
+			vnode->cb_v_break = cb_v_break;
+			valid = false;
 		}
-	} else if (test_bit(AFS_VNODE_DELETED, &vnode->flags)) {
-		valid = true;
-	} else {
-		vnode->cb_v_break = vnode->volume->cb_v_break;
-		valid = false;
-	}
 
-	read_sequnlock_excl(&vnode->cb_lock);
+	} while (need_seqretry(&vnode->cb_lock, seq));
+
+	done_seqretry(&vnode->cb_lock, seq);
 	return valid;
 }
 
@@ -616,7 +628,9 @@ int afs_validate(struct afs_vnode *vnode, struct key *key)
 	       vnode->fid.vid, vnode->fid.vnode, vnode->flags,
 	       key_serial(key));
 
+	rcu_read_lock();
 	valid = afs_check_validity(vnode);
+	rcu_read_unlock();
 
 	if (test_bit(AFS_VNODE_DELETED, &vnode->flags))
 		clear_nlink(&vnode->vfs_inode);
@@ -703,6 +717,7 @@ int afs_drop_inode(struct inode *inode)
  */
 void afs_evict_inode(struct inode *inode)
 {
+	struct afs_cb_interest *cbi;
 	struct afs_vnode *vnode;
 
 	vnode = AFS_FS_I(inode);
@@ -719,10 +734,14 @@ void afs_evict_inode(struct inode *inode)
 	truncate_inode_pages_final(&inode->i_data);
 	clear_inode(inode);
 
-	if (vnode->cb_interest) {
-		afs_put_cb_interest(afs_i2net(inode), vnode->cb_interest);
-		vnode->cb_interest = NULL;
+	write_seqlock(&vnode->cb_lock);
+	cbi = rcu_dereference_protected(vnode->cb_interest,
+					lockdep_is_held(&vnode->cb_lock.lock));
+	if (cbi) {
+		afs_put_cb_interest(afs_i2net(inode), cbi);
+		rcu_assign_pointer(vnode->cb_interest, NULL);
 	}
+	write_sequnlock(&vnode->cb_lock);
 
 	while (!list_empty(&vnode->wb_keys)) {
 		struct afs_wb_key *wbk = list_entry(vnode->wb_keys.next,