cifs: avoid starvation when refreshing dfs cache

When refreshing the DFS cache, keep SMB2 IOCTL calls as much outside
critical sections as possible and avoid read/write starvation when
getting new DFS referrals by using broken or slow connections.

Signed-off-by: Paulo Alcantara (SUSE) <pc@cjr.nz>
Reviewed-by: Aurelien Aptel <aaptel@suse.com>
Signed-off-by: Steve French <stfrench@microsoft.com>
diff --git a/fs/cifs/dfs_cache.c b/fs/cifs/dfs_cache.c
index d9c2c75..775dbc7 100644
--- a/fs/cifs/dfs_cache.c
+++ b/fs/cifs/dfs_cache.c
@@ -554,6 +554,8 @@ static void remove_oldest_entry_locked(void)
 	struct cache_entry *ce;
 	struct cache_entry *to_del = NULL;
 
+	WARN_ON(!rwsem_is_locked(&htable_rw_lock));
+
 	for (i = 0; i < CACHE_HTABLE_SIZE; i++) {
 		struct hlist_head *l = &cache_htable[i];
 
@@ -583,7 +585,13 @@ static int add_cache_entry_locked(struct dfs_info3_param *refs, int numrefs)
 	struct cache_entry *ce;
 	unsigned int hash;
 
-	convert_delimiter(refs[0].path_name, '\\');
+	WARN_ON(!rwsem_is_locked(&htable_rw_lock));
+
+	if (atomic_read(&cache_count) >= CACHE_MAX_ENTRIES) {
+		cifs_dbg(FYI, "%s: reached max cache size (%d)\n", __func__, CACHE_MAX_ENTRIES);
+		remove_oldest_entry_locked();
+	}
+
 	rc = cache_entry_hash(refs[0].path_name, strlen(refs[0].path_name), &hash);
 	if (rc)
 		return rc;
@@ -605,6 +613,8 @@ static int add_cache_entry_locked(struct dfs_info3_param *refs, int numrefs)
 	hlist_add_head(&ce->hlist, &cache_htable[hash]);
 	dump_ce(ce);
 
+	atomic_inc(&cache_count);
+
 	return 0;
 }
 
@@ -719,16 +729,13 @@ void dfs_cache_destroy(void)
 }
 
 /* Update a cache entry with the new referral in @refs */
-static int update_cache_entry_locked(const char *path, const struct dfs_info3_param *refs,
+static int update_cache_entry_locked(struct cache_entry *ce, const struct dfs_info3_param *refs,
 				     int numrefs)
 {
 	int rc;
-	struct cache_entry *ce;
 	char *s, *th = NULL;
 
-	ce = lookup_cache_entry(path);
-	if (IS_ERR(ce))
-		return PTR_ERR(ce);
+	WARN_ON(!rwsem_is_locked(&htable_rw_lock));
 
 	if (ce->tgthint) {
 		s = ce->tgthint->name;
@@ -750,18 +757,28 @@ static int update_cache_entry_locked(const char *path, const struct dfs_info3_pa
 static int get_dfs_referral(const unsigned int xid, struct cifs_ses *ses, const char *path,
 			    struct dfs_info3_param **refs, int *numrefs)
 {
+	int rc;
+	int i;
+
 	cifs_dbg(FYI, "%s: get an DFS referral for %s\n", __func__, path);
 
+	*refs = NULL;
+	*numrefs = 0;
+
 	if (!ses || !ses->server || !ses->server->ops->get_dfs_refer)
 		return -EOPNOTSUPP;
 	if (unlikely(!cache_cp))
 		return -EINVAL;
 
-	*refs = NULL;
-	*numrefs = 0;
+	rc =  ses->server->ops->get_dfs_refer(xid, ses, path, refs, numrefs, cache_cp,
+					      NO_MAP_UNI_RSVD);
+	if (!rc) {
+		struct dfs_info3_param *ref = *refs;
 
-	return ses->server->ops->get_dfs_refer(xid, ses, path, refs, numrefs, cache_cp,
-					       NO_MAP_UNI_RSVD);
+		for (i = 0; i < *numrefs; i++)
+			convert_delimiter(ref[i].path_name, '\\');
+	}
+	return rc;
 }
 
 /*
@@ -807,18 +824,11 @@ static int cache_refresh_path(const unsigned int xid, struct cifs_ses *ses, cons
 	dump_refs(refs, numrefs);
 
 	if (!newent) {
-		rc = update_cache_entry_locked(path, refs, numrefs);
+		rc = update_cache_entry_locked(ce, refs, numrefs);
 		goto out_unlock;
 	}
 
-	if (atomic_read(&cache_count) >= CACHE_MAX_ENTRIES) {
-		cifs_dbg(FYI, "%s: reached max cache size (%d)\n", __func__, CACHE_MAX_ENTRIES);
-		remove_oldest_entry_locked();
-	}
-
 	rc = add_cache_entry_locked(refs, numrefs);
-	if (!rc)
-		atomic_inc(&cache_count);
 
 out_unlock:
 	up_write(&htable_rw_lock);
@@ -1313,15 +1323,43 @@ static void refresh_mounts(struct cifs_ses **sessions)
 
 	list_for_each_entry_safe(tcon, ntcon, &tcons, ulist) {
 		const char *path = tcon->dfs_path + 1;
+		struct cache_entry *ce;
+		struct dfs_info3_param *refs = NULL;
+		int numrefs = 0;
+		bool needs_refresh = false;
 		int rc = 0;
 
 		list_del_init(&tcon->ulist);
+
 		ses = find_ipc_from_server_path(sessions, path);
-		if (!IS_ERR(ses)) {
-			xid = get_xid();
-			cache_refresh_path(xid, ses, path);
-			free_xid(xid);
+		if (IS_ERR(ses))
+			goto next_tcon;
+
+		down_read(&htable_rw_lock);
+		ce = lookup_cache_entry(path);
+		needs_refresh = IS_ERR(ce) || cache_entry_expired(ce);
+		up_read(&htable_rw_lock);
+
+		if (!needs_refresh)
+			goto next_tcon;
+
+		xid = get_xid();
+		rc = get_dfs_referral(xid, ses, path, &refs, &numrefs);
+		free_xid(xid);
+
+		/* Create or update a cache entry with the new referral */
+		if (!rc) {
+			down_write(&htable_rw_lock);
+			ce = lookup_cache_entry(path);
+			if (IS_ERR(ce))
+				add_cache_entry_locked(refs, numrefs);
+			else if (cache_entry_expired(ce))
+				update_cache_entry_locked(ce, refs, numrefs);
+			up_write(&htable_rw_lock);
 		}
+
+next_tcon:
+		free_dfs_info_array(refs, numrefs);
 		cifs_put_tcon(tcon);
 	}
 }
@@ -1331,40 +1369,67 @@ static void refresh_cache(struct cifs_ses **sessions)
 	int i;
 	struct cifs_ses *ses;
 	unsigned int xid;
-	int rc;
+	char *ref_paths[CACHE_MAX_ENTRIES];
+	int count = 0;
+	struct cache_entry *ce;
 
 	/*
-	 * Refresh all cached entries.
+	 * Refresh all cached entries.  Get all new referrals outside critical section to avoid
+	 * starvation while performing SMB2 IOCTL on broken or slow connections.
+
 	 * The cache entries may cover more paths than the active mounts
 	 * (e.g. domain-based DFS referrals or multi tier DFS setups).
 	 */
-	down_write(&htable_rw_lock);
+	down_read(&htable_rw_lock);
 	for (i = 0; i < CACHE_HTABLE_SIZE; i++) {
-		struct cache_entry *ce;
 		struct hlist_head *l = &cache_htable[i];
 
 		hlist_for_each_entry(ce, l, hlist) {
-			struct dfs_info3_param *refs = NULL;
-			int numrefs = 0;
-
-			if (hlist_unhashed(&ce->hlist) || !cache_entry_expired(ce))
+			if (count == ARRAY_SIZE(ref_paths))
+				goto out_unlock;
+			if (hlist_unhashed(&ce->hlist) || !cache_entry_expired(ce) ||
+			    IS_ERR(find_ipc_from_server_path(sessions, ce->path)))
 				continue;
-
-			ses = find_ipc_from_server_path(sessions, ce->path);
-			if (IS_ERR(ses))
-				continue;
-
-			xid = get_xid();
-			rc = get_dfs_referral(xid, ses, ce->path, &refs, &numrefs);
-			free_xid(xid);
-
-			if (!rc)
-				update_cache_entry_locked(ce->path, refs, numrefs);
-
-			free_dfs_info_array(refs, numrefs);
+			ref_paths[count++] = kstrdup(ce->path, GFP_ATOMIC);
 		}
 	}
-	up_write(&htable_rw_lock);
+
+out_unlock:
+	up_read(&htable_rw_lock);
+
+	for (i = 0; i < count; i++) {
+		char *path = ref_paths[i];
+		struct dfs_info3_param *refs = NULL;
+		int numrefs = 0;
+		int rc = 0;
+
+		if (!path)
+			continue;
+
+		ses = find_ipc_from_server_path(sessions, path);
+		if (IS_ERR(ses))
+			goto next_referral;
+
+		xid = get_xid();
+		rc = get_dfs_referral(xid, ses, path, &refs, &numrefs);
+		free_xid(xid);
+
+		if (!rc) {
+			down_write(&htable_rw_lock);
+			ce = lookup_cache_entry(path);
+			/*
+			 * We need to re-check it because other tasks might have it deleted or
+			 * updated.
+			 */
+			if (!IS_ERR(ce) && cache_entry_expired(ce))
+				update_cache_entry_locked(ce, refs, numrefs);
+			up_write(&htable_rw_lock);
+		}
+
+next_referral:
+		kfree(path);
+		free_dfs_info_array(refs, numrefs);
+	}
 }
 
 /*