KVM: x86: Load SMRAM in a single shot when leaving SMM

RSM emulation is currently broken on VMX when the interrupted guest has
CR4.VMXE=1.  Rather than dance around the issue of HF_SMM_MASK being set
when loading SMSTATE into architectural state, ideally RSM emulation
itself would be reworked to clear HF_SMM_MASK prior to loading non-SMM
architectural state.

Ostensibly, the only motivation for having HF_SMM_MASK set throughout
the loading of state from the SMRAM save state area is so that the
memory accesses from GET_SMSTATE() are tagged with role.smm.  Load
all of the SMRAM save state area from guest memory at the beginning of
RSM emulation, and load state from the buffer instead of reading guest
memory one-by-one.

This paves the way for clearing HF_SMM_MASK prior to loading state,
and also aligns RSM with the enter_smm() behavior, which fills a
buffer and writes SMRAM save state in a single go.

Signed-off-by: Sean Christopherson <sean.j.christopherson@intel.com>
Signed-off-by: Paolo Bonzini <pbonzini@redhat.com>
diff --git a/arch/x86/kvm/emulate.c b/arch/x86/kvm/emulate.c
index c338984..ae0d289 100644
--- a/arch/x86/kvm/emulate.c
+++ b/arch/x86/kvm/emulate.c
@@ -2339,16 +2339,6 @@ static int emulator_has_longmode(struct x86_emulate_ctxt *ctxt)
 	return edx & bit(X86_FEATURE_LM);
 }
 
-#define GET_SMSTATE(type, smbase, offset)				  \
-	({								  \
-	 type __val;							  \
-	 int r = ctxt->ops->read_phys(ctxt, smbase + offset, &__val,      \
-				      sizeof(__val));			  \
-	 if (r != X86EMUL_CONTINUE)					  \
-		 return X86EMUL_UNHANDLEABLE;				  \
-	 __val;								  \
-	})
-
 static void rsm_set_desc_flags(struct desc_struct *desc, u32 flags)
 {
 	desc->g    = (flags >> 23) & 1;
@@ -2361,27 +2351,29 @@ static void rsm_set_desc_flags(struct desc_struct *desc, u32 flags)
 	desc->type = (flags >>  8) & 15;
 }
 
-static int rsm_load_seg_32(struct x86_emulate_ctxt *ctxt, u64 smbase, int n)
+static int rsm_load_seg_32(struct x86_emulate_ctxt *ctxt, const char *smstate,
+			   int n)
 {
 	struct desc_struct desc;
 	int offset;
 	u16 selector;
 
-	selector = GET_SMSTATE(u32, smbase, 0x7fa8 + n * 4);
+	selector = GET_SMSTATE(u32, smstate, 0x7fa8 + n * 4);
 
 	if (n < 3)
 		offset = 0x7f84 + n * 12;
 	else
 		offset = 0x7f2c + (n - 3) * 12;
 
-	set_desc_base(&desc,      GET_SMSTATE(u32, smbase, offset + 8));
-	set_desc_limit(&desc,     GET_SMSTATE(u32, smbase, offset + 4));
-	rsm_set_desc_flags(&desc, GET_SMSTATE(u32, smbase, offset));
+	set_desc_base(&desc,      GET_SMSTATE(u32, smstate, offset + 8));
+	set_desc_limit(&desc,     GET_SMSTATE(u32, smstate, offset + 4));
+	rsm_set_desc_flags(&desc, GET_SMSTATE(u32, smstate, offset));
 	ctxt->ops->set_segment(ctxt, selector, &desc, 0, n);
 	return X86EMUL_CONTINUE;
 }
 
-static int rsm_load_seg_64(struct x86_emulate_ctxt *ctxt, u64 smbase, int n)
+static int rsm_load_seg_64(struct x86_emulate_ctxt *ctxt, const char *smstate,
+			   int n)
 {
 	struct desc_struct desc;
 	int offset;
@@ -2390,11 +2382,11 @@ static int rsm_load_seg_64(struct x86_emulate_ctxt *ctxt, u64 smbase, int n)
 
 	offset = 0x7e00 + n * 16;
 
-	selector =                GET_SMSTATE(u16, smbase, offset);
-	rsm_set_desc_flags(&desc, GET_SMSTATE(u16, smbase, offset + 2) << 8);
-	set_desc_limit(&desc,     GET_SMSTATE(u32, smbase, offset + 4));
-	set_desc_base(&desc,      GET_SMSTATE(u32, smbase, offset + 8));
-	base3 =                   GET_SMSTATE(u32, smbase, offset + 12);
+	selector =                GET_SMSTATE(u16, smstate, offset);
+	rsm_set_desc_flags(&desc, GET_SMSTATE(u16, smstate, offset + 2) << 8);
+	set_desc_limit(&desc,     GET_SMSTATE(u32, smstate, offset + 4));
+	set_desc_base(&desc,      GET_SMSTATE(u32, smstate, offset + 8));
+	base3 =                   GET_SMSTATE(u32, smstate, offset + 12);
 
 	ctxt->ops->set_segment(ctxt, selector, &desc, base3, n);
 	return X86EMUL_CONTINUE;
@@ -2445,7 +2437,8 @@ static int rsm_enter_protected_mode(struct x86_emulate_ctxt *ctxt,
 	return X86EMUL_CONTINUE;
 }
 
-static int rsm_load_state_32(struct x86_emulate_ctxt *ctxt, u64 smbase)
+static int rsm_load_state_32(struct x86_emulate_ctxt *ctxt,
+			     const char *smstate)
 {
 	struct desc_struct desc;
 	struct desc_ptr dt;
@@ -2453,53 +2446,54 @@ static int rsm_load_state_32(struct x86_emulate_ctxt *ctxt, u64 smbase)
 	u32 val, cr0, cr3, cr4;
 	int i;
 
-	cr0 =                      GET_SMSTATE(u32, smbase, 0x7ffc);
-	cr3 =                      GET_SMSTATE(u32, smbase, 0x7ff8);
-	ctxt->eflags =             GET_SMSTATE(u32, smbase, 0x7ff4) | X86_EFLAGS_FIXED;
-	ctxt->_eip =               GET_SMSTATE(u32, smbase, 0x7ff0);
+	cr0 =                      GET_SMSTATE(u32, smstate, 0x7ffc);
+	cr3 =                      GET_SMSTATE(u32, smstate, 0x7ff8);
+	ctxt->eflags =             GET_SMSTATE(u32, smstate, 0x7ff4) | X86_EFLAGS_FIXED;
+	ctxt->_eip =               GET_SMSTATE(u32, smstate, 0x7ff0);
 
 	for (i = 0; i < 8; i++)
-		*reg_write(ctxt, i) = GET_SMSTATE(u32, smbase, 0x7fd0 + i * 4);
+		*reg_write(ctxt, i) = GET_SMSTATE(u32, smstate, 0x7fd0 + i * 4);
 
-	val = GET_SMSTATE(u32, smbase, 0x7fcc);
+	val = GET_SMSTATE(u32, smstate, 0x7fcc);
 	ctxt->ops->set_dr(ctxt, 6, (val & DR6_VOLATILE) | DR6_FIXED_1);
-	val = GET_SMSTATE(u32, smbase, 0x7fc8);
+	val = GET_SMSTATE(u32, smstate, 0x7fc8);
 	ctxt->ops->set_dr(ctxt, 7, (val & DR7_VOLATILE) | DR7_FIXED_1);
 
-	selector =                 GET_SMSTATE(u32, smbase, 0x7fc4);
-	set_desc_base(&desc,       GET_SMSTATE(u32, smbase, 0x7f64));
-	set_desc_limit(&desc,      GET_SMSTATE(u32, smbase, 0x7f60));
-	rsm_set_desc_flags(&desc,  GET_SMSTATE(u32, smbase, 0x7f5c));
+	selector =                 GET_SMSTATE(u32, smstate, 0x7fc4);
+	set_desc_base(&desc,       GET_SMSTATE(u32, smstate, 0x7f64));
+	set_desc_limit(&desc,      GET_SMSTATE(u32, smstate, 0x7f60));
+	rsm_set_desc_flags(&desc,  GET_SMSTATE(u32, smstate, 0x7f5c));
 	ctxt->ops->set_segment(ctxt, selector, &desc, 0, VCPU_SREG_TR);
 
-	selector =                 GET_SMSTATE(u32, smbase, 0x7fc0);
-	set_desc_base(&desc,       GET_SMSTATE(u32, smbase, 0x7f80));
-	set_desc_limit(&desc,      GET_SMSTATE(u32, smbase, 0x7f7c));
-	rsm_set_desc_flags(&desc,  GET_SMSTATE(u32, smbase, 0x7f78));
+	selector =                 GET_SMSTATE(u32, smstate, 0x7fc0);
+	set_desc_base(&desc,       GET_SMSTATE(u32, smstate, 0x7f80));
+	set_desc_limit(&desc,      GET_SMSTATE(u32, smstate, 0x7f7c));
+	rsm_set_desc_flags(&desc,  GET_SMSTATE(u32, smstate, 0x7f78));
 	ctxt->ops->set_segment(ctxt, selector, &desc, 0, VCPU_SREG_LDTR);
 
-	dt.address =               GET_SMSTATE(u32, smbase, 0x7f74);
-	dt.size =                  GET_SMSTATE(u32, smbase, 0x7f70);
+	dt.address =               GET_SMSTATE(u32, smstate, 0x7f74);
+	dt.size =                  GET_SMSTATE(u32, smstate, 0x7f70);
 	ctxt->ops->set_gdt(ctxt, &dt);
 
-	dt.address =               GET_SMSTATE(u32, smbase, 0x7f58);
-	dt.size =                  GET_SMSTATE(u32, smbase, 0x7f54);
+	dt.address =               GET_SMSTATE(u32, smstate, 0x7f58);
+	dt.size =                  GET_SMSTATE(u32, smstate, 0x7f54);
 	ctxt->ops->set_idt(ctxt, &dt);
 
 	for (i = 0; i < 6; i++) {
-		int r = rsm_load_seg_32(ctxt, smbase, i);
+		int r = rsm_load_seg_32(ctxt, smstate, i);
 		if (r != X86EMUL_CONTINUE)
 			return r;
 	}
 
-	cr4 = GET_SMSTATE(u32, smbase, 0x7f14);
+	cr4 = GET_SMSTATE(u32, smstate, 0x7f14);
 
-	ctxt->ops->set_smbase(ctxt, GET_SMSTATE(u32, smbase, 0x7ef8));
+	ctxt->ops->set_smbase(ctxt, GET_SMSTATE(u32, smstate, 0x7ef8));
 
 	return rsm_enter_protected_mode(ctxt, cr0, cr3, cr4);
 }
 
-static int rsm_load_state_64(struct x86_emulate_ctxt *ctxt, u64 smbase)
+static int rsm_load_state_64(struct x86_emulate_ctxt *ctxt,
+			     const char *smstate)
 {
 	struct desc_struct desc;
 	struct desc_ptr dt;
@@ -2509,43 +2503,43 @@ static int rsm_load_state_64(struct x86_emulate_ctxt *ctxt, u64 smbase)
 	int i, r;
 
 	for (i = 0; i < 16; i++)
-		*reg_write(ctxt, i) = GET_SMSTATE(u64, smbase, 0x7ff8 - i * 8);
+		*reg_write(ctxt, i) = GET_SMSTATE(u64, smstate, 0x7ff8 - i * 8);
 
-	ctxt->_eip   = GET_SMSTATE(u64, smbase, 0x7f78);
-	ctxt->eflags = GET_SMSTATE(u32, smbase, 0x7f70) | X86_EFLAGS_FIXED;
+	ctxt->_eip   = GET_SMSTATE(u64, smstate, 0x7f78);
+	ctxt->eflags = GET_SMSTATE(u32, smstate, 0x7f70) | X86_EFLAGS_FIXED;
 
-	val = GET_SMSTATE(u32, smbase, 0x7f68);
+	val = GET_SMSTATE(u32, smstate, 0x7f68);
 	ctxt->ops->set_dr(ctxt, 6, (val & DR6_VOLATILE) | DR6_FIXED_1);
-	val = GET_SMSTATE(u32, smbase, 0x7f60);
+	val = GET_SMSTATE(u32, smstate, 0x7f60);
 	ctxt->ops->set_dr(ctxt, 7, (val & DR7_VOLATILE) | DR7_FIXED_1);
 
-	cr0 =                       GET_SMSTATE(u64, smbase, 0x7f58);
-	cr3 =                       GET_SMSTATE(u64, smbase, 0x7f50);
-	cr4 =                       GET_SMSTATE(u64, smbase, 0x7f48);
-	ctxt->ops->set_smbase(ctxt, GET_SMSTATE(u32, smbase, 0x7f00));
-	val =                       GET_SMSTATE(u64, smbase, 0x7ed0);
+	cr0 =                       GET_SMSTATE(u64, smstate, 0x7f58);
+	cr3 =                       GET_SMSTATE(u64, smstate, 0x7f50);
+	cr4 =                       GET_SMSTATE(u64, smstate, 0x7f48);
+	ctxt->ops->set_smbase(ctxt, GET_SMSTATE(u32, smstate, 0x7f00));
+	val =                       GET_SMSTATE(u64, smstate, 0x7ed0);
 	ctxt->ops->set_msr(ctxt, MSR_EFER, val & ~EFER_LMA);
 
-	selector =                  GET_SMSTATE(u32, smbase, 0x7e90);
-	rsm_set_desc_flags(&desc,   GET_SMSTATE(u32, smbase, 0x7e92) << 8);
-	set_desc_limit(&desc,       GET_SMSTATE(u32, smbase, 0x7e94));
-	set_desc_base(&desc,        GET_SMSTATE(u32, smbase, 0x7e98));
-	base3 =                     GET_SMSTATE(u32, smbase, 0x7e9c);
+	selector =                  GET_SMSTATE(u32, smstate, 0x7e90);
+	rsm_set_desc_flags(&desc,   GET_SMSTATE(u32, smstate, 0x7e92) << 8);
+	set_desc_limit(&desc,       GET_SMSTATE(u32, smstate, 0x7e94));
+	set_desc_base(&desc,        GET_SMSTATE(u32, smstate, 0x7e98));
+	base3 =                     GET_SMSTATE(u32, smstate, 0x7e9c);
 	ctxt->ops->set_segment(ctxt, selector, &desc, base3, VCPU_SREG_TR);
 
-	dt.size =                   GET_SMSTATE(u32, smbase, 0x7e84);
-	dt.address =                GET_SMSTATE(u64, smbase, 0x7e88);
+	dt.size =                   GET_SMSTATE(u32, smstate, 0x7e84);
+	dt.address =                GET_SMSTATE(u64, smstate, 0x7e88);
 	ctxt->ops->set_idt(ctxt, &dt);
 
-	selector =                  GET_SMSTATE(u32, smbase, 0x7e70);
-	rsm_set_desc_flags(&desc,   GET_SMSTATE(u32, smbase, 0x7e72) << 8);
-	set_desc_limit(&desc,       GET_SMSTATE(u32, smbase, 0x7e74));
-	set_desc_base(&desc,        GET_SMSTATE(u32, smbase, 0x7e78));
-	base3 =                     GET_SMSTATE(u32, smbase, 0x7e7c);
+	selector =                  GET_SMSTATE(u32, smstate, 0x7e70);
+	rsm_set_desc_flags(&desc,   GET_SMSTATE(u32, smstate, 0x7e72) << 8);
+	set_desc_limit(&desc,       GET_SMSTATE(u32, smstate, 0x7e74));
+	set_desc_base(&desc,        GET_SMSTATE(u32, smstate, 0x7e78));
+	base3 =                     GET_SMSTATE(u32, smstate, 0x7e7c);
 	ctxt->ops->set_segment(ctxt, selector, &desc, base3, VCPU_SREG_LDTR);
 
-	dt.size =                   GET_SMSTATE(u32, smbase, 0x7e64);
-	dt.address =                GET_SMSTATE(u64, smbase, 0x7e68);
+	dt.size =                   GET_SMSTATE(u32, smstate, 0x7e64);
+	dt.address =                GET_SMSTATE(u64, smstate, 0x7e68);
 	ctxt->ops->set_gdt(ctxt, &dt);
 
 	r = rsm_enter_protected_mode(ctxt, cr0, cr3, cr4);
@@ -2553,7 +2547,7 @@ static int rsm_load_state_64(struct x86_emulate_ctxt *ctxt, u64 smbase)
 		return r;
 
 	for (i = 0; i < 6; i++) {
-		r = rsm_load_seg_64(ctxt, smbase, i);
+		r = rsm_load_seg_64(ctxt, smstate, i);
 		if (r != X86EMUL_CONTINUE)
 			return r;
 	}
@@ -2564,12 +2558,19 @@ static int rsm_load_state_64(struct x86_emulate_ctxt *ctxt, u64 smbase)
 static int em_rsm(struct x86_emulate_ctxt *ctxt)
 {
 	unsigned long cr0, cr4, efer;
+	char buf[512];
 	u64 smbase;
 	int ret;
 
 	if ((ctxt->ops->get_hflags(ctxt) & X86EMUL_SMM_MASK) == 0)
 		return emulate_ud(ctxt);
 
+	smbase = ctxt->ops->get_smbase(ctxt);
+
+	ret = ctxt->ops->read_phys(ctxt, smbase + 0xfe00, buf, sizeof(buf));
+	if (ret != X86EMUL_CONTINUE)
+		return X86EMUL_UNHANDLEABLE;
+
 	/*
 	 * Get back to real mode, to prepare a safe state in which to load
 	 * CR0/CR3/CR4/EFER.  It's all a bit more complicated if the vCPU
@@ -2605,20 +2606,18 @@ static int em_rsm(struct x86_emulate_ctxt *ctxt)
 	efer = 0;
 	ctxt->ops->set_msr(ctxt, MSR_EFER, efer);
 
-	smbase = ctxt->ops->get_smbase(ctxt);
-
 	/*
 	 * Give pre_leave_smm() a chance to make ISA-specific changes to the
 	 * vCPU state (e.g. enter guest mode) before loading state from the SMM
 	 * state-save area.
 	 */
-	if (ctxt->ops->pre_leave_smm(ctxt, smbase))
+	if (ctxt->ops->pre_leave_smm(ctxt, buf))
 		return X86EMUL_UNHANDLEABLE;
 
 	if (emulator_has_longmode(ctxt))
-		ret = rsm_load_state_64(ctxt, smbase + 0x8000);
+		ret = rsm_load_state_64(ctxt, buf);
 	else
-		ret = rsm_load_state_32(ctxt, smbase + 0x8000);
+		ret = rsm_load_state_32(ctxt, buf);
 
 	if (ret != X86EMUL_CONTINUE) {
 		/* FIXME: should triple fault */