KVM: MMU: Update accessed and dirty bits after guest pagetable walk

While unspecified, the behaviour of Intel processors is to first
perform the page table walk, then, if the walk was successful, to
atomically update the accessed and dirty bits of walked paging elements.

While we are not required to follow this exactly, doing so will allow us
to perform the access permissions check after the walk is complete, rather
than after each walk step.

(the tricky case is SMEP: a zero in any pte's U bit makes the referenced
page a supervisor page, so we can't fault on a one bit during the walk
itself).

Reviewed-by: Xiao Guangrong <xiaoguangrong@linux.vnet.ibm.com>
Signed-off-by: Avi Kivity <avi@redhat.com>
diff --git a/arch/x86/kvm/paging_tmpl.h b/arch/x86/kvm/paging_tmpl.h
index 1cbf576..35a05dd 100644
--- a/arch/x86/kvm/paging_tmpl.h
+++ b/arch/x86/kvm/paging_tmpl.h
@@ -63,10 +63,12 @@
  */
 struct guest_walker {
 	int level;
+	unsigned max_level;
 	gfn_t table_gfn[PT_MAX_FULL_LEVELS];
 	pt_element_t ptes[PT_MAX_FULL_LEVELS];
 	pt_element_t prefetch_ptes[PTE_PREFETCH_NUM];
 	gpa_t pte_gpa[PT_MAX_FULL_LEVELS];
+	pt_element_t __user *ptep_user[PT_MAX_FULL_LEVELS];
 	unsigned pt_access;
 	unsigned pte_access;
 	gfn_t gfn;
@@ -119,6 +121,43 @@
 	return false;
 }
 
+static int FNAME(update_accessed_dirty_bits)(struct kvm_vcpu *vcpu,
+					     struct kvm_mmu *mmu,
+					     struct guest_walker *walker,
+					     int write_fault)
+{
+	unsigned level, index;
+	pt_element_t pte, orig_pte;
+	pt_element_t __user *ptep_user;
+	gfn_t table_gfn;
+	int ret;
+
+	for (level = walker->max_level; level >= walker->level; --level) {
+		pte = orig_pte = walker->ptes[level - 1];
+		table_gfn = walker->table_gfn[level - 1];
+		ptep_user = walker->ptep_user[level - 1];
+		index = offset_in_page(ptep_user) / sizeof(pt_element_t);
+		if (!(pte & PT_ACCESSED_MASK)) {
+			trace_kvm_mmu_set_accessed_bit(table_gfn, index, sizeof(pte));
+			pte |= PT_ACCESSED_MASK;
+		}
+		if (level == walker->level && write_fault && !is_dirty_gpte(pte)) {
+			trace_kvm_mmu_set_dirty_bit(table_gfn, index, sizeof(pte));
+			pte |= PT_DIRTY_MASK;
+		}
+		if (pte == orig_pte)
+			continue;
+
+		ret = FNAME(cmpxchg_gpte)(vcpu, mmu, ptep_user, index, orig_pte, pte);
+		if (ret)
+			return ret;
+
+		mark_page_dirty(vcpu->kvm, table_gfn);
+		walker->ptes[level] = pte;
+	}
+	return 0;
+}
+
 /*
  * Fetch a guest pte for a guest virtual address
  */
@@ -126,6 +165,7 @@
 				    struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
 				    gva_t addr, u32 access)
 {
+	int ret;
 	pt_element_t pte;
 	pt_element_t __user *uninitialized_var(ptep_user);
 	gfn_t table_gfn;
@@ -153,6 +193,7 @@
 		--walker->level;
 	}
 #endif
+	walker->max_level = walker->level;
 	ASSERT((!is_long_mode(vcpu) && is_pae(vcpu)) ||
 	       (mmu->get_cr3(vcpu) & CR3_NONPAE_RESERVED_BITS) == 0);
 
@@ -183,6 +224,7 @@
 		ptep_user = (pt_element_t __user *)((void *)host_addr + offset);
 		if (unlikely(__copy_from_user(&pte, ptep_user, sizeof(pte))))
 			goto error;
+		walker->ptep_user[walker->level - 1] = ptep_user;
 
 		trace_kvm_mmu_paging_element(pte, walker->level);
 
@@ -214,21 +256,6 @@
 					eperm = true;
 		}
 
-		if (!eperm && unlikely(!(pte & PT_ACCESSED_MASK))) {
-			int ret;
-			trace_kvm_mmu_set_accessed_bit(table_gfn, index,
-						       sizeof(pte));
-			ret = FNAME(cmpxchg_gpte)(vcpu, mmu, ptep_user, index,
-						  pte, pte|PT_ACCESSED_MASK);
-			if (unlikely(ret < 0))
-				goto error;
-			else if (ret)
-				goto retry_walk;
-
-			mark_page_dirty(vcpu->kvm, table_gfn);
-			pte |= PT_ACCESSED_MASK;
-		}
-
 		walker->ptes[walker->level - 1] = pte;
 
 		if (last_gpte) {
@@ -268,21 +295,12 @@
 
 	if (!write_fault)
 		protect_clean_gpte(&pte_access, pte);
-	else if (unlikely(!is_dirty_gpte(pte))) {
-		int ret;
 
-		trace_kvm_mmu_set_dirty_bit(table_gfn, index, sizeof(pte));
-		ret = FNAME(cmpxchg_gpte)(vcpu, mmu, ptep_user, index,
-					  pte, pte|PT_DIRTY_MASK);
-		if (unlikely(ret < 0))
-			goto error;
-		else if (ret)
-			goto retry_walk;
-
-		mark_page_dirty(vcpu->kvm, table_gfn);
-		pte |= PT_DIRTY_MASK;
-		walker->ptes[walker->level - 1] = pte;
-	}
+	ret = FNAME(update_accessed_dirty_bits)(vcpu, mmu, walker, write_fault);
+	if (unlikely(ret < 0))
+		goto error;
+	else if (ret)
+		goto retry_walk;
 
 	walker->pt_access = pt_access;
 	walker->pte_access = pte_access;