x86/nmi/64: Improve nested NMI comments

I found the nested NMI documentation to be difficult to follow.
Improve the comments.

Signed-off-by: Andy Lutomirski <luto@kernel.org>
Reviewed-by: Steven Rostedt <rostedt@goodmis.org>
Cc: Borislav Petkov <bp@suse.de>
Cc: Linus Torvalds <torvalds@linux-foundation.org>
Cc: Peter Zijlstra <peterz@infradead.org>
Cc: Thomas Gleixner <tglx@linutronix.de>
Cc: stable@vger.kernel.org
Signed-off-by: Ingo Molnar <mingo@kernel.org>
diff --git a/arch/x86/entry/entry_64.S b/arch/x86/entry/entry_64.S
index 8668bbd..f54d63a 100644
--- a/arch/x86/entry/entry_64.S
+++ b/arch/x86/entry/entry_64.S
@@ -1237,11 +1237,12 @@
 	 *  If the variable is not set and the stack is not the NMI
 	 *  stack then:
 	 *    o Set the special variable on the stack
-	 *    o Copy the interrupt frame into a "saved" location on the stack
-	 *    o Copy the interrupt frame into a "copy" location on the stack
+	 *    o Copy the interrupt frame into an "outermost" location on the
+	 *      stack
+	 *    o Copy the interrupt frame into an "iret" location on the stack
 	 *    o Continue processing the NMI
 	 *  If the variable is set or the previous stack is the NMI stack:
-	 *    o Modify the "copy" location to jump to the repeate_nmi
+	 *    o Modify the "iret" location to jump to the repeat_nmi
 	 *    o return back to the first NMI
 	 *
 	 * Now on exit of the first NMI, we first clear the stack variable
@@ -1317,18 +1318,60 @@
 
 .Lnmi_from_kernel:
 	/*
-	 * Check the special variable on the stack to see if NMIs are
-	 * executing.
+	 * Here's what our stack frame will look like:
+	 * +---------------------------------------------------------+
+	 * | original SS                                             |
+	 * | original Return RSP                                     |
+	 * | original RFLAGS                                         |
+	 * | original CS                                             |
+	 * | original RIP                                            |
+	 * +---------------------------------------------------------+
+	 * | temp storage for rdx                                    |
+	 * +---------------------------------------------------------+
+	 * | "NMI executing" variable                                |
+	 * +---------------------------------------------------------+
+	 * | iret SS          } Copied from "outermost" frame        |
+	 * | iret Return RSP  } on each loop iteration; overwritten  |
+	 * | iret RFLAGS      } by a nested NMI to force another     |
+	 * | iret CS          } iteration if needed.                 |
+	 * | iret RIP         }                                      |
+	 * +---------------------------------------------------------+
+	 * | outermost SS          } initialized in first_nmi;       |
+	 * | outermost Return RSP  } will not be changed before      |
+	 * | outermost RFLAGS      } NMI processing is done.         |
+	 * | outermost CS          } Copied to "iret" frame on each  |
+	 * | outermost RIP         } iteration.                      |
+	 * +---------------------------------------------------------+
+	 * | pt_regs                                                 |
+	 * +---------------------------------------------------------+
+	 *
+	 * The "original" frame is used by hardware.  Before re-enabling
+	 * NMIs, we need to be done with it, and we need to leave enough
+	 * space for the asm code here.
+	 *
+	 * We return by executing IRET while RSP points to the "iret" frame.
+	 * That will either return for real or it will loop back into NMI
+	 * processing.
+	 *
+	 * The "outermost" frame is copied to the "iret" frame on each
+	 * iteration of the loop, so each iteration starts with the "iret"
+	 * frame pointing to the final return target.
+	 */
+
+	/*
+	 * Determine whether we're a nested NMI.
+	 *
+	 * First check "NMI executing".  If it's set, then we're nested.
+	 * This will not detect if we interrupted an outer NMI just
+	 * before IRET.
 	 */
 	cmpl	$1, -8(%rsp)
 	je	nested_nmi
 
 	/*
-	 * Now test if the previous stack was an NMI stack.
-	 * We need the double check. We check the NMI stack to satisfy the
-	 * race when the first NMI clears the variable before returning.
-	 * We check the variable because the first NMI could be in a
-	 * breakpoint routine using a breakpoint stack.
+	 * Now test if the previous stack was an NMI stack.  This covers
+	 * the case where we interrupt an outer NMI after it clears
+	 * "NMI executing" but before IRET.
 	 */
 	lea	6*8(%rsp), %rdx
 	/* Compare the NMI stack (rdx) with the stack we came from (4*8(%rsp)) */
@@ -1344,9 +1387,11 @@
 
 nested_nmi:
 	/*
-	 * Do nothing if we interrupted the fixup in repeat_nmi.
-	 * It's about to repeat the NMI handler, so we are fine
-	 * with ignoring this one.
+	 * If we interrupted an NMI that is between repeat_nmi and
+	 * end_repeat_nmi, then we must not modify the "iret" frame
+	 * because it's being written by the outer NMI.  That's okay;
+	 * the outer NMI handler is about to call do_nmi anyway,
+	 * so we can just resume the outer NMI.
 	 */
 	movq	$repeat_nmi, %rdx
 	cmpq	8(%rsp), %rdx
@@ -1356,7 +1401,10 @@
 	ja	nested_nmi_out
 
 1:
-	/* Set up the interrupted NMIs stack to jump to repeat_nmi */
+	/*
+	 * Modify the "iret" frame to point to repeat_nmi, forcing another
+	 * iteration of NMI handling.
+	 */
 	leaq	-1*8(%rsp), %rdx
 	movq	%rdx, %rsp
 	leaq	-10*8(%rsp), %rdx
@@ -1372,61 +1420,27 @@
 nested_nmi_out:
 	popq	%rdx
 
-	/* No need to check faults here */
+	/* We are returning to kernel mode, so this cannot result in a fault. */
 	INTERRUPT_RETURN
 
 first_nmi:
-	/*
-	 * Because nested NMIs will use the pushed location that we
-	 * stored in rdx, we must keep that space available.
-	 * Here's what our stack frame will look like:
-	 * +-------------------------+
-	 * | original SS             |
-	 * | original Return RSP     |
-	 * | original RFLAGS         |
-	 * | original CS             |
-	 * | original RIP            |
-	 * +-------------------------+
-	 * | temp storage for rdx    |
-	 * +-------------------------+
-	 * | NMI executing variable  |
-	 * +-------------------------+
-	 * | copied SS               |
-	 * | copied Return RSP       |
-	 * | copied RFLAGS           |
-	 * | copied CS               |
-	 * | copied RIP              |
-	 * +-------------------------+
-	 * | Saved SS                |
-	 * | Saved Return RSP        |
-	 * | Saved RFLAGS            |
-	 * | Saved CS                |
-	 * | Saved RIP               |
-	 * +-------------------------+
-	 * | pt_regs                 |
-	 * +-------------------------+
-	 *
-	 * The saved stack frame is used to fix up the copied stack frame
-	 * that a nested NMI may change to make the interrupted NMI iret jump
-	 * to the repeat_nmi. The original stack frame and the temp storage
-	 * is also used by nested NMIs and can not be trusted on exit.
-	 */
-	/* Do not pop rdx, nested NMIs will corrupt that part of the stack */
+	/* Restore rdx. */
 	movq	(%rsp), %rdx
 
-	/* Set the NMI executing variable on the stack. */
+	/* Set "NMI executing" on the stack. */
 	pushq	$1
 
-	/* Leave room for the "copied" frame */
+	/* Leave room for the "iret" frame */
 	subq	$(5*8), %rsp
 
-	/* Copy the stack frame to the Saved frame */
+	/* Copy the "original" frame to the "outermost" frame */
 	.rept 5
 	pushq	11*8(%rsp)
 	.endr
 
 	/* Everything up to here is safe from nested NMIs */
 
+repeat_nmi:
 	/*
 	 * If there was a nested NMI, the first NMI's iret will return
 	 * here. But NMIs are still enabled and we can take another
@@ -1435,16 +1449,21 @@
 	 * it will just return, as we are about to repeat an NMI anyway.
 	 * This makes it safe to copy to the stack frame that a nested
 	 * NMI will update.
-	 */
-repeat_nmi:
-	/*
-	 * Update the stack variable to say we are still in NMI (the update
-	 * is benign for the non-repeat case, where 1 was pushed just above
-	 * to this very stack slot).
+	 *
+	 * RSP is pointing to "outermost RIP".  gsbase is unknown, but, if
+	 * we're repeating an NMI, gsbase has the same value that it had on
+	 * the first iteration.  paranoid_entry will load the kernel
+	 * gsbase if needed before we call do_nmi.
+	 *
+	 * Set "NMI executing" in case we came back here via IRET.
 	 */
 	movq	$1, 10*8(%rsp)
 
-	/* Make another copy, this one may be modified by nested NMIs */
+	/*
+	 * Copy the "outermost" frame to the "iret" frame.  NMIs that nest
+	 * here must not modify the "iret" frame while we're writing to
+	 * it or it will end up containing garbage.
+	 */
 	addq	$(10*8), %rsp
 	.rept 5
 	pushq	-6*8(%rsp)
@@ -1453,9 +1472,9 @@
 end_repeat_nmi:
 
 	/*
-	 * Everything below this point can be preempted by a nested
-	 * NMI if the first NMI took an exception and reset our iret stack
-	 * so that we repeat another NMI.
+	 * Everything below this point can be preempted by a nested NMI.
+	 * If this happens, then the inner NMI will change the "iret"
+	 * frame to point back to repeat_nmi.
 	 */
 	pushq	$-1				/* ORIG_RAX: no syscall to restart */
 	ALLOC_PT_GPREGS_ON_STACK
@@ -1481,11 +1500,18 @@
 nmi_restore:
 	RESTORE_EXTRA_REGS
 	RESTORE_C_REGS
-	/* Pop the extra iret frame at once */
+
+	/* Point RSP at the "iret" frame. */
 	REMOVE_PT_GPREGS_FROM_STACK 6*8
 
-	/* Clear the NMI executing stack variable */
+	/* Clear "NMI executing". */
 	movq	$0, 5*8(%rsp)
+
+	/*
+	 * INTERRUPT_RETURN reads the "iret" frame and exits the NMI
+	 * stack in a single instruction.  We are returning to kernel
+	 * mode, so this cannot result in a fault.
+	 */
 	INTERRUPT_RETURN
 END(nmi)