mm/hmm: use device driver encoding for HMM pfn

Users of hmm_vma_fault() and hmm_vma_get_pfns() provide a flags array and
pfn shift value allowing them to define their own encoding for HMM pfn
that are fill inside the pfns array of the hmm_range struct.  With this
device driver can get pfn that match their own private encoding out of HMM
without having to do any conversion.

[rcampbell@nvidia.com: don't ignore specific pte fault flag in hmm_vma_fault()]
  Link: http://lkml.kernel.org/r/20180326213009.2460-2-jglisse@redhat.com
[rcampbell@nvidia.com: clarify fault logic for device private memory]
  Link: http://lkml.kernel.org/r/20180326213009.2460-3-jglisse@redhat.com
Link: http://lkml.kernel.org/r/20180323005527.758-16-jglisse@redhat.com
Signed-off-by: Jérôme Glisse <jglisse@redhat.com>
Signed-off-by: Ralph Campbell <rcampbell@nvidia.com>
Cc: Evgeny Baskakov <ebaskakov@nvidia.com>
Cc: Ralph Campbell <rcampbell@nvidia.com>
Cc: Mark Hairgrove <mhairgrove@nvidia.com>
Cc: John Hubbard <jhubbard@nvidia.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
diff --git a/mm/hmm.c b/mm/hmm.c
index 290c872..398d021 100644
--- a/mm/hmm.c
+++ b/mm/hmm.c
@@ -306,6 +306,7 @@ static int hmm_vma_do_fault(struct mm_walk *walk, unsigned long addr,
 {
 	unsigned int flags = FAULT_FLAG_ALLOW_RETRY | FAULT_FLAG_REMOTE;
 	struct hmm_vma_walk *hmm_vma_walk = walk->private;
+	struct hmm_range *range = hmm_vma_walk->range;
 	struct vm_area_struct *vma = walk->vma;
 	int r;
 
@@ -315,7 +316,7 @@ static int hmm_vma_do_fault(struct mm_walk *walk, unsigned long addr,
 	if (r & VM_FAULT_RETRY)
 		return -EBUSY;
 	if (r & VM_FAULT_ERROR) {
-		*pfn = HMM_PFN_ERROR;
+		*pfn = range->values[HMM_PFN_ERROR];
 		return -EFAULT;
 	}
 
@@ -333,7 +334,7 @@ static int hmm_pfns_bad(unsigned long addr,
 
 	i = (addr - range->start) >> PAGE_SHIFT;
 	for (; addr < end; addr += PAGE_SIZE, i++)
-		pfns[i] = HMM_PFN_ERROR;
+		pfns[i] = range->values[HMM_PFN_ERROR];
 
 	return 0;
 }
@@ -362,7 +363,7 @@ static int hmm_vma_walk_hole_(unsigned long addr, unsigned long end,
 	hmm_vma_walk->last = addr;
 	i = (addr - range->start) >> PAGE_SHIFT;
 	for (; addr < end; addr += PAGE_SIZE, i++) {
-		pfns[i] = 0;
+		pfns[i] = range->values[HMM_PFN_NONE];
 		if (fault || write_fault) {
 			int ret;
 
@@ -380,24 +381,31 @@ static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
 				      uint64_t pfns, uint64_t cpu_flags,
 				      bool *fault, bool *write_fault)
 {
+	struct hmm_range *range = hmm_vma_walk->range;
+
 	*fault = *write_fault = false;
 	if (!hmm_vma_walk->fault)
 		return;
 
 	/* We aren't ask to do anything ... */
-	if (!(pfns & HMM_PFN_VALID))
+	if (!(pfns & range->flags[HMM_PFN_VALID]))
 		return;
-	/* If CPU page table is not valid then we need to fault */
-	*fault = cpu_flags & HMM_PFN_VALID;
-	/* Need to write fault ? */
-	if ((pfns & HMM_PFN_WRITE) && !(cpu_flags & HMM_PFN_WRITE)) {
-		*fault = *write_fault = false;
+	/* If this is device memory than only fault if explicitly requested */
+	if ((cpu_flags & range->flags[HMM_PFN_DEVICE_PRIVATE])) {
+		/* Do we fault on device memory ? */
+		if (pfns & range->flags[HMM_PFN_DEVICE_PRIVATE]) {
+			*write_fault = pfns & range->flags[HMM_PFN_WRITE];
+			*fault = true;
+		}
 		return;
 	}
-	/* Do we fault on device memory ? */
-	if ((pfns & HMM_PFN_DEVICE_PRIVATE) &&
-	    (cpu_flags & HMM_PFN_DEVICE_PRIVATE)) {
-		*write_fault = pfns & HMM_PFN_WRITE;
+
+	/* If CPU page table is not valid then we need to fault */
+	*fault = !(cpu_flags & range->flags[HMM_PFN_VALID]);
+	/* Need to write fault ? */
+	if ((pfns & range->flags[HMM_PFN_WRITE]) &&
+	    !(cpu_flags & range->flags[HMM_PFN_WRITE])) {
+		*write_fault = true;
 		*fault = true;
 	}
 }
@@ -439,13 +447,13 @@ static int hmm_vma_walk_hole(unsigned long addr, unsigned long end,
 	return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
 }
 
-static inline uint64_t pmd_to_hmm_pfn_flags(pmd_t pmd)
+static inline uint64_t pmd_to_hmm_pfn_flags(struct hmm_range *range, pmd_t pmd)
 {
 	if (pmd_protnone(pmd))
 		return 0;
-	return pmd_write(pmd) ? HMM_PFN_VALID |
-				HMM_PFN_WRITE :
-				HMM_PFN_VALID;
+	return pmd_write(pmd) ? range->flags[HMM_PFN_VALID] |
+				range->flags[HMM_PFN_WRITE] :
+				range->flags[HMM_PFN_VALID];
 }
 
 static int hmm_vma_handle_pmd(struct mm_walk *walk,
@@ -455,12 +463,13 @@ static int hmm_vma_handle_pmd(struct mm_walk *walk,
 			      pmd_t pmd)
 {
 	struct hmm_vma_walk *hmm_vma_walk = walk->private;
+	struct hmm_range *range = hmm_vma_walk->range;
 	unsigned long pfn, npages, i;
-	uint64_t flag = 0, cpu_flags;
 	bool fault, write_fault;
+	uint64_t cpu_flags;
 
 	npages = (end - addr) >> PAGE_SHIFT;
-	cpu_flags = pmd_to_hmm_pfn_flags(pmd);
+	cpu_flags = pmd_to_hmm_pfn_flags(range, pmd);
 	hmm_range_need_fault(hmm_vma_walk, pfns, npages, cpu_flags,
 			     &fault, &write_fault);
 
@@ -468,20 +477,19 @@ static int hmm_vma_handle_pmd(struct mm_walk *walk,
 		return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
 
 	pfn = pmd_pfn(pmd) + pte_index(addr);
-	flag |= pmd_write(pmd) ? HMM_PFN_WRITE : 0;
 	for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++)
-		pfns[i] = hmm_pfn_from_pfn(pfn) | flag;
+		pfns[i] = hmm_pfn_from_pfn(range, pfn) | cpu_flags;
 	hmm_vma_walk->last = end;
 	return 0;
 }
 
-static inline uint64_t pte_to_hmm_pfn_flags(pte_t pte)
+static inline uint64_t pte_to_hmm_pfn_flags(struct hmm_range *range, pte_t pte)
 {
 	if (pte_none(pte) || !pte_present(pte))
 		return 0;
-	return pte_write(pte) ? HMM_PFN_VALID |
-				HMM_PFN_WRITE :
-				HMM_PFN_VALID;
+	return pte_write(pte) ? range->flags[HMM_PFN_VALID] |
+				range->flags[HMM_PFN_WRITE] :
+				range->flags[HMM_PFN_VALID];
 }
 
 static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
@@ -489,14 +497,16 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
 			      uint64_t *pfn)
 {
 	struct hmm_vma_walk *hmm_vma_walk = walk->private;
+	struct hmm_range *range = hmm_vma_walk->range;
 	struct vm_area_struct *vma = walk->vma;
 	bool fault, write_fault;
 	uint64_t cpu_flags;
 	pte_t pte = *ptep;
+	uint64_t orig_pfn = *pfn;
 
-	*pfn = 0;
-	cpu_flags = pte_to_hmm_pfn_flags(pte);
-	hmm_pte_need_fault(hmm_vma_walk, *pfn, cpu_flags,
+	*pfn = range->values[HMM_PFN_NONE];
+	cpu_flags = pte_to_hmm_pfn_flags(range, pte);
+	hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
 			   &fault, &write_fault);
 
 	if (pte_none(pte)) {
@@ -519,11 +529,16 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
 		 * device and report anything else as error.
 		 */
 		if (is_device_private_entry(entry)) {
-			cpu_flags = HMM_PFN_VALID | HMM_PFN_DEVICE_PRIVATE;
+			cpu_flags = range->flags[HMM_PFN_VALID] |
+				range->flags[HMM_PFN_DEVICE_PRIVATE];
 			cpu_flags |= is_write_device_private_entry(entry) ?
-					HMM_PFN_WRITE : 0;
-			*pfn = hmm_pfn_from_pfn(swp_offset(entry));
-			*pfn |= HMM_PFN_DEVICE_PRIVATE;
+				range->flags[HMM_PFN_WRITE] : 0;
+			hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
+					   &fault, &write_fault);
+			if (fault || write_fault)
+				goto fault;
+			*pfn = hmm_pfn_from_pfn(range, swp_offset(entry));
+			*pfn |= cpu_flags;
 			return 0;
 		}
 
@@ -539,14 +554,14 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
 		}
 
 		/* Report error for everything else */
-		*pfn = HMM_PFN_ERROR;
+		*pfn = range->values[HMM_PFN_ERROR];
 		return -EFAULT;
 	}
 
 	if (fault || write_fault)
 		goto fault;
 
-	*pfn = hmm_pfn_from_pfn(pte_pfn(pte)) | cpu_flags;
+	*pfn = hmm_pfn_from_pfn(range, pte_pfn(pte)) | cpu_flags;
 	return 0;
 
 fault:
@@ -615,12 +630,13 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp,
 	return 0;
 }
 
-static void hmm_pfns_clear(uint64_t *pfns,
+static void hmm_pfns_clear(struct hmm_range *range,
+			   uint64_t *pfns,
 			   unsigned long addr,
 			   unsigned long end)
 {
 	for (; addr < end; addr += PAGE_SIZE, pfns++)
-		*pfns = 0;
+		*pfns = range->values[HMM_PFN_NONE];
 }
 
 static void hmm_pfns_special(struct hmm_range *range)
@@ -628,7 +644,7 @@ static void hmm_pfns_special(struct hmm_range *range)
 	unsigned long addr = range->start, i = 0;
 
 	for (; addr < range->end; addr += PAGE_SIZE, i++)
-		range->pfns[i] = HMM_PFN_SPECIAL;
+		range->pfns[i] = range->values[HMM_PFN_SPECIAL];
 }
 
 /*
@@ -681,7 +697,7 @@ int hmm_vma_get_pfns(struct hmm_range *range)
 		 * write without read access are not supported by HMM, because
 		 * operations such has atomic access would not work.
 		 */
-		hmm_pfns_clear(range->pfns, range->start, range->end);
+		hmm_pfns_clear(range, range->pfns, range->start, range->end);
 		return -EPERM;
 	}
 
@@ -834,7 +850,7 @@ int hmm_vma_fault(struct hmm_range *range, bool block)
 
 	hmm = hmm_register(vma->vm_mm);
 	if (!hmm) {
-		hmm_pfns_clear(range->pfns, range->start, range->end);
+		hmm_pfns_clear(range, range->pfns, range->start, range->end);
 		return -ENOMEM;
 	}
 	/* Caller must have registered a mirror using hmm_mirror_register() */
@@ -854,7 +870,7 @@ int hmm_vma_fault(struct hmm_range *range, bool block)
 		 * write without read access are not supported by HMM, because
 		 * operations such has atomic access would not work.
 		 */
-		hmm_pfns_clear(range->pfns, range->start, range->end);
+		hmm_pfns_clear(range, range->pfns, range->start, range->end);
 		return -EPERM;
 	}
 
@@ -887,7 +903,8 @@ int hmm_vma_fault(struct hmm_range *range, bool block)
 		unsigned long i;
 
 		i = (hmm_vma_walk.last - range->start) >> PAGE_SHIFT;
-		hmm_pfns_clear(&range->pfns[i], hmm_vma_walk.last, range->end);
+		hmm_pfns_clear(range, &range->pfns[i], hmm_vma_walk.last,
+			       range->end);
 		hmm_vma_range_done(range);
 	}
 	return ret;