Rewrite thread root flip synchronization.

Mark pending "flip function" with a `ThreadFlag` for faster
JNI transitions.

Use two more `ThreadFlag`s to fix potential race conditions.
Some checkpoints were previously running the flip function
on behalf of a suspended thread because they relied on the
thread roots being flipped, but doing that without any
synchronization meant that they could race and one could
have executed its own code while another was still running
the flip function. Other checkpoints that were peforming
a stack walk did not run the flip function at all, so they
could have seen from-space references. We now check for a
pending or running flip function at the start of the
`Thread::RunCheckPointFunction()` and proceed only after
it has completed; holding a mutator lock for the duration
of the whole function prevents a new flip function from
being installed until the checkpoint finishes.

Golem results for art-opt-cc (higher is better):
linux-ia32                       before after
NativeDowncallStaticNormal       46.581 46.813 (+0.4980%)
NativeDowncallStaticNormal6      42.247 42.268 (+0.0497%)
NativeDowncallStaticNormalRefs6  40.918 41.355 (+1.068%)
NativeDowncallVirtualNormal      46.292 46.361 (+0.1496%)
NativeDowncallVirtualNormal6     41.791 41.791 (0%)
NativeDowncallVirtualNormalRefs6 40.500 40.500 (0%)
linux-x64                        before after
NativeDowncallStaticNormal       44.169 43.956 (-0.4815%)
NativeDowncallStaticNormal6      43.198 43.198 (0%)
NativeDowncallStaticNormalRefs6  38.481 38.481 (0%)
NativeDowncallVirtualNormal      43.672 43.672 (0%)
NativeDowncallVirtualNormal6     42.247 42.268 (+0.0479%)
NativeDowncallVirtualNormalRefs6 41.355 41.355 (0%)
linux-armv7                      before after
NativeDowncallStaticNormal       9.9701 10.443 (+4.739%)
NativeDowncallStaticNormal6      9.2457 9.6525 (+4.400%)
NativeDowncallStaticNormalRefs6  8.3868 8.7209 (+3.984%)
NativeDowncallVirtualNormal      9.8377 10.304 (+4.742%)
NativeDowncallVirtualNormal6     9.3596 9.7752 (+4.440%)
NativeDowncallVirtualNormalRefs6 8.4367 8.7719 (+3.973%)
linux-armv8                      before after
NativeDowncallStaticNormal       9.8571 10.685 (+8.397%)
NativeDowncallStaticNormal6      9.4905 10.249 (+7.991%)
NativeDowncallStaticNormalRefs6  8.6705 9.3000 (+7.261%)
NativeDowncallVirtualNormal      9.3183 10.053 (+7.881%)
NativeDowncallVirtualNormal6     9.2638 9.9850 (+7.786%)
NativeDowncallVirtualNormalRefs6 8.2967 8.8714 (+6.926%)
(The x86 and x86-64 differences seem to be lost in noise.)

Test: m test-art-host-gtest
Test: testrunner.py --host --optimizing
Bug: 172332525
Change-Id: I9c2227142010f7fe6ecf07e92273bc65d728c5c6
diff --git a/runtime/thread-inl.h b/runtime/thread-inl.h
index 67e2e6a..960a870 100644
--- a/runtime/thread-inl.h
+++ b/runtime/thread-inl.h
@@ -49,7 +49,7 @@
 inline void Thread::CheckSuspend() {
   DCHECK_EQ(Thread::Current(), this);
   while (true) {
-    StateAndFlags state_and_flags(tls32_.state_and_flags.load(std::memory_order_relaxed));
+    StateAndFlags state_and_flags = GetStateAndFlags(std::memory_order_relaxed);
     if (LIKELY(!state_and_flags.IsAnyOfFlagsSet(SuspendOrCheckpointRequestFlags()))) {
       break;
     } else if (state_and_flags.IsFlagSet(ThreadFlag::kCheckpointRequest)) {
@@ -113,11 +113,10 @@
   }
 
   while (true) {
-    StateAndFlags old_state_and_flags(tls32_.state_and_flags.load(std::memory_order_relaxed));
+    StateAndFlags old_state_and_flags = GetStateAndFlags(std::memory_order_relaxed);
     CHECK_NE(old_state_and_flags.GetState(), ThreadState::kRunnable)
         << new_state << " " << *this << " " << *Thread::Current();
-    StateAndFlags new_state_and_flags = old_state_and_flags;
-    new_state_and_flags.SetState(new_state);
+    StateAndFlags new_state_and_flags = old_state_and_flags.WithState(new_state);
     bool done =
         tls32_.state_and_flags.CompareAndSetWeakRelaxed(old_state_and_flags.GetValue(),
                                                         new_state_and_flags.GetValue());
@@ -191,7 +190,7 @@
 inline void Thread::TransitionToSuspendedAndRunCheckpoints(ThreadState new_state) {
   DCHECK_NE(new_state, ThreadState::kRunnable);
   while (true) {
-    StateAndFlags old_state_and_flags(tls32_.state_and_flags.load(std::memory_order_relaxed));
+    StateAndFlags old_state_and_flags = GetStateAndFlags(std::memory_order_relaxed);
     DCHECK_EQ(old_state_and_flags.GetState(), ThreadState::kRunnable);
     if (UNLIKELY(old_state_and_flags.IsFlagSet(ThreadFlag::kCheckpointRequest))) {
       RunCheckpointFunction();
@@ -204,8 +203,7 @@
     // Change the state but keep the current flags (kCheckpointRequest is clear).
     DCHECK(!old_state_and_flags.IsFlagSet(ThreadFlag::kCheckpointRequest));
     DCHECK(!old_state_and_flags.IsFlagSet(ThreadFlag::kEmptyCheckpointRequest));
-    StateAndFlags new_state_and_flags = old_state_and_flags;
-    new_state_and_flags.SetState(new_state);
+    StateAndFlags new_state_and_flags = old_state_and_flags.WithState(new_state);
 
     // CAS the value, ensuring that prior memory operations are visible to any thread
     // that observes that we are suspended.
@@ -220,7 +218,7 @@
 
 inline void Thread::PassActiveSuspendBarriers() {
   while (true) {
-    StateAndFlags state_and_flags(tls32_.state_and_flags.load(std::memory_order_relaxed));
+    StateAndFlags state_and_flags = GetStateAndFlags(std::memory_order_relaxed);
     if (LIKELY(!state_and_flags.IsFlagSet(ThreadFlag::kCheckpointRequest) &&
                !state_and_flags.IsFlagSet(ThreadFlag::kEmptyCheckpointRequest) &&
                !state_and_flags.IsFlagSet(ThreadFlag::kActiveSuspendBarrier))) {
@@ -253,7 +251,7 @@
 }
 
 inline ThreadState Thread::TransitionFromSuspendedToRunnable() {
-  StateAndFlags old_state_and_flags(tls32_.state_and_flags.load(std::memory_order_relaxed));
+  StateAndFlags old_state_and_flags = GetStateAndFlags(std::memory_order_relaxed);
   ThreadState old_state = old_state_and_flags.GetState();
   DCHECK_NE(old_state, ThreadState::kRunnable);
   while (true) {
@@ -261,11 +259,12 @@
     // Optimize for the return from native code case - this is the fast path.
     // Atomically change from suspended to runnable if no suspend request pending.
     constexpr uint32_t kCheckedFlags =
-        SuspendOrCheckpointRequestFlags() | enum_cast<uint32_t>(ThreadFlag::kActiveSuspendBarrier);
+        SuspendOrCheckpointRequestFlags() |
+        enum_cast<uint32_t>(ThreadFlag::kActiveSuspendBarrier) |
+        FlipFunctionFlags();
     if (LIKELY(!old_state_and_flags.IsAnyOfFlagsSet(kCheckedFlags))) {
       // CAS the value with a memory barrier.
-      StateAndFlags new_state_and_flags = old_state_and_flags;
-      new_state_and_flags.SetState(ThreadState::kRunnable);
+      StateAndFlags new_state_and_flags = old_state_and_flags.WithState(ThreadState::kRunnable);
       if (LIKELY(tls32_.state_and_flags.CompareAndSetWeakAcquire(old_state_and_flags.GetValue(),
                                                                  new_state_and_flags.GetValue()))) {
         // Mark the acquisition of a share of the mutator lock.
@@ -276,15 +275,13 @@
       PassActiveSuspendBarriers(this);
     } else if (UNLIKELY(old_state_and_flags.IsFlagSet(ThreadFlag::kCheckpointRequest) ||
                         old_state_and_flags.IsFlagSet(ThreadFlag::kEmptyCheckpointRequest))) {
-      // Impossible
-      StateAndFlags flags = old_state_and_flags;
+      // Checkpoint flags should not be set while in suspended state.
       static_assert(static_cast<std::underlying_type_t<ThreadState>>(ThreadState::kRunnable) == 0u);
-      flags.SetState(ThreadState::kRunnable);  // Note: Keeping unused bits.
-      LOG(FATAL) << "Transitioning to runnable with checkpoint flag, "
-                 << " flags=" << flags.GetValue()  // State set to kRunnable = 0.
+      LOG(FATAL) << "Transitioning to Runnable with checkpoint flag,"
+                 // Note: Keeping unused flags. If they are set, it points to memory corruption.
+                 << " flags=" << old_state_and_flags.WithState(ThreadState::kRunnable).GetValue()
                  << " state=" << old_state_and_flags.GetState();
-    } else {
-      DCHECK(old_state_and_flags.IsFlagSet(ThreadFlag::kSuspendRequest));
+    } else if (old_state_and_flags.IsFlagSet(ThreadFlag::kSuspendRequest)) {
       // Wait while our suspend count is non-zero.
 
       // We pass null to the MutexLock as we may be in a situation where the
@@ -299,26 +296,44 @@
       MutexLock mu(thread_to_pass, *Locks::thread_suspend_count_lock_);
       ScopedTransitioningToRunnable scoped_transitioning_to_runnable(this);
       // Reload state and flags after locking the mutex.
-      old_state_and_flags.SetValue(tls32_.state_and_flags.load(std::memory_order_relaxed));
+      old_state_and_flags = GetStateAndFlags(std::memory_order_relaxed);
       DCHECK_EQ(old_state, old_state_and_flags.GetState());
       while (old_state_and_flags.IsFlagSet(ThreadFlag::kSuspendRequest)) {
         // Re-check when Thread::resume_cond_ is notified.
         Thread::resume_cond_->Wait(thread_to_pass);
         // Reload state and flags after waiting.
-        old_state_and_flags.SetValue(tls32_.state_and_flags.load(std::memory_order_relaxed));
+        old_state_and_flags = GetStateAndFlags(std::memory_order_relaxed);
         DCHECK_EQ(old_state, old_state_and_flags.GetState());
       }
       DCHECK_EQ(GetSuspendCount(), 0);
+    } else if (UNLIKELY(old_state_and_flags.IsFlagSet(ThreadFlag::kRunningFlipFunction)) ||
+               UNLIKELY(old_state_and_flags.IsFlagSet(ThreadFlag::kWaitingForFlipFunction))) {
+      // The thread should be suspended while another thread is running the flip function.
+      static_assert(static_cast<std::underlying_type_t<ThreadState>>(ThreadState::kRunnable) == 0u);
+      LOG(FATAL) << "Transitioning to Runnable while another thread is running the flip function,"
+                 // Note: Keeping unused flags. If they are set, it points to memory corruption.
+                 << " flags=" << old_state_and_flags.WithState(ThreadState::kRunnable).GetValue()
+                 << " state=" << old_state_and_flags.GetState();
+    } else {
+      DCHECK(old_state_and_flags.IsFlagSet(ThreadFlag::kPendingFlipFunction));
+      // CAS the value with a memory barrier.
+      // Do not set `ThreadFlag::kRunningFlipFunction` as no other thread can run
+      // the flip function for a thread that is not suspended.
+      StateAndFlags new_state_and_flags = old_state_and_flags.WithState(ThreadState::kRunnable)
+          .WithoutFlag(ThreadFlag::kPendingFlipFunction);
+      if (LIKELY(tls32_.state_and_flags.CompareAndSetWeakAcquire(old_state_and_flags.GetValue(),
+                                                                 new_state_and_flags.GetValue()))) {
+        // Mark the acquisition of a share of the mutator lock.
+        GetMutatorLock()->TransitionFromSuspendedToRunnable(this);
+        // Run the flip function.
+        RunFlipFunction(this, /*notify=*/ false);
+        break;
+      }
     }
     // Reload state and flags.
-    old_state_and_flags.SetValue(tls32_.state_and_flags.load(std::memory_order_relaxed));
+    old_state_and_flags = GetStateAndFlags(std::memory_order_relaxed);
     DCHECK_EQ(old_state, old_state_and_flags.GetState());
   }
-  // Run the flip function, if set.
-  Closure* flip_func = GetFlipFunction();
-  if (flip_func != nullptr) {
-    flip_func->Run(this);
-  }
   return static_cast<ThreadState>(old_state);
 }
 
diff --git a/runtime/thread.cc b/runtime/thread.cc
index 1c2b0cc..184d2c1 100644
--- a/runtime/thread.cc
+++ b/runtime/thread.cc
@@ -1553,6 +1553,25 @@
 }
 
 void Thread::RunCheckpointFunction() {
+  // If this thread is suspended and another thread is running the checkpoint on its behalf,
+  // we may have a pending flip function that we need to run for the sake of those checkpoints
+  // that need to walk the stack. We should not see the flip function flags when the thread
+  // is running the checkpoint on its own.
+  StateAndFlags state_and_flags = GetStateAndFlags(std::memory_order_relaxed);
+  if (UNLIKELY(state_and_flags.IsAnyOfFlagsSet(FlipFunctionFlags()))) {
+    DCHECK(IsSuspended());
+    Thread* self = Thread::Current();
+    DCHECK(self != this);
+    if (state_and_flags.IsFlagSet(ThreadFlag::kPendingFlipFunction)) {
+      EnsureFlipFunctionStarted(self);
+      state_and_flags = GetStateAndFlags(std::memory_order_relaxed);
+      DCHECK(!state_and_flags.IsFlagSet(ThreadFlag::kPendingFlipFunction));
+    }
+    if (state_and_flags.IsFlagSet(ThreadFlag::kRunningFlipFunction)) {
+      WaitForFlipFunction(self);
+    }
+  }
+
   // Grab the suspend_count lock, get the next checkpoint and update all the checkpoint fields. If
   // there are no more checkpoints we will also clear the kCheckpointRequest flag.
   Closure* checkpoint;
@@ -1576,13 +1595,15 @@
 }
 
 void Thread::RunEmptyCheckpoint() {
+  // Note: Empty checkpoint does not access the thread's stack,
+  // so we do not need to check for the flip function.
   DCHECK_EQ(Thread::Current(), this);
   AtomicClearFlag(ThreadFlag::kEmptyCheckpointRequest);
   Runtime::Current()->GetThreadList()->EmptyCheckpointBarrier()->Pass(this);
 }
 
 bool Thread::RequestCheckpoint(Closure* function) {
-  StateAndFlags old_state_and_flags(tls32_.state_and_flags.load(std::memory_order_relaxed));
+  StateAndFlags old_state_and_flags = GetStateAndFlags(std::memory_order_relaxed);
   if (old_state_and_flags.GetState() != ThreadState::kRunnable) {
     return false;  // Fail, thread is suspended and so can't run a checkpoint.
   }
@@ -1607,7 +1628,7 @@
 }
 
 bool Thread::RequestEmptyCheckpoint() {
-  StateAndFlags old_state_and_flags(tls32_.state_and_flags.load(std::memory_order_relaxed));
+  StateAndFlags old_state_and_flags = GetStateAndFlags(std::memory_order_relaxed);
   if (old_state_and_flags.GetState() != ThreadState::kRunnable) {
     // If it's not runnable, we don't need to do anything because it won't be in the middle of a
     // heap access (eg. the read barrier).
@@ -1755,23 +1776,89 @@
   }
 }
 
-Closure* Thread::GetFlipFunction() {
-  Atomic<Closure*>* atomic_func = reinterpret_cast<Atomic<Closure*>*>(&tlsPtr_.flip_function);
-  Closure* func;
-  do {
-    func = atomic_func->load(std::memory_order_relaxed);
-    if (func == nullptr) {
-      return nullptr;
-    }
-  } while (!atomic_func->CompareAndSetWeakSequentiallyConsistent(func, nullptr));
-  DCHECK(func != nullptr);
-  return func;
+void Thread::SetFlipFunction(Closure* function) {
+  // This is called with all threads suspended, except for the calling thread.
+  DCHECK(IsSuspended() || Thread::Current() == this);
+  DCHECK(function != nullptr);
+  DCHECK(tlsPtr_.flip_function == nullptr);
+  tlsPtr_.flip_function = function;
+  DCHECK(!GetStateAndFlags(std::memory_order_relaxed).IsAnyOfFlagsSet(FlipFunctionFlags()));
+  AtomicSetFlag(ThreadFlag::kPendingFlipFunction, std::memory_order_release);
 }
 
-void Thread::SetFlipFunction(Closure* function) {
-  CHECK(function != nullptr);
-  Atomic<Closure*>* atomic_func = reinterpret_cast<Atomic<Closure*>*>(&tlsPtr_.flip_function);
-  atomic_func->store(function, std::memory_order_seq_cst);
+void Thread::EnsureFlipFunctionStarted(Thread* self) {
+  while (true) {
+    StateAndFlags old_state_and_flags = GetStateAndFlags(std::memory_order_relaxed);
+    if (!old_state_and_flags.IsFlagSet(ThreadFlag::kPendingFlipFunction)) {
+      return;
+    }
+    DCHECK(!old_state_and_flags.IsFlagSet(ThreadFlag::kRunningFlipFunction));
+    StateAndFlags new_state_and_flags =
+        old_state_and_flags.WithFlag(ThreadFlag::kRunningFlipFunction)
+                           .WithoutFlag(ThreadFlag::kPendingFlipFunction);
+    if (tls32_.state_and_flags.CompareAndSetWeakAcquire(old_state_and_flags.GetValue(),
+                                                        new_state_and_flags.GetValue())) {
+      RunFlipFunction(self, /*notify=*/ true);
+      DCHECK(!GetStateAndFlags(std::memory_order_relaxed).IsAnyOfFlagsSet(FlipFunctionFlags()));
+      return;
+    }
+  }
+}
+
+void Thread::RunFlipFunction(Thread* self, bool notify) {
+  // This function is called for suspended threads and by the thread running
+  // `ThreadList::FlipThreadRoots()` after we've successfully set the flag
+  // `ThreadFlag::kRunningFlipFunction`. This flag is not set if the thread is
+  // running the flip function right after transitioning to Runnable as
+  // no other thread may run checkpoints on a thread that's actually Runnable.
+  DCHECK_EQ(notify, ReadFlag(ThreadFlag::kRunningFlipFunction));
+
+  Closure* flip_function = tlsPtr_.flip_function;
+  tlsPtr_.flip_function = nullptr;
+  DCHECK(flip_function != nullptr);
+  flip_function->Run(this);
+
+  if (notify) {
+    // Clear the `ThreadFlag::kRunningFlipFunction` and `ThreadFlag::kWaitingForFlipFunction`.
+    // Check if the latter was actually set, indicating that there is at least one waiting thread.
+    constexpr uint32_t kFlagsToClear = enum_cast<uint32_t>(ThreadFlag::kRunningFlipFunction) |
+                                       enum_cast<uint32_t>(ThreadFlag::kWaitingForFlipFunction);
+    StateAndFlags old_state_and_flags(
+        tls32_.state_and_flags.fetch_and(~kFlagsToClear, std::memory_order_release));
+    if (old_state_and_flags.IsFlagSet(ThreadFlag::kWaitingForFlipFunction)) {
+      // Notify all threads that are waiting for completion (at least one).
+      // TODO: Should we create a separate mutex and condition variable instead
+      // of piggy-backing on the `thread_suspend_count_lock_` and `resume_cond_`?
+      MutexLock mu(self, *Locks::thread_suspend_count_lock_);
+      resume_cond_->Broadcast(self);
+    }
+  }
+}
+
+void Thread::WaitForFlipFunction(Thread* self) {
+  // Another thread is running the flip function. Wait for it to complete.
+  // Check the flag while holding the mutex so that we do not miss the broadcast.
+  // Repeat the check after waiting to guard against spurious wakeups (and because
+  // we share the `thread_suspend_count_lock_` and `resume_cond_` with other code).
+  MutexLock mu(self, *Locks::thread_suspend_count_lock_);
+  while (true) {
+    StateAndFlags old_state_and_flags = GetStateAndFlags(std::memory_order_acquire);
+    DCHECK(!old_state_and_flags.IsFlagSet(ThreadFlag::kPendingFlipFunction));
+    if (!old_state_and_flags.IsFlagSet(ThreadFlag::kRunningFlipFunction)) {
+      DCHECK(!old_state_and_flags.IsAnyOfFlagsSet(FlipFunctionFlags()));
+      break;
+    }
+    if (!old_state_and_flags.IsFlagSet(ThreadFlag::kWaitingForFlipFunction)) {
+      // Mark that there is a waiting thread.
+      StateAndFlags new_state_and_flags =
+          old_state_and_flags.WithFlag(ThreadFlag::kWaitingForFlipFunction);
+      if (!tls32_.state_and_flags.CompareAndSetWeakRelaxed(old_state_and_flags.GetValue(),
+                                                           new_state_and_flags.GetValue())) {
+        continue;  // Retry.
+      }
+    }
+    resume_cond_->Wait(self);
+  }
 }
 
 void Thread::FullSuspendCheck() {
@@ -1809,27 +1896,12 @@
   return "";
 }
 
-
 void Thread::DumpState(std::ostream& os, const Thread* thread, pid_t tid) {
   std::string group_name;
   int priority;
   bool is_daemon = false;
   Thread* self = Thread::Current();
 
-  // If flip_function is not null, it means we have run a checkpoint
-  // before the thread wakes up to execute the flip function and the
-  // thread roots haven't been forwarded.  So the following access to
-  // the roots (opeer or methods in the frames) would be bad. Run it
-  // here. TODO: clean up.
-  if (thread != nullptr) {
-    ScopedObjectAccessUnchecked soa(self);
-    Thread* this_thread = const_cast<Thread*>(thread);
-    Closure* flip_func = this_thread->GetFlipFunction();
-    if (flip_func != nullptr) {
-      flip_func->Run(this_thread);
-    }
-  }
-
   // Don't do this if we are aborting since the GC may have all the threads suspended. This will
   // cause ScopedObjectAccessUnchecked to deadlock.
   if (gAborting == 0 && self != nullptr && thread != nullptr && thread->tlsPtr_.opeer != nullptr) {
@@ -1882,8 +1954,7 @@
 
   if (thread != nullptr) {
     auto suspend_log_fn = [&]() REQUIRES(Locks::thread_suspend_count_lock_) {
-      StateAndFlags state_and_flags(
-          thread->tls32_.state_and_flags.load(std::memory_order_relaxed));
+      StateAndFlags state_and_flags = thread->GetStateAndFlags(std::memory_order_relaxed);
       static_assert(
           static_cast<std::underlying_type_t<ThreadState>>(ThreadState::kRunnable) == 0u);
       state_and_flags.SetState(ThreadState::kRunnable);  // Clear state bits.
@@ -2149,19 +2220,6 @@
 }
 
 void Thread::DumpJavaStack(std::ostream& os, bool check_suspended, bool dump_locks) const {
-  // If flip_function is not null, it means we have run a checkpoint
-  // before the thread wakes up to execute the flip function and the
-  // thread roots haven't been forwarded.  So the following access to
-  // the roots (locks or methods in the frames) would be bad. Run it
-  // here. TODO: clean up.
-  {
-    Thread* this_thread = const_cast<Thread*>(this);
-    Closure* flip_func = this_thread->GetFlipFunction();
-    if (flip_func != nullptr) {
-      flip_func->Run(this_thread);
-    }
-  }
-
   // Dumping the Java stack involves the verifier for locks. The verifier operates under the
   // assumption that there is no exception pending on entry. Thus, stash any pending exception.
   // Thread::Current() instead of this in case a thread is dumping the stack of another suspended
@@ -2315,9 +2373,8 @@
 
   static_assert((sizeof(Thread) % 4) == 0U,
                 "art::Thread has a size which is not a multiple of 4.");
-  DCHECK_EQ(tls32_.state_and_flags.load(std::memory_order_relaxed), 0u);
-  StateAndFlags state_and_flags(0u);
-  state_and_flags.SetState(ThreadState::kNative);
+  DCHECK_EQ(GetStateAndFlags(std::memory_order_relaxed).GetValue(), 0u);
+  StateAndFlags state_and_flags = StateAndFlags(0u).WithState(ThreadState::kNative);
   tls32_.state_and_flags.store(state_and_flags.GetValue(), std::memory_order_relaxed);
   tls32_.interrupted.store(false, std::memory_order_relaxed);
   // Initialize with no permit; if the java Thread was unparked before being
@@ -3045,20 +3102,6 @@
     return nullptr;
   }
 
-  // If flip_function is not null, it means we have run a checkpoint
-  // before the thread wakes up to execute the flip function and the
-  // thread roots haven't been forwarded.  So the following access to
-  // the roots (locks or methods in the frames) would be bad. Run it
-  // here. TODO: clean up.
-  // Note: copied from DumpJavaStack.
-  {
-    Thread* this_thread = const_cast<Thread*>(this);
-    Closure* flip_func = this_thread->GetFlipFunction();
-    if (flip_func != nullptr) {
-      flip_func->Run(this_thread);
-    }
-  }
-
   class CollectFramesAndLocksStackVisitor : public MonitorObjectsStackVisitor {
    public:
     CollectFramesAndLocksStackVisitor(const ScopedObjectAccessAlreadyRunnable& soaa_in,
@@ -4482,7 +4525,7 @@
 
 std::string Thread::StateAndFlagsAsHexString() const {
   std::stringstream result_stream;
-  result_stream << std::hex << tls32_.state_and_flags.load(std::memory_order_relaxed);
+  result_stream << std::hex << GetStateAndFlags(std::memory_order_relaxed).GetValue();
   return result_stream.str();
 }
 
diff --git a/runtime/thread.h b/runtime/thread.h
index 868fdb4..7d76956 100644
--- a/runtime/thread.h
+++ b/runtime/thread.h
@@ -133,9 +133,27 @@
   // Register that at least 1 suspend barrier needs to be passed.
   kActiveSuspendBarrier = 1u << 3,
 
+  // Marks that a "flip function" needs to be executed on this thread.
+  kPendingFlipFunction = 1u << 4,
+
+  // Marks that the "flip function" is being executed by another thread.
+  //
+  // This is used to guards against multiple threads trying to run the
+  // "flip function" for the same thread while the thread is suspended.
+  //
+  // This is not needed when the thread is running the flip function
+  // on its own after transitioning to Runnable.
+  kRunningFlipFunction = 1u << 5,
+
+  // Marks that a thread is wating for "flip function" to complete.
+  //
+  // This is used to check if we need to broadcast the completion of the
+  // "flip function" to other threads. See also `kRunningFlipFunction`.
+  kWaitingForFlipFunction = 1u << 6,
+
   // Request that compiled JNI stubs do not transition to Native or Runnable with
   // inlined code, but take a slow path for monitoring method entry and exit events.
-  kMonitorJniEntryExit = 1u << 4,
+  kMonitorJniEntryExit = 1u << 7,
 
   // Indicates the last flag. Used for checking that the flags do not overlap thread state.
   kLastFlag = kMonitorJniEntryExit
@@ -251,8 +269,7 @@
       REQUIRES_SHARED(Locks::mutator_lock_);
 
   ThreadState GetState() const {
-    StateAndFlags state_and_flags(tls32_.state_and_flags.load(std::memory_order_relaxed));
-    return state_and_flags.GetState();
+    return GetStateAndFlags(std::memory_order_relaxed).GetState();
   }
 
   ThreadState SetState(ThreadState new_state);
@@ -267,7 +284,7 @@
   }
 
   bool IsSuspended() const {
-    StateAndFlags state_and_flags(tls32_.state_and_flags.load(std::memory_order_relaxed));
+    StateAndFlags state_and_flags = GetStateAndFlags(std::memory_order_relaxed);
     return state_and_flags.GetState() != ThreadState::kRunnable &&
            state_and_flags.IsFlagSet(ThreadFlag::kSuspendRequest);
   }
@@ -326,8 +343,18 @@
   bool RequestEmptyCheckpoint()
       REQUIRES(Locks::thread_suspend_count_lock_);
 
+  // Set the flip function. This is done with all threads suspended, except for the calling thread.
   void SetFlipFunction(Closure* function);
-  Closure* GetFlipFunction();
+
+  // Ensure that thread flip function started running. If no other thread is executing
+  // it, the calling thread shall run the flip function and then notify other threads
+  // that have tried to do that concurrently. After this function returns, the
+  // `ThreadFlag::kPendingFlipFunction` is cleared but another thread may still
+  // run the flip function as indicated by the `ThreadFlag::kRunningFlipFunction`.
+  void EnsureFlipFunctionStarted(Thread* self) REQUIRES_SHARED(Locks::mutator_lock_);
+
+  // Wait for the flip function to complete if still running on another thread.
+  void WaitForFlipFunction(Thread* self) REQUIRES_SHARED(Locks::mutator_lock_);
 
   gc::accounting::AtomicStack<mirror::Object>* GetThreadLocalMarkStack() {
     CHECK(kUseReadBarrier);
@@ -1117,16 +1144,15 @@
       REQUIRES(Locks::thread_suspend_count_lock_);
 
   bool ReadFlag(ThreadFlag flag) const {
-    StateAndFlags state_and_flags(tls32_.state_and_flags.load(std::memory_order_relaxed));
-    return state_and_flags.IsFlagSet(flag);
+    return GetStateAndFlags(std::memory_order_relaxed).IsFlagSet(flag);
   }
 
-  void AtomicSetFlag(ThreadFlag flag) {
-    tls32_.state_and_flags.fetch_or(enum_cast<uint32_t>(flag), std::memory_order_seq_cst);
+  void AtomicSetFlag(ThreadFlag flag, std::memory_order order = std::memory_order_seq_cst) {
+    tls32_.state_and_flags.fetch_or(enum_cast<uint32_t>(flag), order);
   }
 
-  void AtomicClearFlag(ThreadFlag flag) {
-    tls32_.state_and_flags.fetch_and(~enum_cast<uint32_t>(flag), std::memory_order_seq_cst);
+  void AtomicClearFlag(ThreadFlag flag, std::memory_order order = std::memory_order_seq_cst) {
+    tls32_.state_and_flags.fetch_and(~enum_cast<uint32_t>(flag), order);
   }
 
   void ResetQuickAllocEntryPointsForThread();
@@ -1331,6 +1357,12 @@
            enum_cast<uint32_t>(ThreadFlag::kEmptyCheckpointRequest);
   }
 
+  static constexpr uint32_t FlipFunctionFlags() {
+    return enum_cast<uint32_t>(ThreadFlag::kPendingFlipFunction) |
+           enum_cast<uint32_t>(ThreadFlag::kRunningFlipFunction) |
+           enum_cast<uint32_t>(ThreadFlag::kWaitingForFlipFunction);
+  }
+
   static constexpr uint32_t StoredThreadStateValue(ThreadState state) {
     return StateAndFlags::EncodeState(state);
   }
@@ -1365,8 +1397,10 @@
   // Avoid use, callers should use SetState.
   // Used only by `Thread` destructor and stack trace collection in semi-space GC (currently
   // disabled by `kStoreStackTraces = false`).
-  ThreadState SetStateUnsafe(ThreadState new_state) {
-    StateAndFlags old_state_and_flags(tls32_.state_and_flags.load(std::memory_order_relaxed));
+  // NO_THREAD_SAFETY_ANALYSIS: This function is "Unsafe" and can be called in
+  // different states, so clang cannot perform the thread safety analysis.
+  ThreadState SetStateUnsafe(ThreadState new_state) NO_THREAD_SAFETY_ANALYSIS {
+    StateAndFlags old_state_and_flags = GetStateAndFlags(std::memory_order_relaxed);
     ThreadState old_state = old_state_and_flags.GetState();
     if (old_state == new_state) {
       // Nothing to do.
@@ -1387,7 +1421,7 @@
           break;
         }
         // Reload state and flags.
-        old_state_and_flags.SetValue(tls32_.state_and_flags.load(std::memory_order_relaxed));
+        old_state_and_flags = GetStateAndFlags(std::memory_order_relaxed);
         DCHECK_EQ(old_state, old_state_and_flags.GetState());
       }
     }
@@ -1441,7 +1475,8 @@
   void TearDownAlternateSignalStack();
 
   ALWAYS_INLINE void TransitionToSuspendedAndRunCheckpoints(ThreadState new_state)
-      REQUIRES(!Locks::thread_suspend_count_lock_, !Roles::uninterruptible_);
+      REQUIRES(!Locks::thread_suspend_count_lock_, !Roles::uninterruptible_)
+      REQUIRES_SHARED(Locks::mutator_lock_);
 
   ALWAYS_INLINE void PassActiveSuspendBarriers()
       REQUIRES(!Locks::thread_suspend_count_lock_, !Roles::uninterruptible_);
@@ -1470,7 +1505,9 @@
   // Runs a single checkpoint function. If there are no more pending checkpoint functions it will
   // clear the kCheckpointRequest flag. The caller is responsible for calling this in a loop until
   // the kCheckpointRequest flag is cleared.
-  void RunCheckpointFunction() REQUIRES(!Locks::thread_suspend_count_lock_);
+  void RunCheckpointFunction()
+      REQUIRES(!Locks::thread_suspend_count_lock_)
+      REQUIRES_SHARED(Locks::mutator_lock_);
   void RunEmptyCheckpoint();
 
   bool PassActiveSuspendBarriers(Thread* self)
@@ -1514,6 +1551,18 @@
       value_ |= enum_cast<uint32_t>(flag);
     }
 
+    StateAndFlags WithFlag(ThreadFlag flag) const {
+      StateAndFlags result = *this;
+      result.SetFlag(flag);
+      return result;
+    }
+
+    StateAndFlags WithoutFlag(ThreadFlag flag) const {
+      StateAndFlags result = *this;
+      result.ClearFlag(flag);
+      return result;
+    }
+
     void ClearFlag(ThreadFlag flag) {
       value_ &= ~enum_cast<uint32_t>(flag);
     }
@@ -1529,6 +1578,12 @@
       value_ = ThreadStateField::Update(state, value_);
     }
 
+    StateAndFlags WithState(ThreadState state) const {
+      StateAndFlags result = *this;
+      result.SetState(state);
+      return result;
+    }
+
     static constexpr uint32_t EncodeState(ThreadState state) {
       ValidateThreadState(state);
       return ThreadStateField::Encode(state);
@@ -1554,9 +1609,17 @@
   };
   static_assert(sizeof(StateAndFlags) == sizeof(uint32_t), "Unexpected StateAndFlags size");
 
+  StateAndFlags GetStateAndFlags(std::memory_order order) const {
+    return StateAndFlags(tls32_.state_and_flags.load(order));
+  }
+
   // Format state and flags as a hex string. For diagnostic output.
   std::string StateAndFlagsAsHexString() const;
 
+  // Run the flip function and, if requested, notify other threads that may have tried
+  // to do that concurrently.
+  void RunFlipFunction(Thread* self, bool notify) REQUIRES_SHARED(Locks::mutator_lock_);
+
   static void ThreadExitCallback(void* arg);
 
   // Maximum number of suspend barriers.
diff --git a/runtime/thread_list.cc b/runtime/thread_list.cc
index 23e527c..4e3b40b 100644
--- a/runtime/thread_list.cc
+++ b/runtime/thread_list.cc
@@ -544,10 +544,9 @@
     MutexLock mu(self, *Locks::thread_list_lock_);
     MutexLock mu2(self, *Locks::thread_suspend_count_lock_);
     --suspend_all_count_;
-    for (const auto& thread : list_) {
-      // Set the flip function for all threads because Thread::DumpState/DumpJavaStack() (invoked by
-      // a checkpoint) may cause the flip function to be run for a runnable/suspended thread before
-      // a runnable thread runs it for itself or we run it for a suspended thread below.
+    for (Thread* thread : list_) {
+      // Set the flip function for all threads because once we start resuming any threads,
+      // they may need to run the flip function on behalf of other threads, even this one.
       thread->SetFlipFunction(thread_flip_visitor);
       if (thread == self) {
         continue;
@@ -572,21 +571,17 @@
 
   collector->GetHeap()->ThreadFlipEnd(self);
 
-  // Run the closure on the other threads.
+  // Try to run the closure on the other threads.
   {
     TimingLogger::ScopedTiming split3("FlipOtherThreads", collector->GetTimings());
     ReaderMutexLock mu(self, *Locks::mutator_lock_);
-    for (const auto& thread : other_threads) {
-      Closure* flip_func = thread->GetFlipFunction();
-      if (flip_func != nullptr) {
-        flip_func->Run(thread);
-      }
+    for (Thread* thread : other_threads) {
+      thread->EnsureFlipFunctionStarted(self);
+      DCHECK(!thread->ReadFlag(ThreadFlag::kPendingFlipFunction));
     }
-    // Run it for self.
-    Closure* flip_func = self->GetFlipFunction();
-    if (flip_func != nullptr) {
-      flip_func->Run(self);
-    }
+    // Try to run the flip function for self.
+    self->EnsureFlipFunctionStarted(self);
+    DCHECK(!self->ReadFlag(ThreadFlag::kPendingFlipFunction));
   }
 
   // Resume other threads.