io_uring: fix exiting io_req_task_work_add leaks

If one entered io_req_task_work_add() not seeing PF_EXITING, it will set
a ->task_state bit and try task_work_add(), which may fail by that
moment. If that happens the function would try to cancel the request.

However, in a meanwhile there might come other io_req_task_work_add()
callers, which will see the bit set and leave their requests in the
list, which will never be executed.

Don't propagate an error, but clear the bit first and then fallback
all requests that we can splice from the list. The callback functions
have to be able to deal with PF_EXITING, so poll and apoll was modified
via changing io_poll_rewait().

Fixes: 7cbf1722d5fc ("io_uring: provide FIFO ordering for task_work")
Reported-by: Jens Axboe <axboe@kernel.dk>
Signed-off-by: Pavel Begunkov <asml.silence@gmail.com>
Link: https://lore.kernel.org/r/060002f19f1fdbd130ba24aef818ea4d3080819b.1625142209.git.asml.silence@gmail.com
Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/fs/io_uring.c b/fs/io_uring.c
index 5b840bb..8818560 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -1952,17 +1952,13 @@ static void tctx_task_work(struct callback_head *cb)
 	ctx_flush_and_put(ctx);
 }
 
-static int io_req_task_work_add(struct io_kiocb *req)
+static void io_req_task_work_add(struct io_kiocb *req)
 {
 	struct task_struct *tsk = req->task;
 	struct io_uring_task *tctx = tsk->io_uring;
 	enum task_work_notify_mode notify;
-	struct io_wq_work_node *node, *prev;
+	struct io_wq_work_node *node;
 	unsigned long flags;
-	int ret = 0;
-
-	if (unlikely(tsk->flags & PF_EXITING))
-		return -ESRCH;
 
 	WARN_ON_ONCE(!tctx);
 
@@ -1973,7 +1969,9 @@ static int io_req_task_work_add(struct io_kiocb *req)
 	/* task_work already pending, we're done */
 	if (test_bit(0, &tctx->task_state) ||
 	    test_and_set_bit(0, &tctx->task_state))
-		return 0;
+		return;
+	if (unlikely(tsk->flags & PF_EXITING))
+		goto fail;
 
 	/*
 	 * SQPOLL kernel thread doesn't need notification, just a wakeup. For
@@ -1982,36 +1980,24 @@ static int io_req_task_work_add(struct io_kiocb *req)
 	 * will do the job.
 	 */
 	notify = (req->ctx->flags & IORING_SETUP_SQPOLL) ? TWA_NONE : TWA_SIGNAL;
-
 	if (!task_work_add(tsk, &tctx->task_work, notify)) {
 		wake_up_process(tsk);
-		return 0;
+		return;
 	}
-
-	/*
-	 * Slow path - we failed, find and delete work. if the work is not
-	 * in the list, it got run and we're fine.
-	 */
-	spin_lock_irqsave(&tctx->task_lock, flags);
-	wq_list_for_each(node, prev, &tctx->task_list) {
-		if (&req->io_task_work.node == node) {
-			wq_list_del(&tctx->task_list, node, prev);
-			ret = 1;
-			break;
-		}
-	}
-	spin_unlock_irqrestore(&tctx->task_lock, flags);
+fail:
 	clear_bit(0, &tctx->task_state);
-	return ret;
-}
+	spin_lock_irqsave(&tctx->task_lock, flags);
+	node = tctx->task_list.first;
+	INIT_WQ_LIST(&tctx->task_list);
+	spin_unlock_irqrestore(&tctx->task_lock, flags);
 
-static void io_req_task_work_add_fallback(struct io_kiocb *req,
-					  io_req_tw_func_t cb)
-{
-	req->io_task_work.func = cb;
-	if (llist_add(&req->io_task_work.fallback_node,
-		      &req->ctx->fallback_llist))
-		schedule_delayed_work(&req->ctx->fallback_work, 1);
+	while (node) {
+		req = container_of(node, struct io_kiocb, io_task_work.node);
+		node = node->next;
+		if (llist_add(&req->io_task_work.fallback_node,
+			      &req->ctx->fallback_llist))
+			schedule_delayed_work(&req->ctx->fallback_work, 1);
+	}
 }
 
 static void io_req_task_cancel(struct io_kiocb *req)
@@ -2041,17 +2027,13 @@ static void io_req_task_queue_fail(struct io_kiocb *req, int ret)
 {
 	req->result = ret;
 	req->io_task_work.func = io_req_task_cancel;
-
-	if (unlikely(io_req_task_work_add(req)))
-		io_req_task_work_add_fallback(req, io_req_task_cancel);
+	io_req_task_work_add(req);
 }
 
 static void io_req_task_queue(struct io_kiocb *req)
 {
 	req->io_task_work.func = io_req_task_submit;
-
-	if (unlikely(io_req_task_work_add(req)))
-		io_req_task_queue_fail(req, -ECANCELED);
+	io_req_task_work_add(req);
 }
 
 static inline void io_queue_next(struct io_kiocb *req)
@@ -2165,8 +2147,7 @@ static inline void io_put_req(struct io_kiocb *req)
 static void io_free_req_deferred(struct io_kiocb *req)
 {
 	req->io_task_work.func = io_free_req;
-	if (unlikely(io_req_task_work_add(req)))
-		io_req_task_work_add_fallback(req, io_free_req);
+	io_req_task_work_add(req);
 }
 
 static inline void io_put_req_deferred(struct io_kiocb *req, int refs)
@@ -4823,8 +4804,6 @@ struct io_poll_table {
 static int __io_async_wake(struct io_kiocb *req, struct io_poll_iocb *poll,
 			   __poll_t mask, io_req_tw_func_t func)
 {
-	int ret;
-
 	/* for instances that support it check for an event match first: */
 	if (mask && !(mask & poll->events))
 		return 0;
@@ -4842,11 +4821,7 @@ static int __io_async_wake(struct io_kiocb *req, struct io_poll_iocb *poll,
 	 * of executing it. We can't safely execute it anyway, as we may not
 	 * have the needed state needed for it anyway.
 	 */
-	ret = io_req_task_work_add(req);
-	if (unlikely(ret)) {
-		WRITE_ONCE(poll->canceled, true);
-		io_req_task_work_add_fallback(req, func);
-	}
+	io_req_task_work_add(req);
 	return 1;
 }
 
@@ -4855,6 +4830,9 @@ static bool io_poll_rewait(struct io_kiocb *req, struct io_poll_iocb *poll)
 {
 	struct io_ring_ctx *ctx = req->ctx;
 
+	if (unlikely(req->task->flags & PF_EXITING))
+		WRITE_ONCE(poll->canceled, true);
+
 	if (!req->result && !READ_ONCE(poll->canceled)) {
 		struct poll_table_struct pt = { ._key = poll->events };