io_uring: don't rely on weak ->files references

Grab actual references to the files_struct. To avoid circular references
issues due to this, we add a per-task note that keeps track of what
io_uring contexts a task has used. When the tasks execs or exits its
assigned files, we cancel requests based on this tracking.

With that, we can grab proper references to the files table, and no
longer need to rely on stashing away ring_fd and ring_file to check
if the ring_fd may have been closed.

Cc: stable@vger.kernel.org # v5.5+
Reviewed-by: Pavel Begunkov <asml.silence@gmail.com>
Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/include/linux/io_uring.h b/include/linux/io_uring.h
new file mode 100644
index 0000000..c09135a
--- /dev/null
+++ b/include/linux/io_uring.h
@@ -0,0 +1,53 @@
+/* SPDX-License-Identifier: GPL-2.0-or-later */
+#ifndef _LINUX_IO_URING_H
+#define _LINUX_IO_URING_H
+
+#include <linux/sched.h>
+#include <linux/xarray.h>
+#include <linux/percpu-refcount.h>
+
+struct io_uring_task {
+	/* submission side */
+	struct xarray		xa;
+	struct wait_queue_head	wait;
+	struct file		*last;
+	atomic_long_t		req_issue;
+
+	/* completion side */
+	bool			in_idle ____cacheline_aligned_in_smp;
+	atomic_long_t		req_complete;
+};
+
+#if defined(CONFIG_IO_URING)
+void __io_uring_task_cancel(void);
+void __io_uring_files_cancel(struct files_struct *files);
+void __io_uring_free(struct task_struct *tsk);
+
+static inline void io_uring_task_cancel(void)
+{
+	if (current->io_uring && !xa_empty(&current->io_uring->xa))
+		__io_uring_task_cancel();
+}
+static inline void io_uring_files_cancel(struct files_struct *files)
+{
+	if (current->io_uring && !xa_empty(&current->io_uring->xa))
+		__io_uring_files_cancel(files);
+}
+static inline void io_uring_free(struct task_struct *tsk)
+{
+	if (tsk->io_uring)
+		__io_uring_free(tsk);
+}
+#else
+static inline void io_uring_task_cancel(void)
+{
+}
+static inline void io_uring_files_cancel(struct files_struct *files)
+{
+}
+static inline void io_uring_free(struct task_struct *tsk)
+{
+}
+#endif
+
+#endif