arm64: Mask all exceptions during kernel_exit

To take RAS Exceptions as quickly as possible we need to keep SError
unmasked as much as possible. We need to mask it during kernel_exit
as taking an error from this code will overwrite the exception-registers.

Adding a naked 'disable_daif' to kernel_exit causes a performance problem
for micro-benchmarks that do no real work, (e.g. calling getpid() in a
loop). This is because the ret_to_user loop has already masked IRQs so
that the TIF_WORK_MASK thread flags can't change underneath it, adding
disable_daif is an additional self-synchronising operation.

In the future, the RAS APEI code may need to modify the TIF_WORK_MASK
flags from an SError, in which case the ret_to_user loop must mask SError
while it examines the flags.

Disable all exceptions for return to EL1. For return to EL0 get the
ret_to_user loop to leave all exceptions masked once it has done its
work, this avoids an extra pstate-write.

Signed-off-by: James Morse <james.morse@arm.com>
Reviewed-by: Julien Thierry <julien.thierry@arm.com>
Reviewed-by: Catalin Marinas <catalin.marinas@arm.com>
Signed-off-by: Will Deacon <will.deacon@arm.com>
diff --git a/arch/arm64/kernel/signal.c b/arch/arm64/kernel/signal.c
index 0bdc96c..8e6500c 100644
--- a/arch/arm64/kernel/signal.c
+++ b/arch/arm64/kernel/signal.c
@@ -31,6 +31,7 @@
 #include <linux/ratelimit.h>
 #include <linux/syscalls.h>
 
+#include <asm/daifflags.h>
 #include <asm/debug-monitors.h>
 #include <asm/elf.h>
 #include <asm/cacheflush.h>
@@ -756,9 +757,12 @@ asmlinkage void do_notify_resume(struct pt_regs *regs,
 		addr_limit_user_check();
 
 		if (thread_flags & _TIF_NEED_RESCHED) {
+			/* Unmask Debug and SError for the next task */
+			local_daif_restore(DAIF_PROCCTX_NOIRQ);
+
 			schedule();
 		} else {
-			local_irq_enable();
+			local_daif_restore(DAIF_PROCCTX);
 
 			if (thread_flags & _TIF_UPROBE)
 				uprobe_notify_resume(regs);
@@ -775,7 +779,7 @@ asmlinkage void do_notify_resume(struct pt_regs *regs,
 				fpsimd_restore_current_state();
 		}
 
-		local_irq_disable();
+		local_daif_mask();
 		thread_flags = READ_ONCE(current_thread_info()->flags);
 	} while (thread_flags & _TIF_WORK_MASK);
 }