mm/page_table_check: check entries at pmd levels

syzbot detected a case where the page table counters were not properly
updated.

  syzkaller login:  ------------[ cut here ]------------
  kernel BUG at mm/page_table_check.c:162!
  invalid opcode: 0000 [#1] PREEMPT SMP KASAN
  CPU: 0 PID: 3099 Comm: pasha Not tainted 5.16.0+ #48
  Hardware name: QEMU Standard PC (i440FX + PIIX, 1996), BIO4
  RIP: 0010:__page_table_check_zero+0x159/0x1a0
  Call Trace:
   free_pcp_prepare+0x3be/0xaa0
   free_unref_page+0x1c/0x650
   free_compound_page+0xec/0x130
   free_transhuge_page+0x1be/0x260
   __put_compound_page+0x90/0xd0
   release_pages+0x54c/0x1060
   __pagevec_release+0x7c/0x110
   shmem_undo_range+0x85e/0x1250
  ...

The repro involved having a huge page that is split due to uprobe event
temporarily replacing one of the pages in the huge page.  Later the huge
page was combined again, but the counters were off, as the PTE level was
not properly updated.

Make sure that when PMD is cleared and prior to freeing the level the
PTEs are updated.

Link: https://lkml.kernel.org/r/20220131203249.2832273-5-pasha.tatashin@soleen.com
Fixes: df4e817b7108 ("mm: page table check")
Signed-off-by: Pasha Tatashin <pasha.tatashin@soleen.com>
Acked-by: David Rientjes <rientjes@google.com>
Cc: Aneesh Kumar K.V <aneesh.kumar@linux.ibm.com>
Cc: Anshuman Khandual <anshuman.khandual@arm.com>
Cc: Dave Hansen <dave.hansen@linux.intel.com>
Cc: Greg Thelen <gthelen@google.com>
Cc: H. Peter Anvin <hpa@zytor.com>
Cc: Hugh Dickins <hughd@google.com>
Cc: Ingo Molnar <mingo@redhat.com>
Cc: Jiri Slaby <jirislaby@kernel.org>
Cc: Mike Rapoport <rppt@kernel.org>
Cc: Muchun Song <songmuchun@bytedance.com>
Cc: Paul Turner <pjt@google.com>
Cc: Wei Xu <weixugc@google.com>
Cc: Will Deacon <will@kernel.org>
Cc: Zi Yan <ziy@nvidia.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
diff --git a/include/linux/page_table_check.h b/include/linux/page_table_check.h
index 38cace1..01e16c7 100644
--- a/include/linux/page_table_check.h
+++ b/include/linux/page_table_check.h
@@ -26,6 +26,9 @@ void __page_table_check_pmd_set(struct mm_struct *mm, unsigned long addr,
 				pmd_t *pmdp, pmd_t pmd);
 void __page_table_check_pud_set(struct mm_struct *mm, unsigned long addr,
 				pud_t *pudp, pud_t pud);
+void __page_table_check_pte_clear_range(struct mm_struct *mm,
+					unsigned long addr,
+					pmd_t pmd);
 
 static inline void page_table_check_alloc(struct page *page, unsigned int order)
 {
@@ -100,6 +103,16 @@ static inline void page_table_check_pud_set(struct mm_struct *mm,
 	__page_table_check_pud_set(mm, addr, pudp, pud);
 }
 
+static inline void page_table_check_pte_clear_range(struct mm_struct *mm,
+						    unsigned long addr,
+						    pmd_t pmd)
+{
+	if (static_branch_likely(&page_table_check_disabled))
+		return;
+
+	__page_table_check_pte_clear_range(mm, addr, pmd);
+}
+
 #else
 
 static inline void page_table_check_alloc(struct page *page, unsigned int order)
@@ -143,5 +156,11 @@ static inline void page_table_check_pud_set(struct mm_struct *mm,
 {
 }
 
+static inline void page_table_check_pte_clear_range(struct mm_struct *mm,
+						    unsigned long addr,
+						    pmd_t pmd)
+{
+}
+
 #endif /* CONFIG_PAGE_TABLE_CHECK */
 #endif /* __LINUX_PAGE_TABLE_CHECK_H */
diff --git a/mm/khugepaged.c b/mm/khugepaged.c
index 30e59e4..131492f 100644
--- a/mm/khugepaged.c
+++ b/mm/khugepaged.c
@@ -16,6 +16,7 @@
 #include <linux/hashtable.h>
 #include <linux/userfaultfd_k.h>
 #include <linux/page_idle.h>
+#include <linux/page_table_check.h>
 #include <linux/swapops.h>
 #include <linux/shmem_fs.h>
 
@@ -1422,10 +1423,12 @@ static void collapse_and_free_pmd(struct mm_struct *mm, struct vm_area_struct *v
 	spinlock_t *ptl;
 	pmd_t pmd;
 
+	mmap_assert_write_locked(mm);
 	ptl = pmd_lock(vma->vm_mm, pmdp);
 	pmd = pmdp_collapse_flush(vma, addr, pmdp);
 	spin_unlock(ptl);
 	mm_dec_nr_ptes(mm);
+	page_table_check_pte_clear_range(mm, addr, pmd);
 	pte_free(mm, pmd_pgtable(pmd));
 }
 
diff --git a/mm/page_table_check.c b/mm/page_table_check.c
index c61d7eb..3763bd0 100644
--- a/mm/page_table_check.c
+++ b/mm/page_table_check.c
@@ -247,3 +247,23 @@ void __page_table_check_pud_set(struct mm_struct *mm, unsigned long addr,
 	}
 }
 EXPORT_SYMBOL(__page_table_check_pud_set);
+
+void __page_table_check_pte_clear_range(struct mm_struct *mm,
+					unsigned long addr,
+					pmd_t pmd)
+{
+	if (&init_mm == mm)
+		return;
+
+	if (!pmd_bad(pmd) && !pmd_leaf(pmd)) {
+		pte_t *ptep = pte_offset_map(&pmd, addr);
+		unsigned long i;
+
+		pte_unmap(ptep);
+		for (i = 0; i < PTRS_PER_PTE; i++) {
+			__page_table_check_pte_clear(mm, addr, *ptep);
+			addr += PAGE_SIZE;
+			ptep++;
+		}
+	}
+}