dax: Convert dax_insert_pfn_mkwrite to XArray

Add some XArray-based helper functions to replace the radix tree based
metaphors currently in use.  The biggest change is that converted code
doesn't see its own lock bit; get_unlocked_entry() always returns an
entry with the lock bit clear.  So we don't have to mess around loading
the current entry and clearing the lock bit; we can just store the
unlocked entry that we already have.

Signed-off-by: Matthew Wilcox <willy@infradead.org>
diff --git a/fs/dax.c b/fs/dax.c
index fd111ea..8873e4b 100644
--- a/fs/dax.c
+++ b/fs/dax.c
@@ -38,6 +38,17 @@
 #define CREATE_TRACE_POINTS
 #include <trace/events/fs_dax.h>
 
+static inline unsigned int pe_order(enum page_entry_size pe_size)
+{
+	if (pe_size == PE_SIZE_PTE)
+		return PAGE_SHIFT - PAGE_SHIFT;
+	if (pe_size == PE_SIZE_PMD)
+		return PMD_SHIFT - PAGE_SHIFT;
+	if (pe_size == PE_SIZE_PUD)
+		return PUD_SHIFT - PAGE_SHIFT;
+	return ~0;
+}
+
 /* We choose 4096 entries - same as per-zone page wait tables */
 #define DAX_WAIT_TABLE_BITS 12
 #define DAX_WAIT_TABLE_ENTRIES (1 << DAX_WAIT_TABLE_BITS)
@@ -46,6 +57,9 @@
 #define PG_PMD_COLOUR	((PMD_SIZE >> PAGE_SHIFT) - 1)
 #define PG_PMD_NR	(PMD_SIZE >> PAGE_SHIFT)
 
+/* The order of a PMD entry */
+#define PMD_ORDER	(PMD_SHIFT - PAGE_SHIFT)
+
 static wait_queue_head_t wait_table[DAX_WAIT_TABLE_ENTRIES];
 
 static int __init init_dax_wait_table(void)
@@ -85,10 +99,15 @@ static void *dax_make_locked(unsigned long pfn, unsigned long flags)
 			DAX_LOCKED);
 }
 
+static bool dax_is_locked(void *entry)
+{
+	return xa_to_value(entry) & DAX_LOCKED;
+}
+
 static unsigned int dax_entry_order(void *entry)
 {
 	if (xa_to_value(entry) & DAX_PMD)
-		return PMD_SHIFT - PAGE_SHIFT;
+		return PMD_ORDER;
 	return 0;
 }
 
@@ -181,6 +200,81 @@ static void dax_wake_mapping_entry_waiter(struct xarray *xa,
 		__wake_up(wq, TASK_NORMAL, wake_all ? 0 : 1, &key);
 }
 
+static void dax_wake_entry(struct xa_state *xas, void *entry, bool wake_all)
+{
+	return dax_wake_mapping_entry_waiter(xas->xa, xas->xa_index, entry,
+								wake_all);
+}
+
+/*
+ * Look up entry in page cache, wait for it to become unlocked if it
+ * is a DAX entry and return it.  The caller must subsequently call
+ * put_unlocked_entry() if it did not lock the entry or dax_unlock_entry()
+ * if it did.
+ *
+ * Must be called with the i_pages lock held.
+ */
+static void *get_unlocked_entry(struct xa_state *xas)
+{
+	void *entry;
+	struct wait_exceptional_entry_queue ewait;
+	wait_queue_head_t *wq;
+
+	init_wait(&ewait.wait);
+	ewait.wait.func = wake_exceptional_entry_func;
+
+	for (;;) {
+		entry = xas_load(xas);
+		if (!entry || xa_is_internal(entry) ||
+				WARN_ON_ONCE(!xa_is_value(entry)) ||
+				!dax_is_locked(entry))
+			return entry;
+
+		wq = dax_entry_waitqueue(xas->xa, xas->xa_index, entry,
+				&ewait.key);
+		prepare_to_wait_exclusive(wq, &ewait.wait,
+					  TASK_UNINTERRUPTIBLE);
+		xas_unlock_irq(xas);
+		xas_reset(xas);
+		schedule();
+		finish_wait(wq, &ewait.wait);
+		xas_lock_irq(xas);
+	}
+}
+
+static void put_unlocked_entry(struct xa_state *xas, void *entry)
+{
+	/* If we were the only waiter woken, wake the next one */
+	if (entry)
+		dax_wake_entry(xas, entry, false);
+}
+
+/*
+ * We used the xa_state to get the entry, but then we locked the entry and
+ * dropped the xa_lock, so we know the xa_state is stale and must be reset
+ * before use.
+ */
+static void dax_unlock_entry(struct xa_state *xas, void *entry)
+{
+	void *old;
+
+	xas_reset(xas);
+	xas_lock_irq(xas);
+	old = xas_store(xas, entry);
+	xas_unlock_irq(xas);
+	BUG_ON(!dax_is_locked(old));
+	dax_wake_entry(xas, entry, false);
+}
+
+/*
+ * Return: The entry stored at this location before it was locked.
+ */
+static void *dax_lock_entry(struct xa_state *xas, void *entry)
+{
+	unsigned long v = xa_to_value(entry);
+	return xas_store(xas, xa_mk_value(v | DAX_LOCKED));
+}
+
 /*
  * Check whether the given slot is locked.  Must be called with the i_pages
  * lock held.
@@ -1728,50 +1822,46 @@ EXPORT_SYMBOL_GPL(dax_iomap_fault);
 /*
  * dax_insert_pfn_mkwrite - insert PTE or PMD entry into page tables
  * @vmf: The description of the fault
- * @pe_size: Size of entry to be inserted
  * @pfn: PFN to insert
+ * @order: Order of entry to insert.
  *
  * This function inserts a writeable PTE or PMD entry into the page tables
  * for an mmaped DAX file.  It also marks the page cache entry as dirty.
  */
-static vm_fault_t dax_insert_pfn_mkwrite(struct vm_fault *vmf,
-				  enum page_entry_size pe_size,
-				  pfn_t pfn)
+static vm_fault_t
+dax_insert_pfn_mkwrite(struct vm_fault *vmf, pfn_t pfn, unsigned int order)
 {
 	struct address_space *mapping = vmf->vma->vm_file->f_mapping;
-	void *entry, **slot;
-	pgoff_t index = vmf->pgoff;
+	XA_STATE_ORDER(xas, &mapping->i_pages, vmf->pgoff, order);
+	void *entry;
 	vm_fault_t ret;
 
-	xa_lock_irq(&mapping->i_pages);
-	entry = get_unlocked_mapping_entry(mapping, index, &slot);
+	xas_lock_irq(&xas);
+	entry = get_unlocked_entry(&xas);
 	/* Did we race with someone splitting entry or so? */
 	if (!entry ||
-	    (pe_size == PE_SIZE_PTE && !dax_is_pte_entry(entry)) ||
-	    (pe_size == PE_SIZE_PMD && !dax_is_pmd_entry(entry))) {
-		put_unlocked_mapping_entry(mapping, index, entry);
-		xa_unlock_irq(&mapping->i_pages);
+	    (order == 0 && !dax_is_pte_entry(entry)) ||
+	    (order == PMD_ORDER && (xa_is_internal(entry) ||
+				    !dax_is_pmd_entry(entry)))) {
+		put_unlocked_entry(&xas, entry);
+		xas_unlock_irq(&xas);
 		trace_dax_insert_pfn_mkwrite_no_entry(mapping->host, vmf,
 						      VM_FAULT_NOPAGE);
 		return VM_FAULT_NOPAGE;
 	}
-	radix_tree_tag_set(&mapping->i_pages, index, PAGECACHE_TAG_DIRTY);
-	entry = lock_slot(mapping, slot);
-	xa_unlock_irq(&mapping->i_pages);
-	switch (pe_size) {
-	case PE_SIZE_PTE:
+	xas_set_mark(&xas, PAGECACHE_TAG_DIRTY);
+	dax_lock_entry(&xas, entry);
+	xas_unlock_irq(&xas);
+	if (order == 0)
 		ret = vmf_insert_mixed_mkwrite(vmf->vma, vmf->address, pfn);
-		break;
 #ifdef CONFIG_FS_DAX_PMD
-	case PE_SIZE_PMD:
+	else if (order == PMD_ORDER)
 		ret = vmf_insert_pfn_pmd(vmf->vma, vmf->address, vmf->pmd,
 			pfn, true);
-		break;
 #endif
-	default:
+	else
 		ret = VM_FAULT_FALLBACK;
-	}
-	put_locked_mapping_entry(mapping, index);
+	dax_unlock_entry(&xas, entry);
 	trace_dax_insert_pfn_mkwrite(mapping->host, vmf, ret);
 	return ret;
 }
@@ -1791,17 +1881,12 @@ vm_fault_t dax_finish_sync_fault(struct vm_fault *vmf,
 {
 	int err;
 	loff_t start = ((loff_t)vmf->pgoff) << PAGE_SHIFT;
-	size_t len = 0;
+	unsigned int order = pe_order(pe_size);
+	size_t len = PAGE_SIZE << order;
 
-	if (pe_size == PE_SIZE_PTE)
-		len = PAGE_SIZE;
-	else if (pe_size == PE_SIZE_PMD)
-		len = PMD_SIZE;
-	else
-		WARN_ON_ONCE(1);
 	err = vfs_fsync_range(vmf->vma->vm_file, start, start + len - 1, 1);
 	if (err)
 		return VM_FAULT_SIGBUS;
-	return dax_insert_pfn_mkwrite(vmf, pe_size, pfn);
+	return dax_insert_pfn_mkwrite(vmf, pfn, order);
 }
 EXPORT_SYMBOL_GPL(dax_finish_sync_fault);