powerpc/pkeys: Avoid using lockless page table walk

Fetch pkey from vma instead of linux page table. Also document the fact that in
some cases the pkey returned in siginfo won't be the same as the one we took
keyfault on. Even with linux page table walk, we can end up in a similar scenario.

Signed-off-by: Aneesh Kumar K.V <aneesh.kumar@linux.ibm.com>
Signed-off-by: Michael Ellerman <mpe@ellerman.id.au>
Link: https://lore.kernel.org/r/20200505071729.54912-2-aneesh.kumar@linux.ibm.com
diff --git a/arch/powerpc/include/asm/mmu.h b/arch/powerpc/include/asm/mmu.h
index 0699cfe..cf2a08b 100644
--- a/arch/powerpc/include/asm/mmu.h
+++ b/arch/powerpc/include/asm/mmu.h
@@ -291,15 +291,6 @@ static inline bool early_radix_enabled(void)
 }
 #endif
 
-#ifdef CONFIG_PPC_MEM_KEYS
-extern u16 get_mm_addr_key(struct mm_struct *mm, unsigned long address);
-#else
-static inline u16 get_mm_addr_key(struct mm_struct *mm, unsigned long address)
-{
-	return 0;
-}
-#endif /* CONFIG_PPC_MEM_KEYS */
-
 #ifdef CONFIG_STRICT_KERNEL_RWX
 static inline bool strict_kernel_rwx_enabled(void)
 {
diff --git a/arch/powerpc/mm/book3s64/hash_utils.c b/arch/powerpc/mm/book3s64/hash_utils.c
index 8ed2411..e951e87 100644
--- a/arch/powerpc/mm/book3s64/hash_utils.c
+++ b/arch/powerpc/mm/book3s64/hash_utils.c
@@ -1671,30 +1671,6 @@ void update_mmu_cache(struct vm_area_struct *vma, unsigned long address,
 	hash_preload(vma->vm_mm, address, is_exec, trap);
 }
 
-#ifdef CONFIG_PPC_MEM_KEYS
-/*
- * Return the protection key associated with the given address and the
- * mm_struct.
- */
-u16 get_mm_addr_key(struct mm_struct *mm, unsigned long address)
-{
-	pte_t *ptep;
-	u16 pkey = 0;
-	unsigned long flags;
-
-	if (!mm || !mm->pgd)
-		return 0;
-
-	local_irq_save(flags);
-	ptep = find_linux_pte(mm->pgd, address, NULL, NULL);
-	if (ptep)
-		pkey = pte_to_pkey_bits(pte_val(READ_ONCE(*ptep)));
-	local_irq_restore(flags);
-
-	return pkey;
-}
-#endif /* CONFIG_PPC_MEM_KEYS */
-
 #ifdef CONFIG_PPC_TRANSACTIONAL_MEM
 static inline void tm_flush_hash_page(int local)
 {
diff --git a/arch/powerpc/mm/fault.c b/arch/powerpc/mm/fault.c
index 84af6c8..8e529e4 100644
--- a/arch/powerpc/mm/fault.c
+++ b/arch/powerpc/mm/fault.c
@@ -118,9 +118,34 @@ static noinline int bad_area(struct pt_regs *regs, unsigned long address)
 	return __bad_area(regs, address, SEGV_MAPERR);
 }
 
-static int bad_key_fault_exception(struct pt_regs *regs, unsigned long address,
-				    int pkey)
+#ifdef CONFIG_PPC_MEM_KEYS
+static noinline int bad_access_pkey(struct pt_regs *regs, unsigned long address,
+				    struct vm_area_struct *vma)
 {
+	struct mm_struct *mm = current->mm;
+	int pkey;
+
+	/*
+	 * We don't try to fetch the pkey from page table because reading
+	 * page table without locking doesn't guarantee stable pte value.
+	 * Hence the pkey value that we return to userspace can be different
+	 * from the pkey that actually caused access error.
+	 *
+	 * It does *not* guarantee that the VMA we find here
+	 * was the one that we faulted on.
+	 *
+	 * 1. T1   : mprotect_key(foo, PAGE_SIZE, pkey=4);
+	 * 2. T1   : set AMR to deny access to pkey=4, touches, page
+	 * 3. T1   : faults...
+	 * 4.    T2: mprotect_key(foo, PAGE_SIZE, pkey=5);
+	 * 5. T1   : enters fault handler, takes mmap_sem, etc...
+	 * 6. T1   : reaches here, sees vma_pkey(vma)=5, when we really
+	 *	     faulted on a pte with its pkey=4.
+	 */
+	pkey = vma_pkey(vma);
+
+	up_read(&mm->mmap_sem);
+
 	/*
 	 * If we are in kernel mode, bail out with a SEGV, this will
 	 * be caught by the assembly which will restore the non-volatile
@@ -133,6 +158,7 @@ static int bad_key_fault_exception(struct pt_regs *regs, unsigned long address,
 
 	return 0;
 }
+#endif
 
 static noinline int bad_access(struct pt_regs *regs, unsigned long address)
 {
@@ -289,8 +315,31 @@ static bool bad_stack_expansion(struct pt_regs *regs, unsigned long address,
 	return false;
 }
 
-static bool access_error(bool is_write, bool is_exec,
-			 struct vm_area_struct *vma)
+#ifdef CONFIG_PPC_MEM_KEYS
+static bool access_pkey_error(bool is_write, bool is_exec, bool is_pkey,
+			      struct vm_area_struct *vma)
+{
+	/*
+	 * Read or write was blocked by protection keys.  This is
+	 * always an unconditional error and can never result in
+	 * a follow-up action to resolve the fault, like a COW.
+	 */
+	if (is_pkey)
+		return true;
+
+	/*
+	 * Make sure to check the VMA so that we do not perform
+	 * faults just to hit a pkey fault as soon as we fill in a
+	 * page. Only called for current mm, hence foreign == 0
+	 */
+	if (!arch_vma_access_permitted(vma, is_write, is_exec, 0))
+		return true;
+
+	return false;
+}
+#endif
+
+static bool access_error(bool is_write, bool is_exec, struct vm_area_struct *vma)
 {
 	/*
 	 * Allow execution from readable areas if the MMU does not
@@ -483,10 +532,6 @@ static int __do_page_fault(struct pt_regs *regs, unsigned long address,
 
 	perf_sw_event(PERF_COUNT_SW_PAGE_FAULTS, 1, regs, address);
 
-	if (error_code & DSISR_KEYFAULT)
-		return bad_key_fault_exception(regs, address,
-					       get_mm_addr_key(mm, address));
-
 	/*
 	 * We want to do this outside mmap_sem, because reading code around nip
 	 * can result in fault, which will cause a deadlock when called with
@@ -555,6 +600,13 @@ static int __do_page_fault(struct pt_regs *regs, unsigned long address,
 		return bad_area(regs, address);
 
 good_area:
+
+#ifdef CONFIG_PPC_MEM_KEYS
+	if (unlikely(access_pkey_error(is_write, is_exec,
+				       (error_code & DSISR_KEYFAULT), vma)))
+		return bad_access_pkey(regs, address, vma);
+#endif /* CONFIG_PPC_MEM_KEYS */
+
 	if (unlikely(access_error(is_write, is_exec, vma)))
 		return bad_access(regs, address);
 
@@ -565,21 +617,6 @@ static int __do_page_fault(struct pt_regs *regs, unsigned long address,
 	 */
 	fault = handle_mm_fault(vma, address, flags);
 
-#ifdef CONFIG_PPC_MEM_KEYS
-	/*
-	 * we skipped checking for access error due to key earlier.
-	 * Check that using handle_mm_fault error return.
-	 */
-	if (unlikely(fault & VM_FAULT_SIGSEGV) &&
-		!arch_vma_access_permitted(vma, is_write, is_exec, 0)) {
-
-		int pkey = vma_pkey(vma);
-
-		up_read(&mm->mmap_sem);
-		return bad_key_fault_exception(regs, address, pkey);
-	}
-#endif /* CONFIG_PPC_MEM_KEYS */
-
 	major |= fault & VM_FAULT_MAJOR;
 
 	if (fault_signal_pending(fault, regs))