cifs: keep referral server sessions alive

At every mount, keep all sessions alive that were used for chasing the
DFS referrals as long as the dfs mounts are active.

Use those sessions in DFS cache to refresh all active tcons as well as
cached entries.  They will be managed by a list of mount_group
structures that will be indexed by a randomly generated uuid at mount
time, so we can put all the sessions related to specific dfs mounts
and avoid leaking them.

Signed-off-by: Paulo Alcantara (SUSE) <pc@cjr.nz>
Reviewed-by: Aurelien Aptel <aaptel@suse.com>
Signed-off-by: Steve French <stfrench@microsoft.com>
diff --git a/fs/cifs/dfs_cache.c b/fs/cifs/dfs_cache.c
index 70383e1..846d670 100644
--- a/fs/cifs/dfs_cache.c
+++ b/fs/cifs/dfs_cache.c
@@ -11,6 +11,7 @@
 #include <linux/proc_fs.h>
 #include <linux/nls.h>
 #include <linux/workqueue.h>
+#include <linux/uuid.h>
 #include "cifsglob.h"
 #include "smb2pdu.h"
 #include "smb2proto.h"
@@ -18,7 +19,6 @@
 #include "cifs_debug.h"
 #include "cifs_unicode.h"
 #include "smb2glob.h"
-#include "fs_context.h"
 
 #include "dfs_cache.h"
 
@@ -48,14 +48,15 @@ struct cache_entry {
 	struct cache_dfs_tgt *tgthint;
 };
 
-struct vol_info {
-	char *fullpath;
-	spinlock_t ctx_lock;
-	struct smb3_fs_context ctx;
-	char *mntdata;
+/* List of referral server sessions per dfs mount */
+struct mount_group {
 	struct list_head list;
-	struct list_head rlist;
-	struct kref refcnt;
+	uuid_t id;
+	struct cifs_ses *sessions[CACHE_MAX_ENTRIES];
+	int num_sessions;
+	spinlock_t lock;
+	struct list_head refresh_list;
+	struct kref refcount;
 };
 
 static struct kmem_cache *cache_slab __read_mostly;
@@ -74,13 +75,106 @@ static atomic_t cache_count;
 static struct hlist_head cache_htable[CACHE_HTABLE_SIZE];
 static DECLARE_RWSEM(htable_rw_lock);
 
-static LIST_HEAD(vol_list);
-static DEFINE_SPINLOCK(vol_list_lock);
+static LIST_HEAD(mount_group_list);
+static DEFINE_MUTEX(mount_group_list_lock);
 
 static void refresh_cache_worker(struct work_struct *work);
 
 static DECLARE_DELAYED_WORK(refresh_task, refresh_cache_worker);
 
+static void get_ipc_unc(const char *ref_path, char *ipc, size_t ipclen)
+{
+	const char *host;
+	size_t len;
+
+	extract_unc_hostname(ref_path, &host, &len);
+	scnprintf(ipc, ipclen, "\\\\%.*s\\IPC$", (int)len, host);
+}
+
+static struct cifs_ses *find_ipc_from_server_path(struct cifs_ses **ses, const char *path)
+{
+	char unc[SERVER_NAME_LENGTH + sizeof("//x/IPC$")] = {0};
+
+	get_ipc_unc(path, unc, sizeof(unc));
+	for (; *ses; ses++) {
+		if (!strcasecmp(unc, (*ses)->tcon_ipc->treeName))
+			return *ses;
+	}
+	return ERR_PTR(-ENOENT);
+}
+
+static void __mount_group_release(struct mount_group *mg)
+{
+	int i;
+
+	for (i = 0; i < mg->num_sessions; i++)
+		cifs_put_smb_ses(mg->sessions[i]);
+	kfree(mg);
+}
+
+static void mount_group_release(struct kref *kref)
+{
+	struct mount_group *mg = container_of(kref, struct mount_group, refcount);
+
+	mutex_lock(&mount_group_list_lock);
+	list_del(&mg->list);
+	mutex_unlock(&mount_group_list_lock);
+	__mount_group_release(mg);
+}
+
+static struct mount_group *find_mount_group_locked(const uuid_t *id)
+{
+	struct mount_group *mg;
+
+	list_for_each_entry(mg, &mount_group_list, list) {
+		if (uuid_equal(&mg->id, id))
+			return mg;
+	}
+	return ERR_PTR(-ENOENT);
+}
+
+static struct mount_group *__get_mount_group_locked(const uuid_t *id)
+{
+	struct mount_group *mg;
+
+	mg = find_mount_group_locked(id);
+	if (!IS_ERR(mg))
+		return mg;
+
+	mg = kmalloc(sizeof(*mg), GFP_KERNEL);
+	if (!mg)
+		return ERR_PTR(-ENOMEM);
+	kref_init(&mg->refcount);
+	uuid_copy(&mg->id, id);
+	mg->num_sessions = 0;
+	spin_lock_init(&mg->lock);
+	list_add(&mg->list, &mount_group_list);
+	return mg;
+}
+
+static struct mount_group *get_mount_group(const uuid_t *id)
+{
+	struct mount_group *mg;
+
+	mutex_lock(&mount_group_list_lock);
+	mg = __get_mount_group_locked(id);
+	if (!IS_ERR(mg))
+		kref_get(&mg->refcount);
+	mutex_unlock(&mount_group_list_lock);
+
+	return mg;
+}
+
+static void free_mount_group_list(void)
+{
+	struct mount_group *mg, *tmp_mg;
+
+	list_for_each_entry_safe(mg, tmp_mg, &mount_group_list, list) {
+		list_del_init(&mg->list);
+		__mount_group_release(mg);
+	}
+}
+
 static int get_normalized_path(const char *path, const char **npath)
 {
 	if (!path || strlen(path) < 3 || (*path != '\\' && *path != '/'))
@@ -284,8 +378,7 @@ int dfs_cache_init(void)
 	int rc;
 	int i;
 
-	dfscache_wq = alloc_workqueue("cifs-dfscache",
-				      WQ_FREEZABLE | WQ_MEM_RECLAIM, 1);
+	dfscache_wq = alloc_workqueue("cifs-dfscache", WQ_FREEZABLE | WQ_UNBOUND, 1);
 	if (!dfscache_wq)
 		return -ENOMEM;
 
@@ -426,8 +519,7 @@ static struct cache_entry *alloc_cache_entry(const char *path,
 	return ce;
 }
 
-/* Must be called with htable_rw_lock held */
-static void remove_oldest_entry(void)
+static void remove_oldest_entry_locked(void)
 {
 	int i;
 	struct cache_entry *ce;
@@ -456,8 +548,8 @@ static void remove_oldest_entry(void)
 }
 
 /* Add a new DFS cache entry */
-static int add_cache_entry(const char *path, unsigned int hash,
-			   struct dfs_info3_param *refs, int numrefs)
+static int add_cache_entry_locked(const char *path, unsigned int hash,
+				  struct dfs_info3_param *refs, int numrefs)
 {
 	struct cache_entry *ce;
 
@@ -475,10 +567,8 @@ static int add_cache_entry(const char *path, unsigned int hash,
 	}
 	spin_unlock(&cache_ttl_lock);
 
-	down_write(&htable_rw_lock);
 	hlist_add_head(&ce->hlist, &cache_htable[hash]);
 	dump_ce(ce);
-	up_write(&htable_rw_lock);
 
 	return 0;
 }
@@ -573,34 +663,6 @@ static struct cache_entry *lookup_cache_entry(const char *path, unsigned int *ha
 	return ce;
 }
 
-static void __vol_release(struct vol_info *vi)
-{
-	kfree(vi->fullpath);
-	kfree(vi->mntdata);
-	smb3_cleanup_fs_context_contents(&vi->ctx);
-	kfree(vi);
-}
-
-static void vol_release(struct kref *kref)
-{
-	struct vol_info *vi = container_of(kref, struct vol_info, refcnt);
-
-	spin_lock(&vol_list_lock);
-	list_del(&vi->list);
-	spin_unlock(&vol_list_lock);
-	__vol_release(vi);
-}
-
-static inline void free_vol_list(void)
-{
-	struct vol_info *vi, *nvi;
-
-	list_for_each_entry_safe(vi, nvi, &vol_list, list) {
-		list_del_init(&vi->list);
-		__vol_release(vi);
-	}
-}
-
 /**
  * dfs_cache_destroy - destroy DFS referral cache
  */
@@ -608,7 +670,7 @@ void dfs_cache_destroy(void)
 {
 	cancel_delayed_work_sync(&refresh_task);
 	unload_nls(cache_nlsc);
-	free_vol_list();
+	free_mount_group_list();
 	flush_cache_ents();
 	kmem_cache_destroy(cache_slab);
 	destroy_workqueue(dfscache_wq);
@@ -616,10 +678,9 @@ void dfs_cache_destroy(void)
 	cifs_dbg(FYI, "%s: destroyed DFS referral cache\n", __func__);
 }
 
-/* Must be called with htable_rw_lock held */
-static int __update_cache_entry(const char *path,
-				const struct dfs_info3_param *refs,
-				int numrefs)
+/* Update a cache entry with the new referral in @refs */
+static int update_cache_entry_locked(const char *path, const struct dfs_info3_param *refs,
+				     int numrefs)
 {
 	int rc;
 	struct cache_entry *ce;
@@ -665,32 +726,17 @@ static int get_dfs_referral(const unsigned int xid, struct cifs_ses *ses,
 					       nls_codepage, remap);
 }
 
-/* Update an expired cache entry by getting a new DFS referral from server */
-static int update_cache_entry(const char *path,
-			      const struct dfs_info3_param *refs,
-			      int numrefs)
-{
-
-	int rc;
-
-	down_write(&htable_rw_lock);
-	rc = __update_cache_entry(path, refs, numrefs);
-	up_write(&htable_rw_lock);
-
-	return rc;
-}
-
 /*
  * Find, create or update a DFS cache entry.
  *
  * If the entry wasn't found, it will create a new one. Or if it was found but
  * expired, then it will update the entry accordingly.
  *
- * For interlinks, __cifs_dfs_mount() and expand_dfs_referral() are supposed to
+ * For interlinks, cifs_mount() and expand_dfs_referral() are supposed to
  * handle them properly.
  */
-static int __dfs_cache_find(const unsigned int xid, struct cifs_ses *ses,
-			    const struct nls_table *nls_codepage, int remap, const char *path)
+static int cache_refresh_path(const unsigned int xid, struct cifs_ses *ses,
+			      const struct nls_table *nls_codepage, int remap, const char *path)
 {
 	int rc;
 	unsigned int hash;
@@ -701,52 +747,46 @@ static int __dfs_cache_find(const unsigned int xid, struct cifs_ses *ses,
 
 	cifs_dbg(FYI, "%s: search path: %s\n", __func__, path);
 
-	down_read(&htable_rw_lock);
+	down_write(&htable_rw_lock);
 
 	ce = lookup_cache_entry(path, &hash);
 	if (!IS_ERR(ce)) {
 		if (!cache_entry_expired(ce)) {
 			dump_ce(ce);
-			up_read(&htable_rw_lock);
+			up_write(&htable_rw_lock);
 			return 0;
 		}
 	} else {
 		newent = true;
 	}
 
-	up_read(&htable_rw_lock);
-
 	/*
-	 * No entry was found.
-	 *
-	 * Request a new DFS referral in order to create a new cache entry, or
-	 * updating an existing one.
+	 * Either the entry was not found, or it is expired.
+	 * Request a new DFS referral in order to create or update a cache entry.
 	 */
 	rc = get_dfs_referral(xid, ses, nls_codepage, remap, path,
 			      &refs, &numrefs);
 	if (rc)
-		return rc;
+		goto out_unlock;
 
 	dump_refs(refs, numrefs);
 
 	if (!newent) {
-		rc = update_cache_entry(path, refs, numrefs);
-		goto out_free_refs;
+		rc = update_cache_entry_locked(path, refs, numrefs);
+		goto out_unlock;
 	}
 
 	if (atomic_read(&cache_count) >= CACHE_MAX_ENTRIES) {
-		cifs_dbg(FYI, "%s: reached max cache size (%d)\n",
-			 __func__, CACHE_MAX_ENTRIES);
-		down_write(&htable_rw_lock);
-		remove_oldest_entry();
-		up_write(&htable_rw_lock);
+		cifs_dbg(FYI, "%s: reached max cache size (%d)\n", __func__, CACHE_MAX_ENTRIES);
+		remove_oldest_entry_locked();
 	}
 
-	rc = add_cache_entry(path, hash, refs, numrefs);
+	rc = add_cache_entry_locked(path, hash, refs, numrefs);
 	if (!rc)
 		atomic_inc(&cache_count);
 
-out_free_refs:
+out_unlock:
+	up_write(&htable_rw_lock);
 	free_dfs_info_array(refs, numrefs);
 	return rc;
 }
@@ -868,7 +908,7 @@ int dfs_cache_find(const unsigned int xid, struct cifs_ses *ses,
 	if (rc)
 		return rc;
 
-	rc = __dfs_cache_find(xid, ses, nls_codepage, remap, npath);
+	rc = cache_refresh_path(xid, ses, nls_codepage, remap, npath);
 	if (rc)
 		goto out_free_path;
 
@@ -980,7 +1020,7 @@ int dfs_cache_update_tgthint(const unsigned int xid, struct cifs_ses *ses,
 
 	cifs_dbg(FYI, "%s: update target hint - path: %s\n", __func__, npath);
 
-	rc = __dfs_cache_find(xid, ses, nls_codepage, remap, npath);
+	rc = cache_refresh_path(xid, ses, nls_codepage, remap, npath);
 	if (rc)
 		goto out_free_path;
 
@@ -1122,126 +1162,51 @@ int dfs_cache_get_tgt_referral(const char *path,
 }
 
 /**
- * dfs_cache_add_vol - add a cifs context during mount() that will be handled by
- * DFS cache refresh worker.
+ * dfs_cache_add_refsrv_session - add SMB session of referral server
  *
- * @mntdata: mount data.
- * @ctx: cifs context.
- * @fullpath: origin full path.
- *
- * Return zero if context was set up correctly, otherwise non-zero.
+ * @mount_id: mount group uuid to lookup.
+ * @ses: reference counted SMB session of referral server.
  */
-int dfs_cache_add_vol(char *mntdata, struct smb3_fs_context *ctx, const char *fullpath)
+void dfs_cache_add_refsrv_session(const uuid_t *mount_id, struct cifs_ses *ses)
 {
-	int rc;
-	struct vol_info *vi;
+	struct mount_group *mg;
 
-	if (!ctx || !fullpath || !mntdata)
-		return -EINVAL;
-
-	cifs_dbg(FYI, "%s: fullpath: %s\n", __func__, fullpath);
-
-	vi = kzalloc(sizeof(*vi), GFP_KERNEL);
-	if (!vi)
-		return -ENOMEM;
-
-	vi->fullpath = kstrdup(fullpath, GFP_KERNEL);
-	if (!vi->fullpath) {
-		rc = -ENOMEM;
-		goto err_free_vi;
-	}
-
-	rc = smb3_fs_context_dup(&vi->ctx, ctx);
-	if (rc)
-		goto err_free_fullpath;
-
-	vi->mntdata = mntdata;
-	spin_lock_init(&vi->ctx_lock);
-	kref_init(&vi->refcnt);
-
-	spin_lock(&vol_list_lock);
-	list_add_tail(&vi->list, &vol_list);
-	spin_unlock(&vol_list_lock);
-
-	return 0;
-
-err_free_fullpath:
-	kfree(vi->fullpath);
-err_free_vi:
-	kfree(vi);
-	return rc;
-}
-
-/* Must be called with vol_list_lock held */
-static struct vol_info *find_vol(const char *fullpath)
-{
-	struct vol_info *vi;
-
-	list_for_each_entry(vi, &vol_list, list) {
-		cifs_dbg(FYI, "%s: vi->fullpath: %s\n", __func__, vi->fullpath);
-		if (!strcasecmp(vi->fullpath, fullpath))
-			return vi;
-	}
-	return ERR_PTR(-ENOENT);
-}
-
-/**
- * dfs_cache_update_vol - update vol info in DFS cache after failover
- *
- * @fullpath: fullpath to look up in volume list.
- * @server: TCP ses pointer.
- *
- * Return zero if volume was updated, otherwise non-zero.
- */
-int dfs_cache_update_vol(const char *fullpath, struct TCP_Server_Info *server)
-{
-	struct vol_info *vi;
-
-	if (!fullpath || !server)
-		return -EINVAL;
-
-	cifs_dbg(FYI, "%s: fullpath: %s\n", __func__, fullpath);
-
-	spin_lock(&vol_list_lock);
-	vi = find_vol(fullpath);
-	if (IS_ERR(vi)) {
-		spin_unlock(&vol_list_lock);
-		return PTR_ERR(vi);
-	}
-	kref_get(&vi->refcnt);
-	spin_unlock(&vol_list_lock);
-
-	cifs_dbg(FYI, "%s: updating volume info\n", __func__);
-	spin_lock(&vi->ctx_lock);
-	memcpy(&vi->ctx.dstaddr, &server->dstaddr,
-	       sizeof(vi->ctx.dstaddr));
-	spin_unlock(&vi->ctx_lock);
-
-	kref_put(&vi->refcnt, vol_release);
-
-	return 0;
-}
-
-/**
- * dfs_cache_del_vol - remove volume info in DFS cache during umount()
- *
- * @fullpath: fullpath to look up in volume list.
- */
-void dfs_cache_del_vol(const char *fullpath)
-{
-	struct vol_info *vi;
-
-	if (!fullpath || !*fullpath)
+	if (WARN_ON_ONCE(!mount_id || uuid_is_null(mount_id) || !ses))
 		return;
 
-	cifs_dbg(FYI, "%s: fullpath: %s\n", __func__, fullpath);
+	mg = get_mount_group(mount_id);
+	if (WARN_ON_ONCE(IS_ERR(mg)))
+		return;
 
-	spin_lock(&vol_list_lock);
-	vi = find_vol(fullpath);
-	spin_unlock(&vol_list_lock);
+	spin_lock(&mg->lock);
+	if (mg->num_sessions < ARRAY_SIZE(mg->sessions))
+		mg->sessions[mg->num_sessions++] = ses;
+	spin_unlock(&mg->lock);
+	kref_put(&mg->refcount, mount_group_release);
+}
 
-	if (!IS_ERR(vi))
-		kref_put(&vi->refcnt, vol_release);
+/**
+ * dfs_cache_put_refsrv_sessions - put all referral server sessions
+ *
+ * Put all SMB sessions from the given mount group id.
+ *
+ * @mount_id: mount group uuid to lookup.
+ */
+void dfs_cache_put_refsrv_sessions(const uuid_t *mount_id)
+{
+	struct mount_group *mg;
+
+	if (!mount_id || uuid_is_null(mount_id))
+		return;
+
+	mutex_lock(&mount_group_list_lock);
+	mg = find_mount_group_locked(mount_id);
+	if (IS_ERR(mg)) {
+		mutex_unlock(&mount_group_list_lock);
+		return;
+	}
+	mutex_unlock(&mount_group_list_lock);
+	kref_put(&mg->refcount, mount_group_release);
 }
 
 /**
@@ -1310,278 +1275,136 @@ int dfs_cache_get_tgt_share(char *path, const struct dfs_cache_tgt_iterator *it,
 	return 0;
 }
 
-/* Get all tcons that are within a DFS namespace and can be refreshed */
-static void get_tcons(struct TCP_Server_Info *server, struct list_head *head)
+/*
+ * Refresh all active dfs mounts regardless of whether they are in cache or not.
+ * (cache can be cleared)
+ */
+static void refresh_mounts(struct cifs_ses **sessions)
 {
+	struct TCP_Server_Info *server;
 	struct cifs_ses *ses;
-	struct cifs_tcon *tcon;
+	struct cifs_tcon *tcon, *ntcon;
+	struct list_head tcons;
+	unsigned int xid;
 
-	INIT_LIST_HEAD(head);
+	INIT_LIST_HEAD(&tcons);
 
 	spin_lock(&cifs_tcp_ses_lock);
-	list_for_each_entry(ses, &server->smb_ses_list, smb_ses_list) {
-		list_for_each_entry(tcon, &ses->tcon_list, tcon_list) {
-			if (!tcon->need_reconnect && !tcon->need_reopen_files &&
-			    tcon->dfs_path) {
-				tcon->tc_count++;
-				list_add_tail(&tcon->ulist, head);
+	list_for_each_entry(server, &cifs_tcp_ses_list, tcp_ses_list) {
+		list_for_each_entry(ses, &server->smb_ses_list, smb_ses_list) {
+			list_for_each_entry(tcon, &ses->tcon_list, tcon_list) {
+				if (tcon->dfs_path) {
+					tcon->tc_count++;
+					list_add_tail(&tcon->ulist, &tcons);
+				}
 			}
 		}
-		if (ses->tcon_ipc && !ses->tcon_ipc->need_reconnect &&
-		    ses->tcon_ipc->dfs_path) {
-			list_add_tail(&ses->tcon_ipc->ulist, head);
-		}
 	}
 	spin_unlock(&cifs_tcp_ses_lock);
-}
 
-static bool is_dfs_link(const char *path)
-{
-	char *s;
+	list_for_each_entry_safe(tcon, ntcon, &tcons, ulist) {
+		const char *path = tcon->dfs_path + 1;
+		int rc = 0;
 
-	s = strchr(path + 1, '\\');
-	if (!s)
-		return false;
-	return !!strchr(s + 1, '\\');
-}
-
-static char *get_dfs_root(const char *path)
-{
-	char *s, *npath;
-
-	s = strchr(path + 1, '\\');
-	if (!s)
-		return ERR_PTR(-EINVAL);
-
-	s = strchr(s + 1, '\\');
-	if (!s)
-		return ERR_PTR(-EINVAL);
-
-	npath = kstrndup(path, s - path, GFP_KERNEL);
-	if (!npath)
-		return ERR_PTR(-ENOMEM);
-
-	return npath;
-}
-
-static inline void put_tcp_server(struct TCP_Server_Info *server)
-{
-	cifs_put_tcp_session(server, 0);
-}
-
-static struct TCP_Server_Info *get_tcp_server(struct smb3_fs_context *ctx)
-{
-	struct TCP_Server_Info *server;
-
-	server = cifs_find_tcp_session(ctx);
-	if (IS_ERR_OR_NULL(server))
-		return NULL;
-
-	spin_lock(&GlobalMid_Lock);
-	if (server->tcpStatus != CifsGood) {
-		spin_unlock(&GlobalMid_Lock);
-		put_tcp_server(server);
-		return NULL;
-	}
-	spin_unlock(&GlobalMid_Lock);
-
-	return server;
-}
-
-/* Find root SMB session out of a DFS link path */
-static struct cifs_ses *find_root_ses(struct vol_info *vi,
-				      struct cifs_tcon *tcon,
-				      const char *path)
-{
-	char *rpath;
-	int rc;
-	struct cache_entry *ce;
-	struct dfs_info3_param ref = {0};
-	char *mdata = NULL, *devname = NULL;
-	struct TCP_Server_Info *server;
-	struct cifs_ses *ses;
-	struct smb3_fs_context ctx = {NULL};
-
-	rpath = get_dfs_root(path);
-	if (IS_ERR(rpath))
-		return ERR_CAST(rpath);
-
-	down_read(&htable_rw_lock);
-
-	ce = lookup_cache_entry(rpath, NULL);
-	if (IS_ERR(ce)) {
-		up_read(&htable_rw_lock);
-		ses = ERR_CAST(ce);
-		goto out;
-	}
-
-	rc = setup_referral(path, ce, &ref, get_tgt_name(ce));
-	if (rc) {
-		up_read(&htable_rw_lock);
-		ses = ERR_PTR(rc);
-		goto out;
-	}
-
-	up_read(&htable_rw_lock);
-
-	mdata = cifs_compose_mount_options(vi->mntdata, rpath, &ref,
-					   &devname);
-	free_dfs_info_param(&ref);
-
-	if (IS_ERR(mdata)) {
-		ses = ERR_CAST(mdata);
-		mdata = NULL;
-		goto out;
-	}
-
-	rc = cifs_setup_volume_info(&ctx, NULL, devname);
-
-	if (rc) {
-		ses = ERR_PTR(rc);
-		goto out;
-	}
-
-	server = get_tcp_server(&ctx);
-	if (!server) {
-		ses = ERR_PTR(-EHOSTDOWN);
-		goto out;
-	}
-
-	ses = cifs_get_smb_ses(server, &ctx);
-
-out:
-	smb3_cleanup_fs_context_contents(&ctx);
-	kfree(mdata);
-	kfree(rpath);
-	kfree(devname);
-
-	return ses;
-}
-
-/* Refresh DFS cache entry from a given tcon */
-static int refresh_tcon(struct vol_info *vi, struct cifs_tcon *tcon)
-{
-	int rc = 0;
-	unsigned int xid;
-	const char *path, *npath;
-	struct cache_entry *ce;
-	struct cifs_ses *root_ses = NULL, *ses;
-	struct dfs_info3_param *refs = NULL;
-	int numrefs = 0;
-
-	xid = get_xid();
-
-	path = tcon->dfs_path + 1;
-
-	rc = get_normalized_path(path, &npath);
-	if (rc)
-		goto out_free_xid;
-
-	down_read(&htable_rw_lock);
-
-	ce = lookup_cache_entry(npath, NULL);
-	if (IS_ERR(ce)) {
-		rc = PTR_ERR(ce);
-		up_read(&htable_rw_lock);
-		goto out_free_path;
-	}
-
-	if (!cache_entry_expired(ce)) {
-		up_read(&htable_rw_lock);
-		goto out_free_path;
-	}
-
-	up_read(&htable_rw_lock);
-
-	/* If it's a DFS Link, then use root SMB session for refreshing it */
-	if (is_dfs_link(npath)) {
-		ses = root_ses = find_root_ses(vi, tcon, npath);
-		if (IS_ERR(ses)) {
-			rc = PTR_ERR(ses);
-			root_ses = NULL;
-			goto out_free_path;
+		list_del_init(&tcon->ulist);
+		ses = find_ipc_from_server_path(sessions, path);
+		if (!IS_ERR(ses)) {
+			xid = get_xid();
+			cache_refresh_path(xid, ses, cache_nlsc, tcon->remap, path);
+			free_xid(xid);
 		}
-	} else {
-		ses = tcon->ses;
+		cifs_put_tcon(tcon);
 	}
+}
 
-	rc = get_dfs_referral(xid, ses, cache_nlsc, tcon->remap, npath, &refs,
-			      &numrefs);
-	if (!rc) {
-		dump_refs(refs, numrefs);
-		rc = update_cache_entry(npath, refs, numrefs);
-		free_dfs_info_array(refs, numrefs);
+static void refresh_cache(struct cifs_ses **sessions)
+{
+	int i;
+	struct cifs_ses *ses;
+	unsigned int xid;
+	int rc;
+
+	/*
+	 * Refresh all cached entries.
+	 * The cache entries may cover more paths than the active mounts
+	 * (e.g. domain-based DFS referrals or multi tier DFS setups).
+	 */
+	down_write(&htable_rw_lock);
+	for (i = 0; i < CACHE_HTABLE_SIZE; i++) {
+		struct cache_entry *ce;
+		struct hlist_head *l = &cache_htable[i];
+
+		hlist_for_each_entry(ce, l, hlist) {
+			struct dfs_info3_param *refs = NULL;
+			int numrefs = 0;
+
+			if (hlist_unhashed(&ce->hlist) || !cache_entry_expired(ce))
+				continue;
+
+			ses = find_ipc_from_server_path(sessions, ce->path);
+			if (IS_ERR(ses))
+				continue;
+
+			xid = get_xid();
+			rc = get_dfs_referral(xid, ses, cache_nlsc, NO_MAP_UNI_RSVD, ce->path,
+					      &refs, &numrefs);
+			free_xid(xid);
+
+			if (!rc)
+				update_cache_entry_locked(ce->path, refs, numrefs);
+
+			free_dfs_info_array(refs, numrefs);
+		}
 	}
-
-	if (root_ses)
-		cifs_put_smb_ses(root_ses);
-
-out_free_path:
-	free_normalized_path(path, npath);
-
-out_free_xid:
-	free_xid(xid);
-	return rc;
+	up_write(&htable_rw_lock);
 }
 
 /*
- * Worker that will refresh DFS cache based on lowest TTL value from a DFS
+ * Worker that will refresh DFS cache and active mounts based on lowest TTL value from a DFS
  * referral.
  */
 static void refresh_cache_worker(struct work_struct *work)
 {
-	struct vol_info *vi, *nvi;
-	struct TCP_Server_Info *server;
-	LIST_HEAD(vols);
-	LIST_HEAD(tcons);
-	struct cifs_tcon *tcon, *ntcon;
-	int rc;
+	struct list_head mglist;
+	struct mount_group *mg, *tmp_mg;
+	struct cifs_ses *sessions[CACHE_MAX_ENTRIES + 1] = {NULL};
+	int max_sessions = ARRAY_SIZE(sessions) - 1;
+	int i = 0, count;
 
-	/*
-	 * Find SMB volumes that are eligible (server->tcpStatus == CifsGood)
-	 * for refreshing.
-	 */
-	spin_lock(&vol_list_lock);
-	list_for_each_entry(vi, &vol_list, list) {
-		server = get_tcp_server(&vi->ctx);
-		if (!server)
-			continue;
+	INIT_LIST_HEAD(&mglist);
 
-		kref_get(&vi->refcnt);
-		list_add_tail(&vi->rlist, &vols);
-		put_tcp_server(server);
+	/* Get refereces of mount groups */
+	mutex_lock(&mount_group_list_lock);
+	list_for_each_entry(mg, &mount_group_list, list) {
+		kref_get(&mg->refcount);
+		list_add(&mg->refresh_list, &mglist);
 	}
-	spin_unlock(&vol_list_lock);
+	mutex_unlock(&mount_group_list_lock);
 
-	/* Walk through all TCONs and refresh any expired cache entry */
-	list_for_each_entry_safe(vi, nvi, &vols, rlist) {
-		spin_lock(&vi->ctx_lock);
-		server = get_tcp_server(&vi->ctx);
-		spin_unlock(&vi->ctx_lock);
+	/* Fill in local array with an NULL-terminated list of all referral server sessions */
+	list_for_each_entry(mg, &mglist, refresh_list) {
+		if (i >= max_sessions)
+			break;
 
-		if (!server)
-			goto next_vol;
+		spin_lock(&mg->lock);
+		if (i + mg->num_sessions > max_sessions)
+			count = max_sessions - i;
+		else
+			count = mg->num_sessions;
+		memcpy(&sessions[i], mg->sessions, count * sizeof(mg->sessions[0]));
+		spin_unlock(&mg->lock);
+		i += count;
+	}
 
-		get_tcons(server, &tcons);
-		rc = 0;
+	if (sessions[0]) {
+		/* Refresh all active mounts and cached entries */
+		refresh_mounts(sessions);
+		refresh_cache(sessions);
+	}
 
-		list_for_each_entry_safe(tcon, ntcon, &tcons, ulist) {
-			/*
-			 * Skip tcp server if any of its tcons failed to refresh
-			 * (possibily due to reconnects).
-			 */
-			if (!rc)
-				rc = refresh_tcon(vi, tcon);
-
-			list_del_init(&tcon->ulist);
-			cifs_put_tcon(tcon);
-		}
-
-		put_tcp_server(server);
-
-next_vol:
-		list_del_init(&vi->rlist);
-		kref_put(&vi->refcnt, vol_release);
+	list_for_each_entry_safe(mg, tmp_mg, &mglist, refresh_list) {
+		list_del_init(&mg->refresh_list);
+		kref_put(&mg->refcount, mount_group_release);
 	}
 
 	spin_lock(&cache_ttl_lock);