Merge tag 'fuse-update-5.10' of git://git.kernel.org/pub/scm/linux/kernel/git/mszeredi/fuse

Pull fuse updates from Miklos Szeredi:

 - Support directly accessing host page cache from virtiofs. This can
   improve I/O performance for various workloads, as well as reducing
   the memory requirement by eliminating double caching. Thanks to Vivek
   Goyal for doing most of the work on this.

 - Allow automatic submounting inside virtiofs. This allows unique
   st_dev/ st_ino values to be assigned inside the guest to files
   residing on different filesystems on the host. Thanks to Max Reitz
   for the patches.

 - Fix an old use after free bug found by Pradeep P V K.

* tag 'fuse-update-5.10' of git://git.kernel.org/pub/scm/linux/kernel/git/mszeredi/fuse: (25 commits)
  virtiofs: calculate number of scatter-gather elements accurately
  fuse: connection remove fix
  fuse: implement crossmounts
  fuse: Allow fuse_fill_super_common() for submounts
  fuse: split fuse_mount off of fuse_conn
  fuse: drop fuse_conn parameter where possible
  fuse: store fuse_conn in fuse_req
  fuse: add submount support to <uapi/linux/fuse.h>
  fuse: fix page dereference after free
  virtiofs: add logic to free up a memory range
  virtiofs: maintain a list of busy elements
  virtiofs: serialize truncate/punch_hole and dax fault path
  virtiofs: define dax address space operations
  virtiofs: add DAX mmap support
  virtiofs: implement dax read/write operations
  virtiofs: introduce setupmapping/removemapping commands
  virtiofs: implement FUSE_INIT map_alignment field
  virtiofs: keep a list of free dax memory ranges
  virtiofs: add a mount option to enable dax
  virtiofs: set up virtio_fs dax_device
  ...
diff --git a/fs/fuse/virtio_fs.c b/fs/fuse/virtio_fs.c
index 104f35d..21a9e53 100644
--- a/fs/fuse/virtio_fs.c
+++ b/fs/fuse/virtio_fs.c
@@ -5,12 +5,17 @@
  */
 
 #include <linux/fs.h>
+#include <linux/dax.h>
+#include <linux/pci.h>
+#include <linux/pfn_t.h>
 #include <linux/module.h>
 #include <linux/virtio.h>
 #include <linux/virtio_fs.h>
 #include <linux/delay.h>
 #include <linux/fs_context.h>
+#include <linux/fs_parser.h>
 #include <linux/highmem.h>
+#include <linux/uio.h>
 #include "fuse_i.h"
 
 /* List of virtio-fs device instances and a lock for the list. Also provides
@@ -24,6 +29,8 @@ enum {
 	VQ_REQUEST
 };
 
+#define VQ_NAME_LEN	24
+
 /* Per-virtqueue state */
 struct virtio_fs_vq {
 	spinlock_t lock;
@@ -36,7 +43,7 @@ struct virtio_fs_vq {
 	bool connected;
 	long in_flight;
 	struct completion in_flight_zero; /* No inflight requests */
-	char name[24];
+	char name[VQ_NAME_LEN];
 } ____cacheline_aligned_in_smp;
 
 /* A virtio-fs device instance */
@@ -47,6 +54,12 @@ struct virtio_fs {
 	struct virtio_fs_vq *vqs;
 	unsigned int nvqs;               /* number of virtqueues */
 	unsigned int num_request_queues; /* number of request queues */
+	struct dax_device *dax_dev;
+
+	/* DAX memory window where file contents are mapped */
+	void *window_kaddr;
+	phys_addr_t window_phys_addr;
+	size_t window_len;
 };
 
 struct virtio_fs_forget_req {
@@ -69,6 +82,44 @@ struct virtio_fs_req_work {
 static int virtio_fs_enqueue_req(struct virtio_fs_vq *fsvq,
 				 struct fuse_req *req, bool in_flight);
 
+enum {
+	OPT_DAX,
+};
+
+static const struct fs_parameter_spec virtio_fs_parameters[] = {
+	fsparam_flag("dax", OPT_DAX),
+	{}
+};
+
+static int virtio_fs_parse_param(struct fs_context *fc,
+				 struct fs_parameter *param)
+{
+	struct fs_parse_result result;
+	struct fuse_fs_context *ctx = fc->fs_private;
+	int opt;
+
+	opt = fs_parse(fc, virtio_fs_parameters, param, &result);
+	if (opt < 0)
+		return opt;
+
+	switch (opt) {
+	case OPT_DAX:
+		ctx->dax = 1;
+		break;
+	default:
+		return -EINVAL;
+	}
+
+	return 0;
+}
+
+static void virtio_fs_free_fc(struct fs_context *fc)
+{
+	struct fuse_fs_context *ctx = fc->fs_private;
+
+	kfree(ctx);
+}
+
 static inline struct virtio_fs_vq *vq_to_fsvq(struct virtqueue *vq)
 {
 	struct virtio_fs *fs = vq->vdev->priv;
@@ -289,7 +340,6 @@ static void virtio_fs_request_dispatch_work(struct work_struct *work)
 	struct fuse_req *req;
 	struct virtio_fs_vq *fsvq = container_of(work, struct virtio_fs_vq,
 						 dispatch_work.work);
-	struct fuse_conn *fc = fsvq->fud->fc;
 	int ret;
 
 	pr_debug("virtio-fs: worker %s called.\n", __func__);
@@ -304,7 +354,7 @@ static void virtio_fs_request_dispatch_work(struct work_struct *work)
 
 		list_del_init(&req->list);
 		spin_unlock(&fsvq->lock);
-		fuse_request_end(fc, req);
+		fuse_request_end(req);
 	}
 
 	/* Dispatch pending requests */
@@ -335,7 +385,7 @@ static void virtio_fs_request_dispatch_work(struct work_struct *work)
 			spin_unlock(&fsvq->lock);
 			pr_err("virtio-fs: virtio_fs_enqueue_req() failed %d\n",
 			       ret);
-			fuse_request_end(fc, req);
+			fuse_request_end(req);
 		}
 	}
 }
@@ -495,7 +545,6 @@ static void virtio_fs_request_complete(struct fuse_req *req,
 				       struct virtio_fs_vq *fsvq)
 {
 	struct fuse_pqueue *fpq = &fsvq->fud->pq;
-	struct fuse_conn *fc = fsvq->fud->fc;
 	struct fuse_args *args;
 	struct fuse_args_pages *ap;
 	unsigned int len, i, thislen;
@@ -528,7 +577,7 @@ static void virtio_fs_request_complete(struct fuse_req *req,
 	clear_bit(FR_SENT, &req->flags);
 	spin_unlock(&fpq->lock);
 
-	fuse_request_end(fc, req);
+	fuse_request_end(req);
 	spin_lock(&fsvq->lock);
 	dec_in_flight_req(fsvq);
 	spin_unlock(&fsvq->lock);
@@ -596,6 +645,26 @@ static void virtio_fs_vq_done(struct virtqueue *vq)
 	schedule_work(&fsvq->done_work);
 }
 
+static void virtio_fs_init_vq(struct virtio_fs_vq *fsvq, char *name,
+			      int vq_type)
+{
+	strncpy(fsvq->name, name, VQ_NAME_LEN);
+	spin_lock_init(&fsvq->lock);
+	INIT_LIST_HEAD(&fsvq->queued_reqs);
+	INIT_LIST_HEAD(&fsvq->end_reqs);
+	init_completion(&fsvq->in_flight_zero);
+
+	if (vq_type == VQ_REQUEST) {
+		INIT_WORK(&fsvq->done_work, virtio_fs_requests_done_work);
+		INIT_DELAYED_WORK(&fsvq->dispatch_work,
+				  virtio_fs_request_dispatch_work);
+	} else {
+		INIT_WORK(&fsvq->done_work, virtio_fs_hiprio_done_work);
+		INIT_DELAYED_WORK(&fsvq->dispatch_work,
+				  virtio_fs_hiprio_dispatch_work);
+	}
+}
+
 /* Initialize virtqueues */
 static int virtio_fs_setup_vqs(struct virtio_device *vdev,
 			       struct virtio_fs *fs)
@@ -611,7 +680,7 @@ static int virtio_fs_setup_vqs(struct virtio_device *vdev,
 	if (fs->num_request_queues == 0)
 		return -EINVAL;
 
-	fs->nvqs = 1 + fs->num_request_queues;
+	fs->nvqs = VQ_REQUEST + fs->num_request_queues;
 	fs->vqs = kcalloc(fs->nvqs, sizeof(fs->vqs[VQ_HIPRIO]), GFP_KERNEL);
 	if (!fs->vqs)
 		return -ENOMEM;
@@ -625,29 +694,17 @@ static int virtio_fs_setup_vqs(struct virtio_device *vdev,
 		goto out;
 	}
 
+	/* Initialize the hiprio/forget request virtqueue */
 	callbacks[VQ_HIPRIO] = virtio_fs_vq_done;
-	snprintf(fs->vqs[VQ_HIPRIO].name, sizeof(fs->vqs[VQ_HIPRIO].name),
-			"hiprio");
+	virtio_fs_init_vq(&fs->vqs[VQ_HIPRIO], "hiprio", VQ_HIPRIO);
 	names[VQ_HIPRIO] = fs->vqs[VQ_HIPRIO].name;
-	INIT_WORK(&fs->vqs[VQ_HIPRIO].done_work, virtio_fs_hiprio_done_work);
-	INIT_LIST_HEAD(&fs->vqs[VQ_HIPRIO].queued_reqs);
-	INIT_LIST_HEAD(&fs->vqs[VQ_HIPRIO].end_reqs);
-	INIT_DELAYED_WORK(&fs->vqs[VQ_HIPRIO].dispatch_work,
-			virtio_fs_hiprio_dispatch_work);
-	init_completion(&fs->vqs[VQ_HIPRIO].in_flight_zero);
-	spin_lock_init(&fs->vqs[VQ_HIPRIO].lock);
 
 	/* Initialize the requests virtqueues */
 	for (i = VQ_REQUEST; i < fs->nvqs; i++) {
-		spin_lock_init(&fs->vqs[i].lock);
-		INIT_WORK(&fs->vqs[i].done_work, virtio_fs_requests_done_work);
-		INIT_DELAYED_WORK(&fs->vqs[i].dispatch_work,
-				  virtio_fs_request_dispatch_work);
-		INIT_LIST_HEAD(&fs->vqs[i].queued_reqs);
-		INIT_LIST_HEAD(&fs->vqs[i].end_reqs);
-		init_completion(&fs->vqs[i].in_flight_zero);
-		snprintf(fs->vqs[i].name, sizeof(fs->vqs[i].name),
-			 "requests.%u", i - VQ_REQUEST);
+		char vq_name[VQ_NAME_LEN];
+
+		snprintf(vq_name, VQ_NAME_LEN, "requests.%u", i - VQ_REQUEST);
+		virtio_fs_init_vq(&fs->vqs[i], vq_name, VQ_REQUEST);
 		callbacks[i] = virtio_fs_vq_done;
 		names[i] = fs->vqs[i].name;
 	}
@@ -676,6 +733,130 @@ static void virtio_fs_cleanup_vqs(struct virtio_device *vdev,
 	vdev->config->del_vqs(vdev);
 }
 
+/* Map a window offset to a page frame number.  The window offset will have
+ * been produced by .iomap_begin(), which maps a file offset to a window
+ * offset.
+ */
+static long virtio_fs_direct_access(struct dax_device *dax_dev, pgoff_t pgoff,
+				    long nr_pages, void **kaddr, pfn_t *pfn)
+{
+	struct virtio_fs *fs = dax_get_private(dax_dev);
+	phys_addr_t offset = PFN_PHYS(pgoff);
+	size_t max_nr_pages = fs->window_len/PAGE_SIZE - pgoff;
+
+	if (kaddr)
+		*kaddr = fs->window_kaddr + offset;
+	if (pfn)
+		*pfn = phys_to_pfn_t(fs->window_phys_addr + offset,
+					PFN_DEV | PFN_MAP);
+	return nr_pages > max_nr_pages ? max_nr_pages : nr_pages;
+}
+
+static size_t virtio_fs_copy_from_iter(struct dax_device *dax_dev,
+				       pgoff_t pgoff, void *addr,
+				       size_t bytes, struct iov_iter *i)
+{
+	return copy_from_iter(addr, bytes, i);
+}
+
+static size_t virtio_fs_copy_to_iter(struct dax_device *dax_dev,
+				       pgoff_t pgoff, void *addr,
+				       size_t bytes, struct iov_iter *i)
+{
+	return copy_to_iter(addr, bytes, i);
+}
+
+static int virtio_fs_zero_page_range(struct dax_device *dax_dev,
+				     pgoff_t pgoff, size_t nr_pages)
+{
+	long rc;
+	void *kaddr;
+
+	rc = dax_direct_access(dax_dev, pgoff, nr_pages, &kaddr, NULL);
+	if (rc < 0)
+		return rc;
+	memset(kaddr, 0, nr_pages << PAGE_SHIFT);
+	dax_flush(dax_dev, kaddr, nr_pages << PAGE_SHIFT);
+	return 0;
+}
+
+static const struct dax_operations virtio_fs_dax_ops = {
+	.direct_access = virtio_fs_direct_access,
+	.copy_from_iter = virtio_fs_copy_from_iter,
+	.copy_to_iter = virtio_fs_copy_to_iter,
+	.zero_page_range = virtio_fs_zero_page_range,
+};
+
+static void virtio_fs_cleanup_dax(void *data)
+{
+	struct dax_device *dax_dev = data;
+
+	kill_dax(dax_dev);
+	put_dax(dax_dev);
+}
+
+static int virtio_fs_setup_dax(struct virtio_device *vdev, struct virtio_fs *fs)
+{
+	struct virtio_shm_region cache_reg;
+	struct dev_pagemap *pgmap;
+	bool have_cache;
+
+	if (!IS_ENABLED(CONFIG_FUSE_DAX))
+		return 0;
+
+	/* Get cache region */
+	have_cache = virtio_get_shm_region(vdev, &cache_reg,
+					   (u8)VIRTIO_FS_SHMCAP_ID_CACHE);
+	if (!have_cache) {
+		dev_notice(&vdev->dev, "%s: No cache capability\n", __func__);
+		return 0;
+	}
+
+	if (!devm_request_mem_region(&vdev->dev, cache_reg.addr, cache_reg.len,
+				     dev_name(&vdev->dev))) {
+		dev_warn(&vdev->dev, "could not reserve region addr=0x%llx len=0x%llx\n",
+			 cache_reg.addr, cache_reg.len);
+		return -EBUSY;
+	}
+
+	dev_notice(&vdev->dev, "Cache len: 0x%llx @ 0x%llx\n", cache_reg.len,
+		   cache_reg.addr);
+
+	pgmap = devm_kzalloc(&vdev->dev, sizeof(*pgmap), GFP_KERNEL);
+	if (!pgmap)
+		return -ENOMEM;
+
+	pgmap->type = MEMORY_DEVICE_FS_DAX;
+
+	/* Ideally we would directly use the PCI BAR resource but
+	 * devm_memremap_pages() wants its own copy in pgmap.  So
+	 * initialize a struct resource from scratch (only the start
+	 * and end fields will be used).
+	 */
+	pgmap->range = (struct range) {
+		.start = (phys_addr_t) cache_reg.addr,
+		.end = (phys_addr_t) cache_reg.addr + cache_reg.len - 1,
+	};
+	pgmap->nr_range = 1;
+
+	fs->window_kaddr = devm_memremap_pages(&vdev->dev, pgmap);
+	if (IS_ERR(fs->window_kaddr))
+		return PTR_ERR(fs->window_kaddr);
+
+	fs->window_phys_addr = (phys_addr_t) cache_reg.addr;
+	fs->window_len = (phys_addr_t) cache_reg.len;
+
+	dev_dbg(&vdev->dev, "%s: window kaddr 0x%px phys_addr 0x%llx len 0x%llx\n",
+		__func__, fs->window_kaddr, cache_reg.addr, cache_reg.len);
+
+	fs->dax_dev = alloc_dax(fs, NULL, &virtio_fs_dax_ops, 0);
+	if (IS_ERR(fs->dax_dev))
+		return PTR_ERR(fs->dax_dev);
+
+	return devm_add_action_or_reset(&vdev->dev, virtio_fs_cleanup_dax,
+					fs->dax_dev);
+}
+
 static int virtio_fs_probe(struct virtio_device *vdev)
 {
 	struct virtio_fs *fs;
@@ -697,6 +878,10 @@ static int virtio_fs_probe(struct virtio_device *vdev)
 
 	/* TODO vq affinity */
 
+	ret = virtio_fs_setup_dax(vdev, fs);
+	if (ret < 0)
+		goto out_vqs;
+
 	/* Bring the device online in case the filesystem is mounted and
 	 * requests need to be sent before we return.
 	 */
@@ -833,18 +1018,37 @@ __releases(fiq->lock)
 	spin_unlock(&fiq->lock);
 }
 
+/* Count number of scatter-gather elements required */
+static unsigned int sg_count_fuse_pages(struct fuse_page_desc *page_descs,
+				       unsigned int num_pages,
+				       unsigned int total_len)
+{
+	unsigned int i;
+	unsigned int this_len;
+
+	for (i = 0; i < num_pages && total_len; i++) {
+		this_len =  min(page_descs[i].length, total_len);
+		total_len -= this_len;
+	}
+
+	return i;
+}
+
 /* Return the number of scatter-gather list elements required */
 static unsigned int sg_count_fuse_req(struct fuse_req *req)
 {
 	struct fuse_args *args = req->args;
 	struct fuse_args_pages *ap = container_of(args, typeof(*ap), args);
-	unsigned int total_sgs = 1 /* fuse_in_header */;
+	unsigned int size, total_sgs = 1 /* fuse_in_header */;
 
 	if (args->in_numargs - args->in_pages)
 		total_sgs += 1;
 
-	if (args->in_pages)
-		total_sgs += ap->num_pages;
+	if (args->in_pages) {
+		size = args->in_args[args->in_numargs - 1].size;
+		total_sgs += sg_count_fuse_pages(ap->descs, ap->num_pages,
+						 size);
+	}
 
 	if (!test_bit(FR_ISREPLY, &req->flags))
 		return total_sgs;
@@ -854,8 +1058,11 @@ static unsigned int sg_count_fuse_req(struct fuse_req *req)
 	if (args->out_numargs - args->out_pages)
 		total_sgs += 1;
 
-	if (args->out_pages)
-		total_sgs += ap->num_pages;
+	if (args->out_pages) {
+		size = args->out_args[args->out_numargs - 1].size;
+		total_sgs += sg_count_fuse_pages(ap->descs, ap->num_pages,
+						 size);
+	}
 
 	return total_sgs;
 }
@@ -1071,24 +1278,28 @@ static const struct fuse_iqueue_ops virtio_fs_fiq_ops = {
 	.release			= virtio_fs_fiq_release,
 };
 
-static int virtio_fs_fill_super(struct super_block *sb)
+static inline void virtio_fs_ctx_set_defaults(struct fuse_fs_context *ctx)
 {
-	struct fuse_conn *fc = get_fuse_conn_super(sb);
+	ctx->rootmode = S_IFDIR;
+	ctx->default_permissions = 1;
+	ctx->allow_other = 1;
+	ctx->max_read = UINT_MAX;
+	ctx->blksize = 512;
+	ctx->destroy = true;
+	ctx->no_control = true;
+	ctx->no_force_umount = true;
+}
+
+static int virtio_fs_fill_super(struct super_block *sb, struct fs_context *fsc)
+{
+	struct fuse_mount *fm = get_fuse_mount_super(sb);
+	struct fuse_conn *fc = fm->fc;
 	struct virtio_fs *fs = fc->iq.priv;
+	struct fuse_fs_context *ctx = fsc->fs_private;
 	unsigned int i;
 	int err;
-	struct fuse_fs_context ctx = {
-		.rootmode = S_IFDIR,
-		.default_permissions = 1,
-		.allow_other = 1,
-		.max_read = UINT_MAX,
-		.blksize = 512,
-		.destroy = true,
-		.no_control = true,
-		.no_force_umount = true,
-		.no_mount_options = true,
-	};
 
+	virtio_fs_ctx_set_defaults(ctx);
 	mutex_lock(&virtio_fs_mutex);
 
 	/* After holding mutex, make sure virtiofs device is still there.
@@ -1112,8 +1323,10 @@ static int virtio_fs_fill_super(struct super_block *sb)
 	}
 
 	/* virtiofs allocates and installs its own fuse devices */
-	ctx.fudptr = NULL;
-	err = fuse_fill_super_common(sb, &ctx);
+	ctx->fudptr = NULL;
+	if (ctx->dax)
+		ctx->dax_dev = fs->dax_dev;
+	err = fuse_fill_super_common(sb, ctx);
 	if (err < 0)
 		goto err_free_fuse_devs;
 
@@ -1125,7 +1338,7 @@ static int virtio_fs_fill_super(struct super_block *sb)
 
 	/* Previous unmount will stop all queues. Start these again */
 	virtio_fs_start_all_queues(fs);
-	fuse_send_init(fc);
+	fuse_send_init(fm);
 	mutex_unlock(&virtio_fs_mutex);
 	return 0;
 
@@ -1136,18 +1349,17 @@ static int virtio_fs_fill_super(struct super_block *sb)
 	return err;
 }
 
-static void virtio_kill_sb(struct super_block *sb)
+static void virtio_fs_conn_destroy(struct fuse_mount *fm)
 {
-	struct fuse_conn *fc = get_fuse_conn_super(sb);
-	struct virtio_fs *vfs;
-	struct virtio_fs_vq *fsvq;
+	struct fuse_conn *fc = fm->fc;
+	struct virtio_fs *vfs = fc->iq.priv;
+	struct virtio_fs_vq *fsvq = &vfs->vqs[VQ_HIPRIO];
 
-	/* If mount failed, we can still be called without any fc */
-	if (!fc)
-		return fuse_kill_sb_anon(sb);
-
-	vfs = fc->iq.priv;
-	fsvq = &vfs->vqs[VQ_HIPRIO];
+	/* Stop dax worker. Soon evict_inodes() will be called which
+	 * will free all memory ranges belonging to all inodes.
+	 */
+	if (IS_ENABLED(CONFIG_FUSE_DAX))
+		fuse_dax_cancel_work(fc);
 
 	/* Stop forget queue. Soon destroy will be sent */
 	spin_lock(&fsvq->lock);
@@ -1155,9 +1367,9 @@ static void virtio_kill_sb(struct super_block *sb)
 	spin_unlock(&fsvq->lock);
 	virtio_fs_drain_all_queues(vfs);
 
-	fuse_kill_sb_anon(sb);
+	fuse_conn_destroy(fm);
 
-	/* fuse_kill_sb_anon() must have sent destroy. Stop all queues
+	/* fuse_conn_destroy() must have sent destroy. Stop all queues
 	 * and drain one more time and free fuse devices. Freeing fuse
 	 * devices will drop their reference on fuse_conn and that in
 	 * turn will drop its reference on virtio_fs object.
@@ -1167,12 +1379,27 @@ static void virtio_kill_sb(struct super_block *sb)
 	virtio_fs_free_devs(vfs);
 }
 
+static void virtio_kill_sb(struct super_block *sb)
+{
+	struct fuse_mount *fm = get_fuse_mount_super(sb);
+	bool last;
+
+	/* If mount failed, we can still be called without any fc */
+	if (fm) {
+		last = fuse_mount_remove(fm);
+		if (last)
+			virtio_fs_conn_destroy(fm);
+	}
+	kill_anon_super(sb);
+}
+
 static int virtio_fs_test_super(struct super_block *sb,
 				struct fs_context *fsc)
 {
-	struct fuse_conn *fc = fsc->s_fs_info;
+	struct fuse_mount *fsc_fm = fsc->s_fs_info;
+	struct fuse_mount *sb_fm = get_fuse_mount_super(sb);
 
-	return fc->iq.priv == get_fuse_conn_super(sb)->iq.priv;
+	return fsc_fm->fc->iq.priv == sb_fm->fc->iq.priv;
 }
 
 static int virtio_fs_set_super(struct super_block *sb,
@@ -1182,7 +1409,7 @@ static int virtio_fs_set_super(struct super_block *sb,
 
 	err = get_anon_bdev(&sb->s_dev);
 	if (!err)
-		fuse_conn_get(fsc->s_fs_info);
+		fuse_mount_get(fsc->s_fs_info);
 
 	return err;
 }
@@ -1192,6 +1419,7 @@ static int virtio_fs_get_tree(struct fs_context *fsc)
 	struct virtio_fs *fs;
 	struct super_block *sb;
 	struct fuse_conn *fc;
+	struct fuse_mount *fm;
 	int err;
 
 	/* This gets a reference on virtio_fs object. This ptr gets installed
@@ -1212,19 +1440,29 @@ static int virtio_fs_get_tree(struct fs_context *fsc)
 		return -ENOMEM;
 	}
 
-	fuse_conn_init(fc, get_user_ns(current_user_ns()), &virtio_fs_fiq_ops,
-		       fs);
+	fm = kzalloc(sizeof(struct fuse_mount), GFP_KERNEL);
+	if (!fm) {
+		mutex_lock(&virtio_fs_mutex);
+		virtio_fs_put(fs);
+		mutex_unlock(&virtio_fs_mutex);
+		kfree(fc);
+		return -ENOMEM;
+	}
+
+	fuse_conn_init(fc, fm, get_user_ns(current_user_ns()),
+		       &virtio_fs_fiq_ops, fs);
 	fc->release = fuse_free_conn;
 	fc->delete_stale = true;
+	fc->auto_submounts = true;
 
-	fsc->s_fs_info = fc;
+	fsc->s_fs_info = fm;
 	sb = sget_fc(fsc, virtio_fs_test_super, virtio_fs_set_super);
-	fuse_conn_put(fc);
+	fuse_mount_put(fm);
 	if (IS_ERR(sb))
 		return PTR_ERR(sb);
 
 	if (!sb->s_root) {
-		err = virtio_fs_fill_super(sb);
+		err = virtio_fs_fill_super(sb, fsc);
 		if (err) {
 			deactivate_locked_super(sb);
 			return err;
@@ -1239,11 +1477,19 @@ static int virtio_fs_get_tree(struct fs_context *fsc)
 }
 
 static const struct fs_context_operations virtio_fs_context_ops = {
+	.free		= virtio_fs_free_fc,
+	.parse_param	= virtio_fs_parse_param,
 	.get_tree	= virtio_fs_get_tree,
 };
 
 static int virtio_fs_init_fs_context(struct fs_context *fsc)
 {
+	struct fuse_fs_context *ctx;
+
+	ctx = kzalloc(sizeof(struct fuse_fs_context), GFP_KERNEL);
+	if (!ctx)
+		return -ENOMEM;
+	fsc->fs_private = ctx;
 	fsc->ops = &virtio_fs_context_ops;
 	return 0;
 }