psi: Fix PSI_MEM_FULL state when tasks are in memstall and doing reclaim

We've noticed cases where tasks in a cgroup are stalled on memory but
there is little memory FULL pressure since tasks stay on the runqueue
in reclaim.

A simple example involves a single threaded program that keeps leaking
and touching large amounts of memory. It runs in a cgroup with swap
enabled, memory.high set at 10M and cpu.max ratio set at 5%. Though
there is significant CPU pressure and memory SOME, there is barely any
memory FULL since the task enters reclaim and stays on the runqueue.
However, this memory-bound task is effectively stalled on memory and
we expect memory FULL to match memory SOME in this scenario.

The code is confused about memstall && running, thinking there is a
stalled task and a productive task when there's only one task: a
reclaimer that's counted as both. To fix this, we redefine the
condition for PSI_MEM_FULL to check that all running tasks are in an
active memstall instead of checking that there are no running tasks.

        case PSI_MEM_FULL:
-               return unlikely(tasks[NR_MEMSTALL] && !tasks[NR_RUNNING]);
+               return unlikely(tasks[NR_MEMSTALL] &&
+                       tasks[NR_RUNNING] == tasks[NR_MEMSTALL_RUNNING]);

This will capture reclaimers. It will also capture tasks that called
psi_memstall_enter() and are about to sleep, but this should be
negligible noise.

Signed-off-by: Brian Chen <brianchen118@gmail.com>
Signed-off-by: Peter Zijlstra (Intel) <peterz@infradead.org>
Acked-by: Johannes Weiner <hannes@cmpxchg.org>
Link: https://lore.kernel.org/r/20211110213312.310243-1-brianchen118@gmail.com
diff --git a/kernel/sched/stats.h b/kernel/sched/stats.h
index cfb0893..3a3c826 100644
--- a/kernel/sched/stats.h
+++ b/kernel/sched/stats.h
@@ -118,6 +118,9 @@ static inline void psi_enqueue(struct task_struct *p, bool wakeup)
 	if (static_branch_likely(&psi_disabled))
 		return;
 
+	if (p->in_memstall)
+		set |= TSK_MEMSTALL_RUNNING;
+
 	if (!wakeup || p->sched_psi_wake_requeue) {
 		if (p->in_memstall)
 			set |= TSK_MEMSTALL;
@@ -148,7 +151,7 @@ static inline void psi_dequeue(struct task_struct *p, bool sleep)
 		return;
 
 	if (p->in_memstall)
-		clear |= TSK_MEMSTALL;
+		clear |= (TSK_MEMSTALL | TSK_MEMSTALL_RUNNING);
 
 	psi_task_change(p, clear, 0);
 }