io_uring: Always REQ_F_FREE_SQE for allocated sqe

Always mark requests with allocated sqe and deallocate it in
__io_free_req(). It's easier to follow and doesn't add edge cases.

Signed-off-by: Pavel Begunkov <asml.silence@gmail.com>
Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/fs/io_uring.c b/fs/io_uring.c
index 09fc295..6d52a4d 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -834,6 +834,8 @@ static void __io_free_req(struct io_kiocb *req)
 {
 	struct io_ring_ctx *ctx = req->ctx;
 
+	if (req->flags & REQ_F_FREE_SQE)
+		kfree(req->submit.sqe);
 	if (req->file && !(req->flags & REQ_F_FIXED_FILE))
 		fput(req->file);
 	if (req->flags & REQ_F_INFLIGHT) {
@@ -929,16 +931,11 @@ static void io_fail_links(struct io_kiocb *req)
 	spin_lock_irqsave(&ctx->completion_lock, flags);
 
 	while (!list_empty(&req->link_list)) {
-		const struct io_uring_sqe *sqe_to_free = NULL;
-
 		link = list_first_entry(&req->link_list, struct io_kiocb, list);
 		list_del_init(&link->list);
 
 		trace_io_uring_fail_link(req, link);
 
-		if (link->flags & REQ_F_FREE_SQE)
-			sqe_to_free = link->submit.sqe;
-
 		if ((req->flags & REQ_F_LINK_TIMEOUT) &&
 		    link->submit.sqe->opcode == IORING_OP_LINK_TIMEOUT) {
 			io_link_cancel_timeout(link);
@@ -946,7 +943,6 @@ static void io_fail_links(struct io_kiocb *req)
 			io_cqring_fill_event(link, -ECANCELED);
 			__io_double_put_req(link);
 		}
-		kfree(sqe_to_free);
 		req->flags &= ~REQ_F_LINK_TIMEOUT;
 	}
 
@@ -1089,7 +1085,8 @@ static void io_iopoll_complete(struct io_ring_ctx *ctx, unsigned int *nr_events,
 			 * completions for those, only batch free for fixed
 			 * file and non-linked commands.
 			 */
-			if (((req->flags & (REQ_F_FIXED_FILE|REQ_F_LINK)) ==
+			if (((req->flags &
+				(REQ_F_FIXED_FILE|REQ_F_LINK|REQ_F_FREE_SQE)) ==
 			    REQ_F_FIXED_FILE) && !io_is_fallback_req(req)) {
 				reqs[to_free++] = req;
 				if (to_free == ARRAY_SIZE(reqs))
@@ -2582,6 +2579,7 @@ static int io_req_defer(struct io_kiocb *req)
 	}
 
 	memcpy(sqe_copy, sqe, sizeof(*sqe_copy));
+	req->flags |= REQ_F_FREE_SQE;
 	req->submit.sqe = sqe_copy;
 
 	trace_io_uring_defer(ctx, req, false);
@@ -2676,7 +2674,6 @@ static void io_wq_submit_work(struct io_wq_work **workptr)
 	struct io_wq_work *work = *workptr;
 	struct io_kiocb *req = container_of(work, struct io_kiocb, work);
 	struct sqe_submit *s = &req->submit;
-	const struct io_uring_sqe *sqe = s->sqe;
 	struct io_kiocb *nxt = NULL;
 	int ret = 0;
 
@@ -2712,9 +2709,6 @@ static void io_wq_submit_work(struct io_wq_work **workptr)
 		io_put_req(req);
 	}
 
-	/* async context always use a copy of the sqe */
-	kfree(sqe);
-
 	/* if a dependent link is ready, pass it back */
 	if (!ret && nxt) {
 		struct io_kiocb *link;
@@ -2913,23 +2907,24 @@ static void __io_queue_sqe(struct io_kiocb *req)
 		struct io_uring_sqe *sqe_copy;
 
 		sqe_copy = kmemdup(s->sqe, sizeof(*sqe_copy), GFP_KERNEL);
-		if (sqe_copy) {
-			s->sqe = sqe_copy;
-			if (req->work.flags & IO_WQ_WORK_NEEDS_FILES) {
-				ret = io_grab_files(req);
-				if (ret) {
-					kfree(sqe_copy);
-					goto err;
-				}
-			}
+		if (!sqe_copy)
+			goto err;
 
-			/*
-			 * Queued up for async execution, worker will release
-			 * submit reference when the iocb is actually submitted.
-			 */
-			io_queue_async_work(req);
-			return;
+		s->sqe = sqe_copy;
+		req->flags |= REQ_F_FREE_SQE;
+
+		if (req->work.flags & IO_WQ_WORK_NEEDS_FILES) {
+			ret = io_grab_files(req);
+			if (ret)
+				goto err;
 		}
+
+		/*
+		 * Queued up for async execution, worker will release
+		 * submit reference when the iocb is actually submitted.
+		 */
+		io_queue_async_work(req);
+		return;
 	}
 
 err:
@@ -3024,7 +3019,6 @@ static void io_queue_link_head(struct io_kiocb *req, struct io_kiocb *shadow)
 static void io_submit_sqe(struct io_kiocb *req, struct io_submit_state *state,
 			  struct io_kiocb **link)
 {
-	struct io_uring_sqe *sqe_copy;
 	struct sqe_submit *s = &req->submit;
 	struct io_ring_ctx *ctx = req->ctx;
 	int ret;
@@ -3054,6 +3048,7 @@ static void io_submit_sqe(struct io_kiocb *req, struct io_submit_state *state,
 	 */
 	if (*link) {
 		struct io_kiocb *prev = *link;
+		struct io_uring_sqe *sqe_copy;
 
 		if (READ_ONCE(s->sqe->opcode) == IORING_OP_LINK_TIMEOUT) {
 			ret = io_timeout_setup(req);