io_uring: provide means of removing buffers

We have IORING_OP_PROVIDE_BUFFERS, but the only way to remove buffers
is to trigger IO on them. The usual case of shrinking a buffer pool
would be to just not replenish the buffers when IO completes, and
instead just free it. But it may be nice to have a way to manually
remove a number of buffers from a given group, and
IORING_OP_REMOVE_BUFFERS provides that functionality.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/fs/io_uring.c b/fs/io_uring.c
index 455d53f..f131105 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -827,6 +827,7 @@ static const struct io_op_def io_op_defs[] = {
 		.unbound_nonreg_file	= 1,
 	},
 	[IORING_OP_PROVIDE_BUFFERS] = {},
+	[IORING_OP_REMOVE_BUFFERS] = {},
 };
 
 static void io_wq_submit_work(struct io_wq_work **workptr);
@@ -2995,6 +2996,75 @@ static int io_openat(struct io_kiocb *req, bool force_nonblock)
 	return io_openat2(req, force_nonblock);
 }
 
+static int io_remove_buffers_prep(struct io_kiocb *req,
+				  const struct io_uring_sqe *sqe)
+{
+	struct io_provide_buf *p = &req->pbuf;
+	u64 tmp;
+
+	if (sqe->ioprio || sqe->rw_flags || sqe->addr || sqe->len || sqe->off)
+		return -EINVAL;
+
+	tmp = READ_ONCE(sqe->fd);
+	if (!tmp || tmp > USHRT_MAX)
+		return -EINVAL;
+
+	memset(p, 0, sizeof(*p));
+	p->nbufs = tmp;
+	p->bgid = READ_ONCE(sqe->buf_group);
+	return 0;
+}
+
+static int __io_remove_buffers(struct io_ring_ctx *ctx, struct io_buffer *buf,
+			       int bgid, unsigned nbufs)
+{
+	unsigned i = 0;
+
+	/* shouldn't happen */
+	if (!nbufs)
+		return 0;
+
+	/* the head kbuf is the list itself */
+	while (!list_empty(&buf->list)) {
+		struct io_buffer *nxt;
+
+		nxt = list_first_entry(&buf->list, struct io_buffer, list);
+		list_del(&nxt->list);
+		kfree(nxt);
+		if (++i == nbufs)
+			return i;
+	}
+	i++;
+	kfree(buf);
+	idr_remove(&ctx->io_buffer_idr, bgid);
+
+	return i;
+}
+
+static int io_remove_buffers(struct io_kiocb *req, bool force_nonblock)
+{
+	struct io_provide_buf *p = &req->pbuf;
+	struct io_ring_ctx *ctx = req->ctx;
+	struct io_buffer *head;
+	int ret = 0;
+
+	io_ring_submit_lock(ctx, !force_nonblock);
+
+	lockdep_assert_held(&ctx->uring_lock);
+
+	ret = -ENOENT;
+	head = idr_find(&ctx->io_buffer_idr, p->bgid);
+	if (head)
+		ret = __io_remove_buffers(ctx, head, p->bgid, p->nbufs);
+
+	io_ring_submit_lock(ctx, !force_nonblock);
+	if (ret < 0)
+		req_set_fail_links(req);
+	io_cqring_add_event(req, ret);
+	io_put_req(req);
+	return 0;
+}
+
 static int io_provide_buffers_prep(struct io_kiocb *req,
 				   const struct io_uring_sqe *sqe)
 {
@@ -3070,15 +3140,7 @@ static int io_provide_buffers(struct io_kiocb *req, bool force_nonblock)
 		ret = idr_alloc(&ctx->io_buffer_idr, head, p->bgid, p->bgid + 1,
 					GFP_KERNEL);
 		if (ret < 0) {
-			while (!list_empty(&head->list)) {
-				struct io_buffer *buf;
-
-				buf = list_first_entry(&head->list,
-							struct io_buffer, list);
-				list_del(&buf->list);
-				kfree(buf);
-			}
-			kfree(head);
+			__io_remove_buffers(ctx, head, p->bgid, -1U);
 			goto out;
 		}
 	}
@@ -4825,6 +4887,9 @@ static int io_req_defer_prep(struct io_kiocb *req,
 	case IORING_OP_PROVIDE_BUFFERS:
 		ret = io_provide_buffers_prep(req, sqe);
 		break;
+	case IORING_OP_REMOVE_BUFFERS:
+		ret = io_remove_buffers_prep(req, sqe);
+		break;
 	default:
 		printk_once(KERN_WARNING "io_uring: unhandled opcode %d\n",
 				req->opcode);
@@ -5120,6 +5185,14 @@ static int io_issue_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
 		}
 		ret = io_provide_buffers(req, force_nonblock);
 		break;
+	case IORING_OP_REMOVE_BUFFERS:
+		if (sqe) {
+			ret = io_remove_buffers_prep(req, sqe);
+			if (ret)
+				break;
+		}
+		ret = io_remove_buffers(req, force_nonblock);
+		break;
 	default:
 		ret = -EINVAL;
 		break;
@@ -6998,16 +7071,7 @@ static int __io_destroy_buffers(int id, void *p, void *data)
 	struct io_ring_ctx *ctx = data;
 	struct io_buffer *buf = p;
 
-	/* the head kbuf is the list itself */
-	while (!list_empty(&buf->list)) {
-		struct io_buffer *nxt;
-
-		nxt = list_first_entry(&buf->list, struct io_buffer, list);
-		list_del(&nxt->list);
-		kfree(nxt);
-	}
-	kfree(buf);
-	idr_remove(&ctx->io_buffer_idr, id);
+	__io_remove_buffers(ctx, buf, id, -1U);
 	return 0;
 }