mm/hmm: improve driver API to work and wait over a range

A common use case for HMM mirror is user trying to mirror a range and
before they could program the hardware it get invalidated by some core mm
event.  Instead of having user re-try right away to mirror the range
provide a completion mechanism for them to wait for any active
invalidation affecting the range.

This also changes how hmm_range_snapshot() and hmm_range_fault() works by
not relying on vma so that we can drop the mmap_sem when waiting and
lookup the vma again on retry.

Link: http://lkml.kernel.org/r/20190403193318.16478-7-jglisse@redhat.com
Signed-off-by: Jérôme Glisse <jglisse@redhat.com>
Reviewed-by: Ralph Campbell <rcampbell@nvidia.com>
Cc: John Hubbard <jhubbard@nvidia.com>
Cc: Dan Williams <dan.j.williams@intel.com>
Cc: Dan Carpenter <dan.carpenter@oracle.com>
Cc: Matthew Wilcox <willy@infradead.org>
Cc: Arnd Bergmann <arnd@arndb.de>
Cc: Balbir Singh <bsingharora@gmail.com>
Cc: Ira Weiny <ira.weiny@intel.com>
Cc: Souptick Joarder <jrdr.linux@gmail.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
diff --git a/include/linux/hmm.h b/include/linux/hmm.h
index e9afd23..ec4bfa9 100644
--- a/include/linux/hmm.h
+++ b/include/linux/hmm.h
@@ -77,8 +77,34 @@
 #include <linux/migrate.h>
 #include <linux/memremap.h>
 #include <linux/completion.h>
+#include <linux/mmu_notifier.h>
 
-struct hmm;
+
+/*
+ * struct hmm - HMM per mm struct
+ *
+ * @mm: mm struct this HMM struct is bound to
+ * @lock: lock protecting ranges list
+ * @ranges: list of range being snapshotted
+ * @mirrors: list of mirrors for this mm
+ * @mmu_notifier: mmu notifier to track updates to CPU page table
+ * @mirrors_sem: read/write semaphore protecting the mirrors list
+ * @wq: wait queue for user waiting on a range invalidation
+ * @notifiers: count of active mmu notifiers
+ * @dead: is the mm dead ?
+ */
+struct hmm {
+	struct mm_struct	*mm;
+	struct kref		kref;
+	struct mutex		lock;
+	struct list_head	ranges;
+	struct list_head	mirrors;
+	struct mmu_notifier	mmu_notifier;
+	struct rw_semaphore	mirrors_sem;
+	wait_queue_head_t	wq;
+	long			notifiers;
+	bool			dead;
+};
 
 /*
  * hmm_pfn_flag_e - HMM flag enums
@@ -156,6 +182,38 @@ struct hmm_range {
 };
 
 /*
+ * hmm_range_wait_until_valid() - wait for range to be valid
+ * @range: range affected by invalidation to wait on
+ * @timeout: time out for wait in ms (ie abort wait after that period of time)
+ * Returns: true if the range is valid, false otherwise.
+ */
+static inline bool hmm_range_wait_until_valid(struct hmm_range *range,
+					      unsigned long timeout)
+{
+	/* Check if mm is dead ? */
+	if (range->hmm == NULL || range->hmm->dead || range->hmm->mm == NULL) {
+		range->valid = false;
+		return false;
+	}
+	if (range->valid)
+		return true;
+	wait_event_timeout(range->hmm->wq, range->valid || range->hmm->dead,
+			   msecs_to_jiffies(timeout));
+	/* Return current valid status just in case we get lucky */
+	return range->valid;
+}
+
+/*
+ * hmm_range_valid() - test if a range is valid or not
+ * @range: range
+ * Returns: true if the range is valid, false otherwise.
+ */
+static inline bool hmm_range_valid(struct hmm_range *range)
+{
+	return range->valid;
+}
+
+/*
  * hmm_pfn_to_page() - return struct page pointed to by a valid HMM pfn
  * @range: range use to decode HMM pfn value
  * @pfn: HMM pfn value to get corresponding struct page from
@@ -357,51 +415,66 @@ void hmm_mirror_unregister(struct hmm_mirror *mirror);
 
 
 /*
- * To snapshot the CPU page table, call hmm_vma_get_pfns(), then take a device
- * driver lock that serializes device page table updates, then call
- * hmm_vma_range_done(), to check if the snapshot is still valid. The same
- * device driver page table update lock must also be used in the
- * hmm_mirror_ops.sync_cpu_device_pagetables() callback, so that CPU page
- * table invalidation serializes on it.
- *
- * YOU MUST CALL hmm_vma_range_done() ONCE AND ONLY ONCE EACH TIME YOU CALL
- * hmm_range_snapshot() WITHOUT ERROR !
- *
- * IF YOU DO NOT FOLLOW THE ABOVE RULE THE SNAPSHOT CONTENT MIGHT BE INVALID !
+ * Please see Documentation/vm/hmm.rst for how to use the range API.
  */
+int hmm_range_register(struct hmm_range *range,
+		       struct mm_struct *mm,
+		       unsigned long start,
+		       unsigned long end);
+void hmm_range_unregister(struct hmm_range *range);
 long hmm_range_snapshot(struct hmm_range *range);
-bool hmm_vma_range_done(struct hmm_range *range);
-
+long hmm_range_fault(struct hmm_range *range, bool block);
 
 /*
- * Fault memory on behalf of device driver. Unlike handle_mm_fault(), this will
- * not migrate any device memory back to system memory. The HMM pfn array will
- * be updated with the fault result and current snapshot of the CPU page table
- * for the range.
+ * HMM_RANGE_DEFAULT_TIMEOUT - default timeout (ms) when waiting for a range
  *
- * The mmap_sem must be taken in read mode before entering and it might be
- * dropped by the function if the block argument is false. In that case, the
- * function returns -EAGAIN.
- *
- * Return value does not reflect if the fault was successful for every single
- * address or not. Therefore, the caller must to inspect the HMM pfn array to
- * determine fault status for each address.
- *
- * Trying to fault inside an invalid vma will result in -EINVAL.
- *
- * See the function description in mm/hmm.c for further documentation.
+ * When waiting for mmu notifiers we need some kind of time out otherwise we
+ * could potentialy wait for ever, 1000ms ie 1s sounds like a long time to
+ * wait already.
  */
-long hmm_range_fault(struct hmm_range *range, bool block);
+#define HMM_RANGE_DEFAULT_TIMEOUT 1000
+
+/* This is a temporary helper to avoid merge conflict between trees. */
+static inline bool hmm_vma_range_done(struct hmm_range *range)
+{
+	bool ret = hmm_range_valid(range);
+
+	hmm_range_unregister(range);
+	return ret;
+}
 
 /* This is a temporary helper to avoid merge conflict between trees. */
 static inline int hmm_vma_fault(struct hmm_range *range, bool block)
 {
-	long ret = hmm_range_fault(range, block);
-	if (ret == -EBUSY)
-		ret = -EAGAIN;
-	else if (ret == -EAGAIN)
-		ret = -EBUSY;
-	return ret < 0 ? ret : 0;
+	long ret;
+
+	ret = hmm_range_register(range, range->vma->vm_mm,
+				 range->start, range->end);
+	if (ret)
+		return (int)ret;
+
+	if (!hmm_range_wait_until_valid(range, HMM_RANGE_DEFAULT_TIMEOUT)) {
+		/*
+		 * The mmap_sem was taken by driver we release it here and
+		 * returns -EAGAIN which correspond to mmap_sem have been
+		 * drop in the old API.
+		 */
+		up_read(&range->vma->vm_mm->mmap_sem);
+		return -EAGAIN;
+	}
+
+	ret = hmm_range_fault(range, block);
+	if (ret <= 0) {
+		if (ret == -EBUSY || !ret) {
+			/* Same as above  drop mmap_sem to match old API. */
+			up_read(&range->vma->vm_mm->mmap_sem);
+			ret = -EBUSY;
+		} else if (ret == -EAGAIN)
+			ret = -EBUSY;
+		hmm_range_unregister(range);
+		return ret;
+	}
+	return 0;
 }
 
 /* Below are for HMM internal use only! Not to be used by device driver! */