io_uring: off timeouts based only on completions

Offset timeouts wait not for sqe->off non-timeout CQEs, but rather
sqe->off + number of prior inflight requests. Wait exactly for
sqe->off non-timeout completions

Reported-by: Jens Axboe <axboe@kernel.dk>
Signed-off-by: Pavel Begunkov <asml.silence@gmail.com>
Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/fs/io_uring.c b/fs/io_uring.c
index ffd0ec7..9f11feb 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -394,7 +394,8 @@ struct io_timeout {
 	struct file			*file;
 	u64				addr;
 	int				flags;
-	u32				count;
+	u32				off;
+	u32				target_seq;
 };
 
 struct io_rw {
@@ -1124,8 +1125,10 @@ static void io_flush_timeouts(struct io_ring_ctx *ctx)
 
 		if (req->flags & REQ_F_TIMEOUT_NOSEQ)
 			break;
-		if (__req_need_defer(req))
+		if (req->timeout.target_seq != ctx->cached_cq_tail
+					- atomic_read(&ctx->cq_timeouts))
 			break;
+
 		list_del_init(&req->list);
 		io_kill_timeout(req);
 	}
@@ -4660,20 +4663,8 @@ static enum hrtimer_restart io_timeout_fn(struct hrtimer *timer)
 	 * We could be racing with timeout deletion. If the list is empty,
 	 * then timeout lookup already found it and will be handling it.
 	 */
-	if (!list_empty(&req->list)) {
-		struct io_kiocb *prev;
-
-		/*
-		 * Adjust the reqs sequence before the current one because it
-		 * will consume a slot in the cq_ring and the cq_tail
-		 * pointer will be increased, otherwise other timeout reqs may
-		 * return in advance without waiting for enough wait_nr.
-		 */
-		prev = req;
-		list_for_each_entry_continue_reverse(prev, &ctx->timeout_list, list)
-			prev->sequence++;
+	if (!list_empty(&req->list))
 		list_del_init(&req->list);
-	}
 
 	io_cqring_fill_event(req, -ETIME);
 	io_commit_cqring(ctx);
@@ -4765,7 +4756,7 @@ static int io_timeout_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe,
 	if (flags & ~IORING_TIMEOUT_ABS)
 		return -EINVAL;
 
-	req->timeout.count = off;
+	req->timeout.off = off;
 
 	if (!req->io && io_alloc_async_ctx(req))
 		return -ENOMEM;
@@ -4789,13 +4780,10 @@ static int io_timeout_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe,
 static int io_timeout(struct io_kiocb *req)
 {
 	struct io_ring_ctx *ctx = req->ctx;
-	struct io_timeout_data *data;
+	struct io_timeout_data *data = &req->io->timeout;
 	struct list_head *entry;
-	unsigned span = 0;
-	u32 count = req->timeout.count;
-	u32 seq = req->sequence;
+	u32 tail, off = req->timeout.off;
 
-	data = &req->io->timeout;
 	spin_lock_irq(&ctx->completion_lock);
 
 	/*
@@ -4803,13 +4791,14 @@ static int io_timeout(struct io_kiocb *req)
 	 * timeout event to be satisfied. If it isn't set, then this is
 	 * a pure timeout request, sequence isn't used.
 	 */
-	if (!count) {
+	if (!off) {
 		req->flags |= REQ_F_TIMEOUT_NOSEQ;
 		entry = ctx->timeout_list.prev;
 		goto add;
 	}
 
-	req->sequence = seq + count;
+	tail = ctx->cached_cq_tail - atomic_read(&ctx->cq_timeouts);
+	req->timeout.target_seq = tail + off;
 
 	/*
 	 * Insertion sort, ensuring the first entry in the list is always
@@ -4817,39 +4806,13 @@ static int io_timeout(struct io_kiocb *req)
 	 */
 	list_for_each_prev(entry, &ctx->timeout_list) {
 		struct io_kiocb *nxt = list_entry(entry, struct io_kiocb, list);
-		unsigned nxt_seq;
-		long long tmp, tmp_nxt;
-		u32 nxt_offset = nxt->timeout.count;
 
 		if (nxt->flags & REQ_F_TIMEOUT_NOSEQ)
 			continue;
-
-		/*
-		 * Since seq + count can overflow, use type long
-		 * long to store it.
-		 */
-		tmp = (long long)seq + count;
-		nxt_seq = nxt->sequence - nxt_offset;
-		tmp_nxt = (long long)nxt_seq + nxt_offset;
-
-		/*
-		 * cached_sq_head may overflow, and it will never overflow twice
-		 * once there is some timeout req still be valid.
-		 */
-		if (seq < nxt_seq)
-			tmp += UINT_MAX;
-
-		if (tmp > tmp_nxt)
+		/* nxt.seq is behind @tail, otherwise would've been completed */
+		if (off >= nxt->timeout.target_seq - tail)
 			break;
-
-		/*
-		 * Sequence of reqs after the insert one and itself should
-		 * be adjusted because each timeout req consumes a slot.
-		 */
-		span++;
-		nxt->sequence++;
 	}
-	req->sequence -= span;
 add:
 	list_add(&req->list, entry);
 	data->timer.function = io_timeout_fn;