x86/kvm/mmu: make vcpu->mmu a pointer to the current MMU

As a preparation to full MMU split between L1 and L2 make vcpu->arch.mmu
a pointer to the currently used mmu. For now, this is always
vcpu->arch.root_mmu. No functional change.

Signed-off-by: Vitaly Kuznetsov <vkuznets@redhat.com>
Signed-off-by: Paolo Bonzini <pbonzini@redhat.com>
Reviewed-by: Sean Christopherson <sean.j.christopherson@intel.com>
diff --git a/arch/x86/kvm/vmx.c b/arch/x86/kvm/vmx.c
index c255686..d243eba 100644
--- a/arch/x86/kvm/vmx.c
+++ b/arch/x86/kvm/vmx.c
@@ -5111,9 +5111,10 @@ static inline void __vmx_flush_tlb(struct kvm_vcpu *vcpu, int vpid,
 				bool invalidate_gpa)
 {
 	if (enable_ept && (invalidate_gpa || !enable_vpid)) {
-		if (!VALID_PAGE(vcpu->arch.mmu.root_hpa))
+		if (!VALID_PAGE(vcpu->arch.mmu->root_hpa))
 			return;
-		ept_sync_context(construct_eptp(vcpu, vcpu->arch.mmu.root_hpa));
+		ept_sync_context(construct_eptp(vcpu,
+						vcpu->arch.mmu->root_hpa));
 	} else {
 		vpid_sync_context(vpid);
 	}
@@ -9122,7 +9123,7 @@ static int handle_invpcid(struct kvm_vcpu *vcpu)
 		}
 
 		for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++)
-			if (kvm_get_pcid(vcpu, vcpu->arch.mmu.prev_roots[i].cr3)
+			if (kvm_get_pcid(vcpu, vcpu->arch.mmu->prev_roots[i].cr3)
 			    == operand.pcid)
 				roots_to_free |= KVM_MMU_ROOT_PREVIOUS(i);
 
@@ -11304,16 +11305,16 @@ static void nested_ept_init_mmu_context(struct kvm_vcpu *vcpu)
 			VMX_EPT_EXECUTE_ONLY_BIT,
 			nested_ept_ad_enabled(vcpu),
 			nested_ept_get_cr3(vcpu));
-	vcpu->arch.mmu.set_cr3           = vmx_set_cr3;
-	vcpu->arch.mmu.get_cr3           = nested_ept_get_cr3;
-	vcpu->arch.mmu.inject_page_fault = nested_ept_inject_page_fault;
+	vcpu->arch.mmu->set_cr3           = vmx_set_cr3;
+	vcpu->arch.mmu->get_cr3           = nested_ept_get_cr3;
+	vcpu->arch.mmu->inject_page_fault = nested_ept_inject_page_fault;
 
 	vcpu->arch.walk_mmu              = &vcpu->arch.nested_mmu;
 }
 
 static void nested_ept_uninit_mmu_context(struct kvm_vcpu *vcpu)
 {
-	vcpu->arch.walk_mmu = &vcpu->arch.mmu;
+	vcpu->arch.walk_mmu = &vcpu->arch.root_mmu;
 }
 
 static bool nested_vmx_is_page_fault_vmexit(struct vmcs12 *vmcs12,