mm/mmu_notifiers: add a get/put scheme for the registration

Many places in the kernel have a flow where userspace will create some
object and that object will need to connect to the subsystem's
mmu_notifier subscription for the duration of its lifetime.

In this case the subsystem is usually tracking multiple mm_structs and it
is difficult to keep track of what struct mmu_notifier's have been
allocated for what mm's.

Since this has been open coded in a variety of exciting ways, provide core
functionality to do this safely.

This approach uses the struct mmu_notifier_ops * as a key to determine if
the subsystem has a notifier registered on the mm or not. If there is a
registration then the existing notifier struct is returned, otherwise the
ops->alloc_notifiers() is used to create a new per-subsystem notifier for
the mm.

The destroy side incorporates an async call_srcu based destruction which
will avoid bugs in the callers such as commit 6d7c3cde93c1 ("mm/hmm: fix
use after free with struct hmm in the mmu notifiers").

Since we are inside the mmu notifier core locking is fairly simple, the
allocation uses the same approach as for mmu_notifier_mm, the write side
of the mmap_sem makes everything deterministic and we only need to do
hlist_add_head_rcu() under the mm_take_all_locks(). The new users count
and the discoverability in the hlist is fully serialized by the
mmu_notifier_mm->lock.

Link: https://lore.kernel.org/r/20190806231548.25242-4-jgg@ziepe.ca
Co-developed-by: Christoph Hellwig <hch@infradead.org>
Signed-off-by: Christoph Hellwig <hch@infradead.org>
Reviewed-by: Ralph Campbell <rcampbell@nvidia.com>
Tested-by: Ralph Campbell <rcampbell@nvidia.com>
Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
diff --git a/include/linux/mmu_notifier.h b/include/linux/mmu_notifier.h
index b6c004b..31aa971 100644
--- a/include/linux/mmu_notifier.h
+++ b/include/linux/mmu_notifier.h
@@ -211,6 +211,19 @@ struct mmu_notifier_ops {
 	 */
 	void (*invalidate_range)(struct mmu_notifier *mn, struct mm_struct *mm,
 				 unsigned long start, unsigned long end);
+
+	/*
+	 * These callbacks are used with the get/put interface to manage the
+	 * lifetime of the mmu_notifier memory. alloc_notifier() returns a new
+	 * notifier for use with the mm.
+	 *
+	 * free_notifier() is only called after the mmu_notifier has been
+	 * fully put, calls to any ops callback are prevented and no ops
+	 * callbacks are currently running. It is called from a SRCU callback
+	 * and cannot sleep.
+	 */
+	struct mmu_notifier *(*alloc_notifier)(struct mm_struct *mm);
+	void (*free_notifier)(struct mmu_notifier *mn);
 };
 
 /*
@@ -227,6 +240,9 @@ struct mmu_notifier_ops {
 struct mmu_notifier {
 	struct hlist_node hlist;
 	const struct mmu_notifier_ops *ops;
+	struct mm_struct *mm;
+	struct rcu_head rcu;
+	unsigned int users;
 };
 
 static inline int mm_has_notifiers(struct mm_struct *mm)
@@ -234,6 +250,21 @@ static inline int mm_has_notifiers(struct mm_struct *mm)
 	return unlikely(mm->mmu_notifier_mm);
 }
 
+struct mmu_notifier *mmu_notifier_get_locked(const struct mmu_notifier_ops *ops,
+					     struct mm_struct *mm);
+static inline struct mmu_notifier *
+mmu_notifier_get(const struct mmu_notifier_ops *ops, struct mm_struct *mm)
+{
+	struct mmu_notifier *ret;
+
+	down_write(&mm->mmap_sem);
+	ret = mmu_notifier_get_locked(ops, mm);
+	up_write(&mm->mmap_sem);
+	return ret;
+}
+void mmu_notifier_put(struct mmu_notifier *mn);
+void mmu_notifier_synchronize(void);
+
 extern int mmu_notifier_register(struct mmu_notifier *mn,
 				 struct mm_struct *mm);
 extern int __mmu_notifier_register(struct mmu_notifier *mn,
@@ -581,6 +612,10 @@ static inline void mmu_notifier_mm_destroy(struct mm_struct *mm)
 #define pudp_huge_clear_flush_notify pudp_huge_clear_flush
 #define set_pte_at_notify set_pte_at
 
+static inline void mmu_notifier_synchronize(void)
+{
+}
+
 #endif /* CONFIG_MMU_NOTIFIER */
 
 #endif /* _LINUX_MMU_NOTIFIER_H */
diff --git a/mm/mmu_notifier.c b/mm/mmu_notifier.c
index 696810f..9e92ec8 100644
--- a/mm/mmu_notifier.c
+++ b/mm/mmu_notifier.c
@@ -248,6 +248,9 @@ int __mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm)
 	lockdep_assert_held_write(&mm->mmap_sem);
 	BUG_ON(atomic_read(&mm->mm_users) <= 0);
 
+	mn->mm = mm;
+	mn->users = 1;
+
 	if (!mm->mmu_notifier_mm) {
 		/*
 		 * kmalloc cannot be called under mm_take_all_locks(), but we
@@ -295,18 +298,24 @@ int __mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm)
 }
 EXPORT_SYMBOL_GPL(__mmu_notifier_register);
 
-/*
+/**
+ * mmu_notifier_register - Register a notifier on a mm
+ * @mn: The notifier to attach
+ * @mm: The mm to attach the notifier to
+ *
  * Must not hold mmap_sem nor any other VM related lock when calling
  * this registration function. Must also ensure mm_users can't go down
  * to zero while this runs to avoid races with mmu_notifier_release,
  * so mm has to be current->mm or the mm should be pinned safely such
  * as with get_task_mm(). If the mm is not current->mm, the mm_users
  * pin should be released by calling mmput after mmu_notifier_register
- * returns. mmu_notifier_unregister must be always called to
- * unregister the notifier. mm_count is automatically pinned to allow
- * mmu_notifier_unregister to safely run at any time later, before or
- * after exit_mmap. ->release will always be called before exit_mmap
- * frees the pages.
+ * returns.
+ *
+ * mmu_notifier_unregister() or mmu_notifier_put() must be always called to
+ * unregister the notifier.
+ *
+ * While the caller has a mmu_notifier get the mn->mm pointer will remain
+ * valid, and can be converted to an active mm pointer via mmget_not_zero().
  */
 int mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm)
 {
@@ -319,6 +328,72 @@ int mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm)
 }
 EXPORT_SYMBOL_GPL(mmu_notifier_register);
 
+static struct mmu_notifier *
+find_get_mmu_notifier(struct mm_struct *mm, const struct mmu_notifier_ops *ops)
+{
+	struct mmu_notifier *mn;
+
+	spin_lock(&mm->mmu_notifier_mm->lock);
+	hlist_for_each_entry_rcu (mn, &mm->mmu_notifier_mm->list, hlist) {
+		if (mn->ops != ops)
+			continue;
+
+		if (likely(mn->users != UINT_MAX))
+			mn->users++;
+		else
+			mn = ERR_PTR(-EOVERFLOW);
+		spin_unlock(&mm->mmu_notifier_mm->lock);
+		return mn;
+	}
+	spin_unlock(&mm->mmu_notifier_mm->lock);
+	return NULL;
+}
+
+/**
+ * mmu_notifier_get_locked - Return the single struct mmu_notifier for
+ *                           the mm & ops
+ * @ops: The operations struct being subscribe with
+ * @mm : The mm to attach notifiers too
+ *
+ * This function either allocates a new mmu_notifier via
+ * ops->alloc_notifier(), or returns an already existing notifier on the
+ * list. The value of the ops pointer is used to determine when two notifiers
+ * are the same.
+ *
+ * Each call to mmu_notifier_get() must be paired with a call to
+ * mmu_notifier_put(). The caller must hold the write side of mm->mmap_sem.
+ *
+ * While the caller has a mmu_notifier get the mm pointer will remain valid,
+ * and can be converted to an active mm pointer via mmget_not_zero().
+ */
+struct mmu_notifier *mmu_notifier_get_locked(const struct mmu_notifier_ops *ops,
+					     struct mm_struct *mm)
+{
+	struct mmu_notifier *mn;
+	int ret;
+
+	lockdep_assert_held_write(&mm->mmap_sem);
+
+	if (mm->mmu_notifier_mm) {
+		mn = find_get_mmu_notifier(mm, ops);
+		if (mn)
+			return mn;
+	}
+
+	mn = ops->alloc_notifier(mm);
+	if (IS_ERR(mn))
+		return mn;
+	mn->ops = ops;
+	ret = __mmu_notifier_register(mn, mm);
+	if (ret)
+		goto out_free;
+	return mn;
+out_free:
+	mn->ops->free_notifier(mn);
+	return ERR_PTR(ret);
+}
+EXPORT_SYMBOL_GPL(mmu_notifier_get_locked);
+
 /* this is called after the last mmu_notifier_unregister() returned */
 void __mmu_notifier_mm_destroy(struct mm_struct *mm)
 {
@@ -397,6 +472,75 @@ void mmu_notifier_unregister_no_release(struct mmu_notifier *mn,
 }
 EXPORT_SYMBOL_GPL(mmu_notifier_unregister_no_release);
 
+static void mmu_notifier_free_rcu(struct rcu_head *rcu)
+{
+	struct mmu_notifier *mn = container_of(rcu, struct mmu_notifier, rcu);
+	struct mm_struct *mm = mn->mm;
+
+	mn->ops->free_notifier(mn);
+	/* Pairs with the get in __mmu_notifier_register() */
+	mmdrop(mm);
+}
+
+/**
+ * mmu_notifier_put - Release the reference on the notifier
+ * @mn: The notifier to act on
+ *
+ * This function must be paired with each mmu_notifier_get(), it releases the
+ * reference obtained by the get. If this is the last reference then process
+ * to free the notifier will be run asynchronously.
+ *
+ * Unlike mmu_notifier_unregister() the get/put flow only calls ops->release
+ * when the mm_struct is destroyed. Instead free_notifier is always called to
+ * release any resources held by the user.
+ *
+ * As ops->release is not guaranteed to be called, the user must ensure that
+ * all sptes are dropped, and no new sptes can be established before
+ * mmu_notifier_put() is called.
+ *
+ * This function can be called from the ops->release callback, however the
+ * caller must still ensure it is called pairwise with mmu_notifier_get().
+ *
+ * Modules calling this function must call mmu_notifier_synchronize() in
+ * their __exit functions to ensure the async work is completed.
+ */
+void mmu_notifier_put(struct mmu_notifier *mn)
+{
+	struct mm_struct *mm = mn->mm;
+
+	spin_lock(&mm->mmu_notifier_mm->lock);
+	if (WARN_ON(!mn->users) || --mn->users)
+		goto out_unlock;
+	hlist_del_init_rcu(&mn->hlist);
+	spin_unlock(&mm->mmu_notifier_mm->lock);
+
+	call_srcu(&srcu, &mn->rcu, mmu_notifier_free_rcu);
+	return;
+
+out_unlock:
+	spin_unlock(&mm->mmu_notifier_mm->lock);
+}
+EXPORT_SYMBOL_GPL(mmu_notifier_put);
+
+/**
+ * mmu_notifier_synchronize - Ensure all mmu_notifiers are freed
+ *
+ * This function ensures that all outstanding async SRU work from
+ * mmu_notifier_put() is completed. After it returns any mmu_notifier_ops
+ * associated with an unused mmu_notifier will no longer be called.
+ *
+ * Before using the caller must ensure that all of its mmu_notifiers have been
+ * fully released via mmu_notifier_put().
+ *
+ * Modules using the mmu_notifier_put() API should call this in their __exit
+ * function to avoid module unloading races.
+ */
+void mmu_notifier_synchronize(void)
+{
+	synchronize_srcu(&srcu);
+}
+EXPORT_SYMBOL_GPL(mmu_notifier_synchronize);
+
 bool
 mmu_notifier_range_update_to_read_only(const struct mmu_notifier_range *range)
 {