kvm: x86: Support selectively freeing either current or previous MMU root

kvm_mmu_free_roots() now takes a mask specifying which roots to free, so
that either one of the roots (active/previous) can be individually freed
when needed.

Signed-off-by: Junaid Shahid <junaids@google.com>
Signed-off-by: Paolo Bonzini <pbonzini@redhat.com>
diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index 0b77c23..262b0bc 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -1287,6 +1287,10 @@ static inline int __kvm_irq_line_state(unsigned long *irq_state,
 	return !!(*irq_state);
 }
 
+#define KVM_MMU_ROOT_CURRENT	BIT(0)
+#define KVM_MMU_ROOT_PREVIOUS	BIT(1)
+#define KVM_MMU_ROOTS_ALL	(~0UL)
+
 int kvm_pic_set_irq(struct kvm_pic *pic, int irq, int irq_source_id, int level);
 void kvm_pic_clear_all(struct kvm_pic *pic, int irq_source_id);
 
@@ -1298,7 +1302,7 @@ void __kvm_mmu_free_some_pages(struct kvm_vcpu *vcpu);
 int kvm_mmu_load(struct kvm_vcpu *vcpu);
 void kvm_mmu_unload(struct kvm_vcpu *vcpu);
 void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu);
-void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, bool free_prev_root);
+void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, ulong roots_to_free);
 gpa_t translate_nested_gpa(struct kvm_vcpu *vcpu, gpa_t gpa, u32 access,
 			   struct x86_exception *exception);
 gpa_t kvm_mmu_gva_to_gpa_read(struct kvm_vcpu *vcpu, gva_t gva,
diff --git a/arch/x86/kvm/mmu.c b/arch/x86/kvm/mmu.c
index 6eeca91..0f6965c 100644
--- a/arch/x86/kvm/mmu.c
+++ b/arch/x86/kvm/mmu.c
@@ -3438,14 +3438,18 @@ static void mmu_free_root_page(struct kvm *kvm, hpa_t *root_hpa,
 	*root_hpa = INVALID_PAGE;
 }
 
-void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, bool free_prev_root)
+/* roots_to_free must be some combination of the KVM_MMU_ROOT_* flags */
+void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, ulong roots_to_free)
 {
 	int i;
 	LIST_HEAD(invalid_list);
 	struct kvm_mmu *mmu = &vcpu->arch.mmu;
+	bool free_active_root = roots_to_free & KVM_MMU_ROOT_CURRENT;
+	bool free_prev_root = roots_to_free & KVM_MMU_ROOT_PREVIOUS;
 
-	if (!VALID_PAGE(mmu->root_hpa) &&
-	    (!VALID_PAGE(mmu->prev_root.hpa) || !free_prev_root))
+	/* Before acquiring the MMU lock, see if we need to do any real work. */
+	if (!(free_active_root && VALID_PAGE(mmu->root_hpa)) &&
+	    !(free_prev_root && VALID_PAGE(mmu->prev_root.hpa)))
 		return;
 
 	spin_lock(&vcpu->kvm->mmu_lock);
@@ -3454,15 +3458,19 @@ void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, bool free_prev_root)
 		mmu_free_root_page(vcpu->kvm, &mmu->prev_root.hpa,
 				   &invalid_list);
 
-	if (mmu->shadow_root_level >= PT64_ROOT_4LEVEL &&
-	    (mmu->root_level >= PT64_ROOT_4LEVEL || mmu->direct_map)) {
-		mmu_free_root_page(vcpu->kvm, &mmu->root_hpa, &invalid_list);
-	} else {
-		for (i = 0; i < 4; ++i)
-			if (mmu->pae_root[i] != 0)
-				mmu_free_root_page(vcpu->kvm, &mmu->pae_root[i],
-						   &invalid_list);
-		mmu->root_hpa = INVALID_PAGE;
+	if (free_active_root) {
+		if (mmu->shadow_root_level >= PT64_ROOT_4LEVEL &&
+		    (mmu->root_level >= PT64_ROOT_4LEVEL || mmu->direct_map)) {
+			mmu_free_root_page(vcpu->kvm, &mmu->root_hpa,
+					   &invalid_list);
+		} else {
+			for (i = 0; i < 4; ++i)
+				if (mmu->pae_root[i] != 0)
+					mmu_free_root_page(vcpu->kvm,
+							   &mmu->pae_root[i],
+							   &invalid_list);
+			mmu->root_hpa = INVALID_PAGE;
+		}
 	}
 
 	kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
@@ -4109,7 +4117,7 @@ static void __kvm_mmu_new_cr3(struct kvm_vcpu *vcpu, gpa_t new_cr3,
 			      bool skip_tlb_flush)
 {
 	if (!fast_cr3_switch(vcpu, new_cr3, new_role, skip_tlb_flush))
-		kvm_mmu_free_roots(vcpu, false);
+		kvm_mmu_free_roots(vcpu, KVM_MMU_ROOT_CURRENT);
 }
 
 void kvm_mmu_new_cr3(struct kvm_vcpu *vcpu, gpa_t new_cr3, bool skip_tlb_flush)
@@ -4885,7 +4893,7 @@ EXPORT_SYMBOL_GPL(kvm_mmu_load);
 
 void kvm_mmu_unload(struct kvm_vcpu *vcpu)
 {
-	kvm_mmu_free_roots(vcpu, true);
+	kvm_mmu_free_roots(vcpu, KVM_MMU_ROOTS_ALL);
 	WARN_ON(VALID_PAGE(vcpu->arch.mmu.root_hpa));
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_unload);