uacce: Remove mm_exit() op

The mm_exit() op will be removed from the SVA API. When a process dies
and its mm goes away, the IOMMU driver won't notify device drivers
anymore. Drivers should expect to handle a lot more aborted DMA. On the
upside, it does greatly simplify the queue management.

The uacce_mm struct, that tracks all queues bound to an mm, was only
used by the mm_exit() callback. Remove it.

Signed-off-by: Jean-Philippe Brucker <jean-philippe@linaro.org>
Acked-by: Jacob Pan <jacob.jun.pan@linux.intel.com>
Acked-by: Lu Baolu <baolu.lu@linux.intel.com>
Acked-by: Zhangfei Gao <zhangfei.gao@linaro.org>
Link: https://lore.kernel.org/r/20200423125329.782066-2-jean-philippe@linaro.org
Signed-off-by: Joerg Roedel <jroedel@suse.de>
diff --git a/drivers/misc/uacce/uacce.c b/drivers/misc/uacce/uacce.c
index d39307f..107028e 100644
--- a/drivers/misc/uacce/uacce.c
+++ b/drivers/misc/uacce/uacce.c
@@ -90,109 +90,39 @@ static long uacce_fops_compat_ioctl(struct file *filep,
 }
 #endif
 
-static int uacce_sva_exit(struct device *dev, struct iommu_sva *handle,
-			  void *data)
+static int uacce_bind_queue(struct uacce_device *uacce, struct uacce_queue *q)
 {
-	struct uacce_mm *uacce_mm = data;
-	struct uacce_queue *q;
+	int pasid;
+	struct iommu_sva *handle;
 
-	/*
-	 * No new queue can be added concurrently because no caller can have a
-	 * reference to this mm. But there may be concurrent calls to
-	 * uacce_mm_put(), so we need the lock.
-	 */
-	mutex_lock(&uacce_mm->lock);
-	list_for_each_entry(q, &uacce_mm->queues, list)
-		uacce_put_queue(q);
-	uacce_mm->mm = NULL;
-	mutex_unlock(&uacce_mm->lock);
+	if (!(uacce->flags & UACCE_DEV_SVA))
+		return 0;
 
+	handle = iommu_sva_bind_device(uacce->parent, current->mm, NULL);
+	if (IS_ERR(handle))
+		return PTR_ERR(handle);
+
+	pasid = iommu_sva_get_pasid(handle);
+	if (pasid == IOMMU_PASID_INVALID) {
+		iommu_sva_unbind_device(handle);
+		return -ENODEV;
+	}
+
+	q->handle = handle;
+	q->pasid = pasid;
 	return 0;
 }
 
-static struct iommu_sva_ops uacce_sva_ops = {
-	.mm_exit = uacce_sva_exit,
-};
-
-static struct uacce_mm *uacce_mm_get(struct uacce_device *uacce,
-				     struct uacce_queue *q,
-				     struct mm_struct *mm)
+static void uacce_unbind_queue(struct uacce_queue *q)
 {
-	struct uacce_mm *uacce_mm = NULL;
-	struct iommu_sva *handle = NULL;
-	int ret;
-
-	lockdep_assert_held(&uacce->mm_lock);
-
-	list_for_each_entry(uacce_mm, &uacce->mm_list, list) {
-		if (uacce_mm->mm == mm) {
-			mutex_lock(&uacce_mm->lock);
-			list_add(&q->list, &uacce_mm->queues);
-			mutex_unlock(&uacce_mm->lock);
-			return uacce_mm;
-		}
-	}
-
-	uacce_mm = kzalloc(sizeof(*uacce_mm), GFP_KERNEL);
-	if (!uacce_mm)
-		return NULL;
-
-	if (uacce->flags & UACCE_DEV_SVA) {
-		/*
-		 * Safe to pass an incomplete uacce_mm, since mm_exit cannot
-		 * fire while we hold a reference to the mm.
-		 */
-		handle = iommu_sva_bind_device(uacce->parent, mm, uacce_mm);
-		if (IS_ERR(handle))
-			goto err_free;
-
-		ret = iommu_sva_set_ops(handle, &uacce_sva_ops);
-		if (ret)
-			goto err_unbind;
-
-		uacce_mm->pasid = iommu_sva_get_pasid(handle);
-		if (uacce_mm->pasid == IOMMU_PASID_INVALID)
-			goto err_unbind;
-	}
-
-	uacce_mm->mm = mm;
-	uacce_mm->handle = handle;
-	INIT_LIST_HEAD(&uacce_mm->queues);
-	mutex_init(&uacce_mm->lock);
-	list_add(&q->list, &uacce_mm->queues);
-	list_add(&uacce_mm->list, &uacce->mm_list);
-
-	return uacce_mm;
-
-err_unbind:
-	if (handle)
-		iommu_sva_unbind_device(handle);
-err_free:
-	kfree(uacce_mm);
-	return NULL;
-}
-
-static void uacce_mm_put(struct uacce_queue *q)
-{
-	struct uacce_mm *uacce_mm = q->uacce_mm;
-
-	lockdep_assert_held(&q->uacce->mm_lock);
-
-	mutex_lock(&uacce_mm->lock);
-	list_del(&q->list);
-	mutex_unlock(&uacce_mm->lock);
-
-	if (list_empty(&uacce_mm->queues)) {
-		if (uacce_mm->handle)
-			iommu_sva_unbind_device(uacce_mm->handle);
-		list_del(&uacce_mm->list);
-		kfree(uacce_mm);
-	}
+	if (!q->handle)
+		return;
+	iommu_sva_unbind_device(q->handle);
+	q->handle = NULL;
 }
 
 static int uacce_fops_open(struct inode *inode, struct file *filep)
 {
-	struct uacce_mm *uacce_mm = NULL;
 	struct uacce_device *uacce;
 	struct uacce_queue *q;
 	int ret = 0;
@@ -205,21 +135,16 @@ static int uacce_fops_open(struct inode *inode, struct file *filep)
 	if (!q)
 		return -ENOMEM;
 
-	mutex_lock(&uacce->mm_lock);
-	uacce_mm = uacce_mm_get(uacce, q, current->mm);
-	mutex_unlock(&uacce->mm_lock);
-	if (!uacce_mm) {
-		ret = -ENOMEM;
+	ret = uacce_bind_queue(uacce, q);
+	if (ret)
 		goto out_with_mem;
-	}
 
 	q->uacce = uacce;
-	q->uacce_mm = uacce_mm;
 
 	if (uacce->ops->get_queue) {
-		ret = uacce->ops->get_queue(uacce, uacce_mm->pasid, q);
+		ret = uacce->ops->get_queue(uacce, q->pasid, q);
 		if (ret < 0)
-			goto out_with_mm;
+			goto out_with_bond;
 	}
 
 	init_waitqueue_head(&q->wait);
@@ -227,12 +152,14 @@ static int uacce_fops_open(struct inode *inode, struct file *filep)
 	uacce->inode = inode;
 	q->state = UACCE_Q_INIT;
 
+	mutex_lock(&uacce->queues_lock);
+	list_add(&q->list, &uacce->queues);
+	mutex_unlock(&uacce->queues_lock);
+
 	return 0;
 
-out_with_mm:
-	mutex_lock(&uacce->mm_lock);
-	uacce_mm_put(q);
-	mutex_unlock(&uacce->mm_lock);
+out_with_bond:
+	uacce_unbind_queue(q);
 out_with_mem:
 	kfree(q);
 	return ret;
@@ -241,14 +168,12 @@ static int uacce_fops_open(struct inode *inode, struct file *filep)
 static int uacce_fops_release(struct inode *inode, struct file *filep)
 {
 	struct uacce_queue *q = filep->private_data;
-	struct uacce_device *uacce = q->uacce;
 
+	mutex_lock(&q->uacce->queues_lock);
+	list_del(&q->list);
+	mutex_unlock(&q->uacce->queues_lock);
 	uacce_put_queue(q);
-
-	mutex_lock(&uacce->mm_lock);
-	uacce_mm_put(q);
-	mutex_unlock(&uacce->mm_lock);
-
+	uacce_unbind_queue(q);
 	kfree(q);
 
 	return 0;
@@ -513,8 +438,8 @@ struct uacce_device *uacce_alloc(struct device *parent,
 	if (ret < 0)
 		goto err_with_uacce;
 
-	INIT_LIST_HEAD(&uacce->mm_list);
-	mutex_init(&uacce->mm_lock);
+	INIT_LIST_HEAD(&uacce->queues);
+	mutex_init(&uacce->queues_lock);
 	device_initialize(&uacce->dev);
 	uacce->dev.devt = MKDEV(MAJOR(uacce_devt), uacce->dev_id);
 	uacce->dev.class = uacce_class;
@@ -561,8 +486,7 @@ EXPORT_SYMBOL_GPL(uacce_register);
  */
 void uacce_remove(struct uacce_device *uacce)
 {
-	struct uacce_mm *uacce_mm;
-	struct uacce_queue *q;
+	struct uacce_queue *q, *next_q;
 
 	if (!uacce)
 		return;
@@ -574,24 +498,12 @@ void uacce_remove(struct uacce_device *uacce)
 		unmap_mapping_range(uacce->inode->i_mapping, 0, 0, 1);
 
 	/* ensure no open queue remains */
-	mutex_lock(&uacce->mm_lock);
-	list_for_each_entry(uacce_mm, &uacce->mm_list, list) {
-		/*
-		 * We don't take the uacce_mm->lock here. Since we hold the
-		 * device's mm_lock, no queue can be added to or removed from
-		 * this uacce_mm. We may run concurrently with mm_exit, but
-		 * uacce_put_queue() is serialized and iommu_sva_unbind_device()
-		 * waits for the lock that mm_exit is holding.
-		 */
-		list_for_each_entry(q, &uacce_mm->queues, list)
-			uacce_put_queue(q);
-
-		if (uacce->flags & UACCE_DEV_SVA) {
-			iommu_sva_unbind_device(uacce_mm->handle);
-			uacce_mm->handle = NULL;
-		}
+	mutex_lock(&uacce->queues_lock);
+	list_for_each_entry_safe(q, next_q, &uacce->queues, list) {
+		uacce_put_queue(q);
+		uacce_unbind_queue(q);
 	}
-	mutex_unlock(&uacce->mm_lock);
+	mutex_unlock(&uacce->queues_lock);
 
 	/* disable sva now since no opened queues */
 	if (uacce->flags & UACCE_DEV_SVA)
diff --git a/include/linux/uacce.h b/include/linux/uacce.h
index 0e215e6..454c2f6 100644
--- a/include/linux/uacce.h
+++ b/include/linux/uacce.h
@@ -68,19 +68,21 @@ enum uacce_q_state {
  * @uacce: pointer to uacce
  * @priv: private pointer
  * @wait: wait queue head
- * @list: index into uacce_mm
- * @uacce_mm: the corresponding mm
+ * @list: index into uacce queues list
  * @qfrs: pointer of qfr regions
  * @state: queue state machine
+ * @pasid: pasid associated to the mm
+ * @handle: iommu_sva handle returned by iommu_sva_bind_device()
  */
 struct uacce_queue {
 	struct uacce_device *uacce;
 	void *priv;
 	wait_queue_head_t wait;
 	struct list_head list;
-	struct uacce_mm *uacce_mm;
 	struct uacce_qfile_region *qfrs[UACCE_MAX_REGION];
 	enum uacce_q_state state;
+	int pasid;
+	struct iommu_sva *handle;
 };
 
 /**
@@ -96,8 +98,8 @@ struct uacce_queue {
  * @cdev: cdev of the uacce
  * @dev: dev of the uacce
  * @priv: private pointer of the uacce
- * @mm_list: list head of uacce_mm->list
- * @mm_lock: lock for mm_list
+ * @queues: list of queues
+ * @queues_lock: lock for queues list
  * @inode: core vfs
  */
 struct uacce_device {
@@ -112,27 +114,9 @@ struct uacce_device {
 	struct cdev *cdev;
 	struct device dev;
 	void *priv;
-	struct list_head mm_list;
-	struct mutex mm_lock;
-	struct inode *inode;
-};
-
-/**
- * struct uacce_mm - keep track of queues bound to a process
- * @list: index into uacce_device
- * @queues: list of queues
- * @mm: the mm struct
- * @lock: protects the list of queues
- * @pasid: pasid of the uacce_mm
- * @handle: iommu_sva handle return from iommu_sva_bind_device
- */
-struct uacce_mm {
-	struct list_head list;
 	struct list_head queues;
-	struct mm_struct *mm;
-	struct mutex lock;
-	int pasid;
-	struct iommu_sva *handle;
+	struct mutex queues_lock;
+	struct inode *inode;
 };
 
 #if IS_ENABLED(CONFIG_UACCE)