io_uring/io-wq: don't use static creds/mm assignments

We currently setup the io_wq with a static set of mm and creds. Even for
a single-use io-wq per io_uring, this is suboptimal as we have may have
multiple enters of the ring. For sharing the io-wq backend, it doesn't
work at all.

Switch to passing in the creds and mm when the work item is setup. This
means that async work is no longer deferred to the io_uring mm and creds,
it is done with the current mm and creds.

Flag this behavior with IORING_FEAT_CUR_PERSONALITY, so applications know
they can rely on the current personality (mm and creds) being the same
for direct issue and async issue.

Reviewed-by: Stefan Metzmacher <metze@samba.org>
Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/fs/io-wq.c b/fs/io-wq.c
index 54e270a..7ccedc82 100644
--- a/fs/io-wq.c
+++ b/fs/io-wq.c
@@ -56,7 +56,8 @@ struct io_worker {
 
 	struct rcu_head rcu;
 	struct mm_struct *mm;
-	const struct cred *creds;
+	const struct cred *cur_creds;
+	const struct cred *saved_creds;
 	struct files_struct *restore_files;
 };
 
@@ -109,8 +110,6 @@ struct io_wq {
 
 	struct task_struct *manager;
 	struct user_struct *user;
-	const struct cred *creds;
-	struct mm_struct *mm;
 	refcount_t refs;
 	struct completion done;
 
@@ -137,9 +136,9 @@ static bool __io_worker_unuse(struct io_wqe *wqe, struct io_worker *worker)
 {
 	bool dropped_lock = false;
 
-	if (worker->creds) {
-		revert_creds(worker->creds);
-		worker->creds = NULL;
+	if (worker->saved_creds) {
+		revert_creds(worker->saved_creds);
+		worker->cur_creds = worker->saved_creds = NULL;
 	}
 
 	if (current->files != worker->restore_files) {
@@ -398,6 +397,43 @@ static struct io_wq_work *io_get_next_work(struct io_wqe *wqe, unsigned *hash)
 	return NULL;
 }
 
+static void io_wq_switch_mm(struct io_worker *worker, struct io_wq_work *work)
+{
+	if (worker->mm) {
+		unuse_mm(worker->mm);
+		mmput(worker->mm);
+		worker->mm = NULL;
+	}
+	if (!work->mm) {
+		set_fs(KERNEL_DS);
+		return;
+	}
+	if (mmget_not_zero(work->mm)) {
+		use_mm(work->mm);
+		if (!worker->mm)
+			set_fs(USER_DS);
+		worker->mm = work->mm;
+		/* hang on to this mm */
+		work->mm = NULL;
+		return;
+	}
+
+	/* failed grabbing mm, ensure work gets cancelled */
+	work->flags |= IO_WQ_WORK_CANCEL;
+}
+
+static void io_wq_switch_creds(struct io_worker *worker,
+			       struct io_wq_work *work)
+{
+	const struct cred *old_creds = override_creds(work->creds);
+
+	worker->cur_creds = work->creds;
+	if (worker->saved_creds)
+		put_cred(old_creds); /* creds set by previous switch */
+	else
+		worker->saved_creds = old_creds;
+}
+
 static void io_worker_handle_work(struct io_worker *worker)
 	__releases(wqe->lock)
 {
@@ -446,18 +482,10 @@ static void io_worker_handle_work(struct io_worker *worker)
 			current->files = work->files;
 			task_unlock(current);
 		}
-		if ((work->flags & IO_WQ_WORK_NEEDS_USER) && !worker->mm &&
-		    wq->mm) {
-			if (mmget_not_zero(wq->mm)) {
-				use_mm(wq->mm);
-				set_fs(USER_DS);
-				worker->mm = wq->mm;
-			} else {
-				work->flags |= IO_WQ_WORK_CANCEL;
-			}
-		}
-		if (!worker->creds)
-			worker->creds = override_creds(wq->creds);
+		if (work->mm != worker->mm)
+			io_wq_switch_mm(worker, work);
+		if (worker->cur_creds != work->creds)
+			io_wq_switch_creds(worker, work);
 		/*
 		 * OK to set IO_WQ_WORK_CANCEL even for uncancellable work,
 		 * the worker function will do the right thing.
@@ -1037,7 +1065,6 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
 
 	/* caller must already hold a reference to this */
 	wq->user = data->user;
-	wq->creds = data->creds;
 
 	for_each_node(node) {
 		struct io_wqe *wqe;
@@ -1064,9 +1091,6 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
 
 	init_completion(&wq->done);
 
-	/* caller must have already done mmgrab() on this mm */
-	wq->mm = data->mm;
-
 	wq->manager = kthread_create(io_wq_manager, wq, "io_wq_manager");
 	if (!IS_ERR(wq->manager)) {
 		wake_up_process(wq->manager);
diff --git a/fs/io-wq.h b/fs/io-wq.h
index 1cd039a..167316ad 100644
--- a/fs/io-wq.h
+++ b/fs/io-wq.h
@@ -7,7 +7,6 @@ enum {
 	IO_WQ_WORK_CANCEL	= 1,
 	IO_WQ_WORK_HAS_MM	= 2,
 	IO_WQ_WORK_HASHED	= 4,
-	IO_WQ_WORK_NEEDS_USER	= 8,
 	IO_WQ_WORK_NEEDS_FILES	= 16,
 	IO_WQ_WORK_UNBOUND	= 32,
 	IO_WQ_WORK_INTERNAL	= 64,
@@ -74,6 +73,8 @@ struct io_wq_work {
 	};
 	void (*func)(struct io_wq_work **);
 	struct files_struct *files;
+	struct mm_struct *mm;
+	const struct cred *creds;
 	unsigned flags;
 };
 
@@ -83,15 +84,15 @@ struct io_wq_work {
 		(work)->func = _func;			\
 		(work)->flags = 0;			\
 		(work)->files = NULL;			\
+		(work)->mm = NULL;			\
+		(work)->creds = NULL;			\
 	} while (0)					\
 
 typedef void (get_work_fn)(struct io_wq_work *);
 typedef void (put_work_fn)(struct io_wq_work *);
 
 struct io_wq_data {
-	struct mm_struct *mm;
 	struct user_struct *user;
-	const struct cred *creds;
 
 	get_work_fn *get_work;
 	put_work_fn *put_work;
diff --git a/fs/io_uring.c b/fs/io_uring.c
index 1dd2030..0ea3691 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -875,6 +875,29 @@ static void __io_commit_cqring(struct io_ring_ctx *ctx)
 	}
 }
 
+static inline void io_req_work_grab_env(struct io_kiocb *req,
+					const struct io_op_def *def)
+{
+	if (!req->work.mm && def->needs_mm) {
+		mmgrab(current->mm);
+		req->work.mm = current->mm;
+	}
+	if (!req->work.creds)
+		req->work.creds = get_current_cred();
+}
+
+static inline void io_req_work_drop_env(struct io_kiocb *req)
+{
+	if (req->work.mm) {
+		mmdrop(req->work.mm);
+		req->work.mm = NULL;
+	}
+	if (req->work.creds) {
+		put_cred(req->work.creds);
+		req->work.creds = NULL;
+	}
+}
+
 static inline bool io_prep_async_work(struct io_kiocb *req,
 				      struct io_kiocb **link)
 {
@@ -888,8 +911,8 @@ static inline bool io_prep_async_work(struct io_kiocb *req,
 		if (def->unbound_nonreg_file)
 			req->work.flags |= IO_WQ_WORK_UNBOUND;
 	}
-	if (def->needs_mm)
-		req->work.flags |= IO_WQ_WORK_NEEDS_USER;
+
+	io_req_work_grab_env(req, def);
 
 	*link = io_prep_linked_timeout(req);
 	return do_hashed;
@@ -1180,6 +1203,8 @@ static void __io_req_aux_free(struct io_kiocb *req)
 		else
 			fput(req->file);
 	}
+
+	io_req_work_drop_env(req);
 }
 
 static void __io_free_req(struct io_kiocb *req)
@@ -3963,6 +3988,8 @@ static int io_req_defer_prep(struct io_kiocb *req,
 {
 	ssize_t ret = 0;
 
+	io_req_work_grab_env(req, &io_op_defs[req->opcode]);
+
 	switch (req->opcode) {
 	case IORING_OP_NOP:
 		break;
@@ -5725,9 +5752,7 @@ static int io_sq_offload_start(struct io_ring_ctx *ctx,
 		goto err;
 	}
 
-	data.mm = ctx->sqo_mm;
 	data.user = ctx->user;
-	data.creds = ctx->creds;
 	data.get_work = io_get_work;
 	data.put_work = io_put_work;
 
@@ -6535,7 +6560,8 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p)
 		goto err;
 
 	p->features = IORING_FEAT_SINGLE_MMAP | IORING_FEAT_NODROP |
-			IORING_FEAT_SUBMIT_STABLE | IORING_FEAT_RW_CUR_POS;
+			IORING_FEAT_SUBMIT_STABLE | IORING_FEAT_RW_CUR_POS |
+			IORING_FEAT_CUR_PERSONALITY;
 	trace_io_uring_create(ret, ctx, p->sq_entries, p->cq_entries, p->flags);
 	return ret;
 err: