Merge branch 'tip/locking/urgent'

Pull in dependencies.
diff --git a/drivers/gpu/drm/drm_modeset_lock.c b/drivers/gpu/drm/drm_modeset_lock.c
index fcfe1a0..bf8a6e8 100644
--- a/drivers/gpu/drm/drm_modeset_lock.c
+++ b/drivers/gpu/drm/drm_modeset_lock.c
@@ -248,7 +248,7 @@ static inline int modeset_lock(struct drm_modeset_lock *lock,
 	if (ctx->trylock_only) {
 		lockdep_assert_held(&ctx->ww_ctx);
 
-		if (!ww_mutex_trylock(&lock->mutex))
+		if (!ww_mutex_trylock(&lock->mutex, NULL))
 			return -EBUSY;
 		else
 			return 0;
diff --git a/drivers/regulator/core.c b/drivers/regulator/core.c
index ca6caba..f4d441b 100644
--- a/drivers/regulator/core.c
+++ b/drivers/regulator/core.c
@@ -145,7 +145,7 @@ static inline int regulator_lock_nested(struct regulator_dev *rdev,
 
 	mutex_lock(&regulator_nesting_mutex);
 
-	if (ww_ctx || !ww_mutex_trylock(&rdev->mutex)) {
+	if (!ww_mutex_trylock(&rdev->mutex, ww_ctx)) {
 		if (rdev->mutex_owner == current)
 			rdev->ref_cnt++;
 		else
diff --git a/include/linux/debug_locks.h b/include/linux/debug_locks.h
index 3f49e65..dbb409d 100644
--- a/include/linux/debug_locks.h
+++ b/include/linux/debug_locks.h
@@ -47,8 +47,6 @@ extern int debug_locks_off(void);
 # define locking_selftest()	do { } while (0)
 #endif
 
-struct task_struct;
-
 #ifdef CONFIG_LOCKDEP
 extern void debug_show_all_locks(void);
 extern void debug_show_held_locks(struct task_struct *task);
diff --git a/include/linux/dma-resv.h b/include/linux/dma-resv.h
index e1ca208..39fefb8 100644
--- a/include/linux/dma-resv.h
+++ b/include/linux/dma-resv.h
@@ -173,7 +173,7 @@ static inline int dma_resv_lock_slow_interruptible(struct dma_resv *obj,
  */
 static inline bool __must_check dma_resv_trylock(struct dma_resv *obj)
 {
-	return ww_mutex_trylock(&obj->lock);
+	return ww_mutex_trylock(&obj->lock, NULL);
 }
 
 /**
diff --git a/include/linux/kernel.h b/include/linux/kernel.h
index 2776423..e8696e4 100644
--- a/include/linux/kernel.h
+++ b/include/linux/kernel.h
@@ -111,8 +111,8 @@ static __always_inline void might_resched(void)
 #endif /* CONFIG_PREEMPT_* */
 
 #ifdef CONFIG_DEBUG_ATOMIC_SLEEP
-extern void ___might_sleep(const char *file, int line, int preempt_offset);
-extern void __might_sleep(const char *file, int line, int preempt_offset);
+extern void __might_resched(const char *file, int line, unsigned int offsets);
+extern void __might_sleep(const char *file, int line);
 extern void __cant_sleep(const char *file, int line, int preempt_offset);
 extern void __cant_migrate(const char *file, int line);
 
@@ -129,7 +129,7 @@ extern void __cant_migrate(const char *file, int line);
  * supposed to.
  */
 # define might_sleep() \
-	do { __might_sleep(__FILE__, __LINE__, 0); might_resched(); } while (0)
+	do { __might_sleep(__FILE__, __LINE__); might_resched(); } while (0)
 /**
  * cant_sleep - annotation for functions that cannot sleep
  *
@@ -168,10 +168,9 @@ extern void __cant_migrate(const char *file, int line);
  */
 # define non_block_end() WARN_ON(current->non_block_count-- == 0)
 #else
-  static inline void ___might_sleep(const char *file, int line,
-				   int preempt_offset) { }
-  static inline void __might_sleep(const char *file, int line,
-				   int preempt_offset) { }
+  static inline void __might_resched(const char *file, int line,
+				     unsigned int offsets) { }
+static inline void __might_sleep(const char *file, int line) { }
 # define might_sleep() do { might_resched(); } while (0)
 # define cant_sleep() do { } while (0)
 # define cant_migrate()		do { } while (0)
diff --git a/include/linux/lockdep_types.h b/include/linux/lockdep_types.h
index 3e726ac..d224308 100644
--- a/include/linux/lockdep_types.h
+++ b/include/linux/lockdep_types.h
@@ -21,7 +21,7 @@ enum lockdep_wait_type {
 	LD_WAIT_SPIN,		/* spin loops, raw_spinlock_t etc.. */
 
 #ifdef CONFIG_PROVE_RAW_LOCK_NESTING
-	LD_WAIT_CONFIG,		/* CONFIG_PREEMPT_LOCK, spinlock_t etc.. */
+	LD_WAIT_CONFIG,		/* preemptible in PREEMPT_RT, spinlock_t etc.. */
 #else
 	LD_WAIT_CONFIG = LD_WAIT_SPIN,
 #endif
diff --git a/include/linux/preempt.h b/include/linux/preempt.h
index 4d244e2..031898b 100644
--- a/include/linux/preempt.h
+++ b/include/linux/preempt.h
@@ -122,9 +122,10 @@
  * The preempt_count offset after spin_lock()
  */
 #if !defined(CONFIG_PREEMPT_RT)
-#define PREEMPT_LOCK_OFFSET	PREEMPT_DISABLE_OFFSET
+#define PREEMPT_LOCK_OFFSET		PREEMPT_DISABLE_OFFSET
 #else
-#define PREEMPT_LOCK_OFFSET	0
+/* Locks on RT do not disable preemption */
+#define PREEMPT_LOCK_OFFSET		0
 #endif
 
 /*
diff --git a/include/linux/sched.h b/include/linux/sched.h
index e12b524..21b7cd0 100644
--- a/include/linux/sched.h
+++ b/include/linux/sched.h
@@ -2038,7 +2038,7 @@ static inline int _cond_resched(void) { return 0; }
 #endif /* !defined(CONFIG_PREEMPTION) || defined(CONFIG_PREEMPT_DYNAMIC) */
 
 #define cond_resched() ({			\
-	___might_sleep(__FILE__, __LINE__, 0);	\
+	__might_resched(__FILE__, __LINE__, 0);	\
 	_cond_resched();			\
 })
 
@@ -2046,19 +2046,38 @@ extern int __cond_resched_lock(spinlock_t *lock);
 extern int __cond_resched_rwlock_read(rwlock_t *lock);
 extern int __cond_resched_rwlock_write(rwlock_t *lock);
 
-#define cond_resched_lock(lock) ({				\
-	___might_sleep(__FILE__, __LINE__, PREEMPT_LOCK_OFFSET);\
-	__cond_resched_lock(lock);				\
+#define MIGHT_RESCHED_RCU_SHIFT		8
+#define MIGHT_RESCHED_PREEMPT_MASK	((1U << MIGHT_RESCHED_RCU_SHIFT) - 1)
+
+#ifndef CONFIG_PREEMPT_RT
+/*
+ * Non RT kernels have an elevated preempt count due to the held lock,
+ * but are not allowed to be inside a RCU read side critical section
+ */
+# define PREEMPT_LOCK_RESCHED_OFFSETS	PREEMPT_LOCK_OFFSET
+#else
+/*
+ * spin/rw_lock() on RT implies rcu_read_lock(). The might_sleep() check in
+ * cond_resched*lock() has to take that into account because it checks for
+ * preempt_count() and rcu_preempt_depth().
+ */
+# define PREEMPT_LOCK_RESCHED_OFFSETS	\
+	(PREEMPT_LOCK_OFFSET + (1U << MIGHT_RESCHED_RCU_SHIFT))
+#endif
+
+#define cond_resched_lock(lock) ({						\
+	__might_resched(__FILE__, __LINE__, PREEMPT_LOCK_RESCHED_OFFSETS);	\
+	__cond_resched_lock(lock);						\
 })
 
-#define cond_resched_rwlock_read(lock) ({			\
-	__might_sleep(__FILE__, __LINE__, PREEMPT_LOCK_OFFSET);	\
-	__cond_resched_rwlock_read(lock);			\
+#define cond_resched_rwlock_read(lock) ({					\
+	__might_resched(__FILE__, __LINE__, PREEMPT_LOCK_RESCHED_OFFSETS);	\
+	__cond_resched_rwlock_read(lock);					\
 })
 
-#define cond_resched_rwlock_write(lock) ({			\
-	__might_sleep(__FILE__, __LINE__, PREEMPT_LOCK_OFFSET);	\
-	__cond_resched_rwlock_write(lock);			\
+#define cond_resched_rwlock_write(lock) ({					\
+	__might_resched(__FILE__, __LINE__, PREEMPT_LOCK_RESCHED_OFFSETS);	\
+	__cond_resched_rwlock_write(lock);					\
 })
 
 static inline void cond_resched_rcu(void)
diff --git a/include/linux/ww_mutex.h b/include/linux/ww_mutex.h
index 29db736..bb76308 100644
--- a/include/linux/ww_mutex.h
+++ b/include/linux/ww_mutex.h
@@ -28,12 +28,10 @@
 #ifndef CONFIG_PREEMPT_RT
 #define WW_MUTEX_BASE			mutex
 #define ww_mutex_base_init(l,n,k)	__mutex_init(l,n,k)
-#define ww_mutex_base_trylock(l)	mutex_trylock(l)
 #define ww_mutex_base_is_locked(b)	mutex_is_locked((b))
 #else
 #define WW_MUTEX_BASE			rt_mutex
 #define ww_mutex_base_init(l,n,k)	__rt_mutex_init(l,n,k)
-#define ww_mutex_base_trylock(l)	rt_mutex_trylock(l)
 #define ww_mutex_base_is_locked(b)	rt_mutex_base_is_locked(&(b)->rtmutex)
 #endif
 
@@ -339,17 +337,8 @@ ww_mutex_lock_slow_interruptible(struct ww_mutex *lock,
 
 extern void ww_mutex_unlock(struct ww_mutex *lock);
 
-/**
- * ww_mutex_trylock - tries to acquire the w/w mutex without acquire context
- * @lock: mutex to lock
- *
- * Trylocks a mutex without acquire context, so no deadlock detection is
- * possible. Returns 1 if the mutex has been acquired successfully, 0 otherwise.
- */
-static inline int __must_check ww_mutex_trylock(struct ww_mutex *lock)
-{
-	return ww_mutex_base_trylock(&lock->base);
-}
+extern int __must_check ww_mutex_trylock(struct ww_mutex *lock,
+					 struct ww_acquire_ctx *ctx);
 
 /***
  * ww_mutex_destroy - mark a w/w mutex unusable
diff --git a/kernel/locking/lockdep.c b/kernel/locking/lockdep.c
index bf1c00c..4e63129 100644
--- a/kernel/locking/lockdep.c
+++ b/kernel/locking/lockdep.c
@@ -4671,7 +4671,7 @@ print_lock_invalid_wait_context(struct task_struct *curr,
 /*
  * Verify the wait_type context.
  *
- * This check validates we takes locks in the right wait-type order; that is it
+ * This check validates we take locks in the right wait-type order; that is it
  * ensures that we do not take mutexes inside spinlocks and do not attempt to
  * acquire spinlocks inside raw_spinlocks and the sort.
  *
@@ -5366,7 +5366,7 @@ int __lock_is_held(const struct lockdep_map *lock, int read)
 		struct held_lock *hlock = curr->held_locks + i;
 
 		if (match_held_lock(hlock, lock)) {
-			if (read == -1 || hlock->read == read)
+			if (read == -1 || !!hlock->read == read)
 				return LOCK_STATE_HELD;
 
 			return LOCK_STATE_NOT_HELD;
diff --git a/kernel/locking/mutex.c b/kernel/locking/mutex.c
index d456579..2fede72 100644
--- a/kernel/locking/mutex.c
+++ b/kernel/locking/mutex.c
@@ -94,6 +94,9 @@ static inline unsigned long __owner_flags(unsigned long owner)
 	return owner & MUTEX_FLAGS;
 }
 
+/*
+ * Returns: __mutex_owner(lock) on failure or NULL on success.
+ */
 static inline struct task_struct *__mutex_trylock_common(struct mutex *lock, bool handoff)
 {
 	unsigned long owner, curr = (unsigned long)current;
@@ -736,6 +739,44 @@ __ww_mutex_lock(struct mutex *lock, unsigned int state, unsigned int subclass,
 	return __mutex_lock_common(lock, state, subclass, NULL, ip, ww_ctx, true);
 }
 
+/**
+ * ww_mutex_trylock - tries to acquire the w/w mutex with optional acquire context
+ * @ww: mutex to lock
+ * @ww_ctx: optional w/w acquire context
+ *
+ * Trylocks a mutex with the optional acquire context; no deadlock detection is
+ * possible. Returns 1 if the mutex has been acquired successfully, 0 otherwise.
+ *
+ * Unlike ww_mutex_lock, no deadlock handling is performed. However, if a @ctx is
+ * specified, -EALREADY handling may happen in calls to ww_mutex_trylock.
+ *
+ * A mutex acquired with this function must be released with ww_mutex_unlock.
+ */
+int ww_mutex_trylock(struct ww_mutex *ww, struct ww_acquire_ctx *ww_ctx)
+{
+	if (!ww_ctx)
+		return mutex_trylock(&ww->base);
+
+	MUTEX_WARN_ON(ww->base.magic != &ww->base);
+
+	/*
+	 * Reset the wounded flag after a kill. No other process can
+	 * race and wound us here, since they can't have a valid owner
+	 * pointer if we don't have any locks held.
+	 */
+	if (ww_ctx->acquired == 0)
+		ww_ctx->wounded = 0;
+
+	if (__mutex_trylock(&ww->base)) {
+		ww_mutex_set_context_fastpath(ww, ww_ctx);
+		mutex_acquire_nest(&ww->base.dep_map, 0, 1, &ww_ctx->dep_map, _RET_IP_);
+		return 1;
+	}
+
+	return 0;
+}
+EXPORT_SYMBOL(ww_mutex_trylock);
+
 #ifdef CONFIG_DEBUG_LOCK_ALLOC
 void __sched
 mutex_lock_nested(struct mutex *lock, unsigned int subclass)
diff --git a/kernel/locking/rtmutex.c b/kernel/locking/rtmutex.c
index 6bb116c..0c6a48d 100644
--- a/kernel/locking/rtmutex.c
+++ b/kernel/locking/rtmutex.c
@@ -446,17 +446,24 @@ static __always_inline void rt_mutex_adjust_prio(struct task_struct *p)
 }
 
 /* RT mutex specific wake_q wrappers */
+static __always_inline void rt_mutex_wake_q_add_task(struct rt_wake_q_head *wqh,
+						     struct task_struct *task,
+						     unsigned int wake_state)
+{
+	if (IS_ENABLED(CONFIG_PREEMPT_RT) && wake_state == TASK_RTLOCK_WAIT) {
+		if (IS_ENABLED(CONFIG_PROVE_LOCKING))
+			WARN_ON_ONCE(wqh->rtlock_task);
+		get_task_struct(task);
+		wqh->rtlock_task = task;
+	} else {
+		wake_q_add(&wqh->head, task);
+	}
+}
+
 static __always_inline void rt_mutex_wake_q_add(struct rt_wake_q_head *wqh,
 						struct rt_mutex_waiter *w)
 {
-	if (IS_ENABLED(CONFIG_PREEMPT_RT) && w->wake_state != TASK_NORMAL) {
-		if (IS_ENABLED(CONFIG_PROVE_LOCKING))
-			WARN_ON_ONCE(wqh->rtlock_task);
-		get_task_struct(w->task);
-		wqh->rtlock_task = w->task;
-	} else {
-		wake_q_add(&wqh->head, w->task);
-	}
+	rt_mutex_wake_q_add_task(wqh, w->task, w->wake_state);
 }
 
 static __always_inline void rt_mutex_wake_up_q(struct rt_wake_q_head *wqh)
diff --git a/kernel/locking/rwbase_rt.c b/kernel/locking/rwbase_rt.c
index 88191f6..15c8110 100644
--- a/kernel/locking/rwbase_rt.c
+++ b/kernel/locking/rwbase_rt.c
@@ -148,6 +148,7 @@ static void __sched __rwbase_read_unlock(struct rwbase_rt *rwb,
 {
 	struct rt_mutex_base *rtm = &rwb->rtmutex;
 	struct task_struct *owner;
+	DEFINE_RT_WAKE_Q(wqh);
 
 	raw_spin_lock_irq(&rtm->wait_lock);
 	/*
@@ -158,9 +159,12 @@ static void __sched __rwbase_read_unlock(struct rwbase_rt *rwb,
 	 */
 	owner = rt_mutex_owner(rtm);
 	if (owner)
-		wake_up_state(owner, state);
+		rt_mutex_wake_q_add_task(&wqh, owner, state);
 
+	/* Pairs with the preempt_enable in rt_mutex_wake_up_q() */
+	preempt_disable();
 	raw_spin_unlock_irq(&rtm->wait_lock);
+	rt_mutex_wake_up_q(&wqh);
 }
 
 static __always_inline void rwbase_read_unlock(struct rwbase_rt *rwb,
diff --git a/kernel/locking/spinlock_rt.c b/kernel/locking/spinlock_rt.c
index d2912e4..b2e553f 100644
--- a/kernel/locking/spinlock_rt.c
+++ b/kernel/locking/spinlock_rt.c
@@ -24,6 +24,17 @@
 #define RT_MUTEX_BUILD_SPINLOCKS
 #include "rtmutex.c"
 
+/*
+ * __might_resched() skips the state check as rtlocks are state
+ * preserving. Take RCU nesting into account as spin/read/write_lock() can
+ * legitimately nest into an RCU read side critical section.
+ */
+#define RTLOCK_RESCHED_OFFSETS						\
+	(rcu_preempt_depth() << MIGHT_RESCHED_RCU_SHIFT)
+
+#define rtlock_might_resched()						\
+	__might_resched(__FILE__, __LINE__, RTLOCK_RESCHED_OFFSETS)
+
 static __always_inline void rtlock_lock(struct rt_mutex_base *rtm)
 {
 	if (unlikely(!rt_mutex_cmpxchg_acquire(rtm, NULL, current)))
@@ -32,7 +43,7 @@ static __always_inline void rtlock_lock(struct rt_mutex_base *rtm)
 
 static __always_inline void __rt_spin_lock(spinlock_t *lock)
 {
-	___might_sleep(__FILE__, __LINE__, 0);
+	rtlock_might_resched();
 	rtlock_lock(&lock->lock);
 	rcu_read_lock();
 	migrate_disable();
@@ -210,7 +221,7 @@ EXPORT_SYMBOL(rt_write_trylock);
 
 void __sched rt_read_lock(rwlock_t *rwlock)
 {
-	___might_sleep(__FILE__, __LINE__, 0);
+	rtlock_might_resched();
 	rwlock_acquire_read(&rwlock->dep_map, 0, 0, _RET_IP_);
 	rwbase_read_lock(&rwlock->rwbase, TASK_RTLOCK_WAIT);
 	rcu_read_lock();
@@ -220,7 +231,7 @@ EXPORT_SYMBOL(rt_read_lock);
 
 void __sched rt_write_lock(rwlock_t *rwlock)
 {
-	___might_sleep(__FILE__, __LINE__, 0);
+	rtlock_might_resched();
 	rwlock_acquire(&rwlock->dep_map, 0, 0, _RET_IP_);
 	rwbase_write_lock(&rwlock->rwbase, TASK_RTLOCK_WAIT);
 	rcu_read_lock();
diff --git a/kernel/locking/test-ww_mutex.c b/kernel/locking/test-ww_mutex.c
index 3e82f44..3530041 100644
--- a/kernel/locking/test-ww_mutex.c
+++ b/kernel/locking/test-ww_mutex.c
@@ -16,6 +16,15 @@
 static DEFINE_WD_CLASS(ww_class);
 struct workqueue_struct *wq;
 
+#ifdef CONFIG_DEBUG_WW_MUTEX_SLOWPATH
+#define ww_acquire_init_noinject(a, b) do { \
+		ww_acquire_init((a), (b)); \
+		(a)->deadlock_inject_countdown = ~0U; \
+	} while (0)
+#else
+#define ww_acquire_init_noinject(a, b) ww_acquire_init((a), (b))
+#endif
+
 struct test_mutex {
 	struct work_struct work;
 	struct ww_mutex mutex;
@@ -36,7 +45,7 @@ static void test_mutex_work(struct work_struct *work)
 	wait_for_completion(&mtx->go);
 
 	if (mtx->flags & TEST_MTX_TRY) {
-		while (!ww_mutex_trylock(&mtx->mutex))
+		while (!ww_mutex_trylock(&mtx->mutex, NULL))
 			cond_resched();
 	} else {
 		ww_mutex_lock(&mtx->mutex, NULL);
@@ -109,19 +118,39 @@ static int test_mutex(void)
 	return 0;
 }
 
-static int test_aa(void)
+static int test_aa(bool trylock)
 {
 	struct ww_mutex mutex;
 	struct ww_acquire_ctx ctx;
 	int ret;
+	const char *from = trylock ? "trylock" : "lock";
 
 	ww_mutex_init(&mutex, &ww_class);
 	ww_acquire_init(&ctx, &ww_class);
 
-	ww_mutex_lock(&mutex, &ctx);
+	if (!trylock) {
+		ret = ww_mutex_lock(&mutex, &ctx);
+		if (ret) {
+			pr_err("%s: initial lock failed!\n", __func__);
+			goto out;
+		}
+	} else {
+		ret = !ww_mutex_trylock(&mutex, &ctx);
+		if (ret) {
+			pr_err("%s: initial trylock failed!\n", __func__);
+			goto out;
+		}
+	}
 
-	if (ww_mutex_trylock(&mutex))  {
-		pr_err("%s: trylocked itself!\n", __func__);
+	if (ww_mutex_trylock(&mutex, NULL))  {
+		pr_err("%s: trylocked itself without context from %s!\n", __func__, from);
+		ww_mutex_unlock(&mutex);
+		ret = -EINVAL;
+		goto out;
+	}
+
+	if (ww_mutex_trylock(&mutex, &ctx))  {
+		pr_err("%s: trylocked itself with context from %s!\n", __func__, from);
 		ww_mutex_unlock(&mutex);
 		ret = -EINVAL;
 		goto out;
@@ -129,17 +158,17 @@ static int test_aa(void)
 
 	ret = ww_mutex_lock(&mutex, &ctx);
 	if (ret != -EALREADY) {
-		pr_err("%s: missed deadlock for recursing, ret=%d\n",
-		       __func__, ret);
+		pr_err("%s: missed deadlock for recursing, ret=%d from %s\n",
+		       __func__, ret, from);
 		if (!ret)
 			ww_mutex_unlock(&mutex);
 		ret = -EINVAL;
 		goto out;
 	}
 
+	ww_mutex_unlock(&mutex);
 	ret = 0;
 out:
-	ww_mutex_unlock(&mutex);
 	ww_acquire_fini(&ctx);
 	return ret;
 }
@@ -150,7 +179,7 @@ struct test_abba {
 	struct ww_mutex b_mutex;
 	struct completion a_ready;
 	struct completion b_ready;
-	bool resolve;
+	bool resolve, trylock;
 	int result;
 };
 
@@ -160,8 +189,13 @@ static void test_abba_work(struct work_struct *work)
 	struct ww_acquire_ctx ctx;
 	int err;
 
-	ww_acquire_init(&ctx, &ww_class);
-	ww_mutex_lock(&abba->b_mutex, &ctx);
+	ww_acquire_init_noinject(&ctx, &ww_class);
+	if (!abba->trylock)
+		ww_mutex_lock(&abba->b_mutex, &ctx);
+	else
+		WARN_ON(!ww_mutex_trylock(&abba->b_mutex, &ctx));
+
+	WARN_ON(READ_ONCE(abba->b_mutex.ctx) != &ctx);
 
 	complete(&abba->b_ready);
 	wait_for_completion(&abba->a_ready);
@@ -181,7 +215,7 @@ static void test_abba_work(struct work_struct *work)
 	abba->result = err;
 }
 
-static int test_abba(bool resolve)
+static int test_abba(bool trylock, bool resolve)
 {
 	struct test_abba abba;
 	struct ww_acquire_ctx ctx;
@@ -192,12 +226,18 @@ static int test_abba(bool resolve)
 	INIT_WORK_ONSTACK(&abba.work, test_abba_work);
 	init_completion(&abba.a_ready);
 	init_completion(&abba.b_ready);
+	abba.trylock = trylock;
 	abba.resolve = resolve;
 
 	schedule_work(&abba.work);
 
-	ww_acquire_init(&ctx, &ww_class);
-	ww_mutex_lock(&abba.a_mutex, &ctx);
+	ww_acquire_init_noinject(&ctx, &ww_class);
+	if (!trylock)
+		ww_mutex_lock(&abba.a_mutex, &ctx);
+	else
+		WARN_ON(!ww_mutex_trylock(&abba.a_mutex, &ctx));
+
+	WARN_ON(READ_ONCE(abba.a_mutex.ctx) != &ctx);
 
 	complete(&abba.a_ready);
 	wait_for_completion(&abba.b_ready);
@@ -249,7 +289,7 @@ static void test_cycle_work(struct work_struct *work)
 	struct ww_acquire_ctx ctx;
 	int err, erra = 0;
 
-	ww_acquire_init(&ctx, &ww_class);
+	ww_acquire_init_noinject(&ctx, &ww_class);
 	ww_mutex_lock(&cycle->a_mutex, &ctx);
 
 	complete(cycle->a_signal);
@@ -581,7 +621,9 @@ static int stress(int nlocks, int nthreads, unsigned int flags)
 static int __init test_ww_mutex_init(void)
 {
 	int ncpus = num_online_cpus();
-	int ret;
+	int ret, i;
+
+	printk(KERN_INFO "Beginning ww mutex selftests\n");
 
 	wq = alloc_workqueue("test-ww_mutex", WQ_UNBOUND, 0);
 	if (!wq)
@@ -591,17 +633,19 @@ static int __init test_ww_mutex_init(void)
 	if (ret)
 		return ret;
 
-	ret = test_aa();
+	ret = test_aa(false);
 	if (ret)
 		return ret;
 
-	ret = test_abba(false);
+	ret = test_aa(true);
 	if (ret)
 		return ret;
 
-	ret = test_abba(true);
-	if (ret)
-		return ret;
+	for (i = 0; i < 4; i++) {
+		ret = test_abba(i & 1, i & 2);
+		if (ret)
+			return ret;
+	}
 
 	ret = test_cycle(ncpus);
 	if (ret)
@@ -619,6 +663,7 @@ static int __init test_ww_mutex_init(void)
 	if (ret)
 		return ret;
 
+	printk(KERN_INFO "All ww mutex selftests passed\n");
 	return 0;
 }
 
diff --git a/kernel/locking/ww_rt_mutex.c b/kernel/locking/ww_rt_mutex.c
index 3f1fff7..0e00205 100644
--- a/kernel/locking/ww_rt_mutex.c
+++ b/kernel/locking/ww_rt_mutex.c
@@ -9,6 +9,31 @@
 #define WW_RT
 #include "rtmutex.c"
 
+int ww_mutex_trylock(struct ww_mutex *lock, struct ww_acquire_ctx *ww_ctx)
+{
+	struct rt_mutex *rtm = &lock->base;
+
+	if (!ww_ctx)
+		return rt_mutex_trylock(rtm);
+
+	/*
+	 * Reset the wounded flag after a kill. No other process can
+	 * race and wound us here, since they can't have a valid owner
+	 * pointer if we don't have any locks held.
+	 */
+	if (ww_ctx->acquired == 0)
+		ww_ctx->wounded = 0;
+
+	if (__rt_mutex_trylock(&rtm->rtmutex)) {
+		ww_mutex_set_context_fastpath(lock, ww_ctx);
+		mutex_acquire_nest(&rtm->dep_map, 0, 1, ww_ctx->dep_map, _RET_IP_);
+		return 1;
+	}
+
+	return 0;
+}
+EXPORT_SYMBOL(ww_mutex_trylock);
+
 static int __sched
 __ww_rt_mutex_lock(struct ww_mutex *lock, struct ww_acquire_ctx *ww_ctx,
 		   unsigned int state, unsigned long ip)
diff --git a/kernel/rcu/update.c b/kernel/rcu/update.c
index c21b38c..690b0ce 100644
--- a/kernel/rcu/update.c
+++ b/kernel/rcu/update.c
@@ -247,7 +247,7 @@ struct lockdep_map rcu_lock_map = {
 	.name = "rcu_read_lock",
 	.key = &rcu_lock_key,
 	.wait_type_outer = LD_WAIT_FREE,
-	.wait_type_inner = LD_WAIT_CONFIG, /* XXX PREEMPT_RCU ? */
+	.wait_type_inner = LD_WAIT_CONFIG, /* PREEMPT_RT implies PREEMPT_RCU */
 };
 EXPORT_SYMBOL_GPL(rcu_lock_map);
 
@@ -256,7 +256,7 @@ struct lockdep_map rcu_bh_lock_map = {
 	.name = "rcu_read_lock_bh",
 	.key = &rcu_bh_lock_key,
 	.wait_type_outer = LD_WAIT_FREE,
-	.wait_type_inner = LD_WAIT_CONFIG, /* PREEMPT_LOCK also makes BH preemptible */
+	.wait_type_inner = LD_WAIT_CONFIG, /* PREEMPT_RT makes BH preemptible. */
 };
 EXPORT_SYMBOL_GPL(rcu_bh_lock_map);
 
diff --git a/kernel/sched/core.c b/kernel/sched/core.c
index 1bba412..8d3fa07 100644
--- a/kernel/sched/core.c
+++ b/kernel/sched/core.c
@@ -9468,14 +9468,8 @@ void __init sched_init(void)
 }
 
 #ifdef CONFIG_DEBUG_ATOMIC_SLEEP
-static inline int preempt_count_equals(int preempt_offset)
-{
-	int nested = preempt_count() + rcu_preempt_depth();
 
-	return (nested == preempt_offset);
-}
-
-void __might_sleep(const char *file, int line, int preempt_offset)
+void __might_sleep(const char *file, int line)
 {
 	unsigned int state = get_current_state();
 	/*
@@ -9489,11 +9483,32 @@ void __might_sleep(const char *file, int line, int preempt_offset)
 			(void *)current->task_state_change,
 			(void *)current->task_state_change);
 
-	___might_sleep(file, line, preempt_offset);
+	__might_resched(file, line, 0);
 }
 EXPORT_SYMBOL(__might_sleep);
 
-void ___might_sleep(const char *file, int line, int preempt_offset)
+static void print_preempt_disable_ip(int preempt_offset, unsigned long ip)
+{
+	if (!IS_ENABLED(CONFIG_DEBUG_PREEMPT))
+		return;
+
+	if (preempt_count() == preempt_offset)
+		return;
+
+	pr_err("Preemption disabled at:");
+	print_ip_sym(KERN_ERR, ip);
+}
+
+static inline bool resched_offsets_ok(unsigned int offsets)
+{
+	unsigned int nested = preempt_count();
+
+	nested += rcu_preempt_depth() << MIGHT_RESCHED_RCU_SHIFT;
+
+	return nested == offsets;
+}
+
+void __might_resched(const char *file, int line, unsigned int offsets)
 {
 	/* Ratelimiting timestamp: */
 	static unsigned long prev_jiffy;
@@ -9503,7 +9518,7 @@ void ___might_sleep(const char *file, int line, int preempt_offset)
 	/* WARN_ON_ONCE() by default, no rate limit required: */
 	rcu_sleep_check();
 
-	if ((preempt_count_equals(preempt_offset) && !irqs_disabled() &&
+	if ((resched_offsets_ok(offsets) && !irqs_disabled() &&
 	     !is_idle_task(current) && !current->non_block_count) ||
 	    system_state == SYSTEM_BOOTING || system_state > SYSTEM_RUNNING ||
 	    oops_in_progress)
@@ -9516,29 +9531,33 @@ void ___might_sleep(const char *file, int line, int preempt_offset)
 	/* Save this before calling printk(), since that will clobber it: */
 	preempt_disable_ip = get_preempt_disable_ip(current);
 
-	printk(KERN_ERR
-		"BUG: sleeping function called from invalid context at %s:%d\n",
-			file, line);
-	printk(KERN_ERR
-		"in_atomic(): %d, irqs_disabled(): %d, non_block: %d, pid: %d, name: %s\n",
-			in_atomic(), irqs_disabled(), current->non_block_count,
-			current->pid, current->comm);
+	pr_err("BUG: sleeping function called from invalid context at %s:%d\n",
+	       file, line);
+	pr_err("in_atomic(): %d, irqs_disabled(): %d, non_block: %d, pid: %d, name: %s\n",
+	       in_atomic(), irqs_disabled(), current->non_block_count,
+	       current->pid, current->comm);
+	pr_err("preempt_count: %x, expected: %x\n", preempt_count(),
+	       offsets & MIGHT_RESCHED_PREEMPT_MASK);
+
+	if (IS_ENABLED(CONFIG_PREEMPT_RCU)) {
+		pr_err("RCU nest depth: %d, expected: %u\n",
+		       rcu_preempt_depth(), offsets >> MIGHT_RESCHED_RCU_SHIFT);
+	}
 
 	if (task_stack_end_corrupted(current))
-		printk(KERN_EMERG "Thread overran stack, or stack corrupted\n");
+		pr_emerg("Thread overran stack, or stack corrupted\n");
 
 	debug_show_held_locks(current);
 	if (irqs_disabled())
 		print_irqtrace_events(current);
-	if (IS_ENABLED(CONFIG_DEBUG_PREEMPT)
-	    && !preempt_count_equals(preempt_offset)) {
-		pr_err("Preemption disabled at:");
-		print_ip_sym(KERN_ERR, preempt_disable_ip);
-	}
+
+	print_preempt_disable_ip(offsets & MIGHT_RESCHED_PREEMPT_MASK,
+				 preempt_disable_ip);
+
 	dump_stack();
 	add_taint(TAINT_WARN, LOCKDEP_STILL_OK);
 }
-EXPORT_SYMBOL(___might_sleep);
+EXPORT_SYMBOL(__might_resched);
 
 void __cant_sleep(const char *file, int line, int preempt_offset)
 {
diff --git a/lib/locking-selftest.c b/lib/locking-selftest.c
index 161108e..71652e1 100644
--- a/lib/locking-selftest.c
+++ b/lib/locking-selftest.c
@@ -258,7 +258,7 @@ static void init_shared_classes(void)
 #define WWAF(x)			ww_acquire_fini(x)
 
 #define WWL(x, c)		ww_mutex_lock(x, c)
-#define WWT(x)			ww_mutex_trylock(x)
+#define WWT(x)			ww_mutex_trylock(x, NULL)
 #define WWL1(x)			ww_mutex_lock(x, NULL)
 #define WWU(x)			ww_mutex_unlock(x)
 
diff --git a/mm/memory.c b/mm/memory.c
index 25fc46e..1cd1792 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -5255,7 +5255,7 @@ void __might_fault(const char *file, int line)
 		return;
 	if (pagefault_disabled())
 		return;
-	__might_sleep(file, line, 0);
+	__might_sleep(file, line);
 #if defined(CONFIG_DEBUG_ATOMIC_SLEEP)
 	if (current->mm)
 		might_lock_read(&current->mm->mmap_lock);