radix-tree: Add radix_tree_iter_delete

Factor the deletion code out into __radix_tree_delete() and provide a
nice iterator-based wrapper around it.  If we free the node, advance
the iterator to avoid reading from freed memory.

Signed-off-by: Matthew Wilcox <mawilcox@microsoft.com>
diff --git a/lib/radix-tree.c b/lib/radix-tree.c
index 40f3091..7bf7d4e 100644
--- a/lib/radix-tree.c
+++ b/lib/radix-tree.c
@@ -581,10 +581,12 @@ static int radix_tree_extend(struct radix_tree_root *root,
  *	radix_tree_shrink    -    shrink radix tree to minimum height
  *	@root		radix tree root
  */
-static inline void radix_tree_shrink(struct radix_tree_root *root,
+static inline bool radix_tree_shrink(struct radix_tree_root *root,
 				     radix_tree_update_node_t update_node,
 				     void *private)
 {
+	bool shrunk = false;
+
 	for (;;) {
 		struct radix_tree_node *node = root->rnode;
 		struct radix_tree_node *child;
@@ -645,20 +647,26 @@ static inline void radix_tree_shrink(struct radix_tree_root *root,
 
 		WARN_ON_ONCE(!list_empty(&node->private_list));
 		radix_tree_node_free(node);
+		shrunk = true;
 	}
+
+	return shrunk;
 }
 
-static void delete_node(struct radix_tree_root *root,
+static bool delete_node(struct radix_tree_root *root,
 			struct radix_tree_node *node,
 			radix_tree_update_node_t update_node, void *private)
 {
+	bool deleted = false;
+
 	do {
 		struct radix_tree_node *parent;
 
 		if (node->count) {
 			if (node == entry_to_node(root->rnode))
-				radix_tree_shrink(root, update_node, private);
-			return;
+				deleted |= radix_tree_shrink(root, update_node,
+								private);
+			return deleted;
 		}
 
 		parent = node->parent;
@@ -672,9 +680,12 @@ static void delete_node(struct radix_tree_root *root,
 
 		WARN_ON_ONCE(!list_empty(&node->private_list));
 		radix_tree_node_free(node);
+		deleted = true;
 
 		node = parent;
 	} while (node);
+
+	return deleted;
 }
 
 /**
@@ -1859,25 +1870,55 @@ void __radix_tree_delete_node(struct radix_tree_root *root,
 	delete_node(root, node, update_node, private);
 }
 
+static bool __radix_tree_delete(struct radix_tree_root *root,
+				struct radix_tree_node *node, void **slot)
+{
+	unsigned offset = get_slot_offset(node, slot);
+	int tag;
+
+	for (tag = 0; tag < RADIX_TREE_MAX_TAGS; tag++)
+		node_tag_clear(root, node, tag, offset);
+
+	replace_slot(root, node, slot, NULL, true);
+	return node && delete_node(root, node, NULL, NULL);
+}
+
 /**
- *	radix_tree_delete_item    -    delete an item from a radix tree
- *	@root:		radix tree root
- *	@index:		index key
- *	@item:		expected item
+ * radix_tree_iter_delete - delete the entry at this iterator position
+ * @root: radix tree root
+ * @iter: iterator state
+ * @slot: pointer to slot
  *
- *	Remove @item at @index from the radix tree rooted at @root.
+ * Delete the entry at the position currently pointed to by the iterator.
+ * This may result in the current node being freed; if it is, the iterator
+ * is advanced so that it will not reference the freed memory.  This
+ * function may be called without any locking if there are no other threads
+ * which can access this tree.
+ */
+void radix_tree_iter_delete(struct radix_tree_root *root,
+				struct radix_tree_iter *iter, void **slot)
+{
+	if (__radix_tree_delete(root, iter->node, slot))
+		iter->index = iter->next_index;
+}
+
+/**
+ * radix_tree_delete_item - delete an item from a radix tree
+ * @root: radix tree root
+ * @index: index key
+ * @item: expected item
  *
- *	Returns the address of the deleted item, or NULL if it was not present
- *	or the entry at the given @index was not @item.
+ * Remove @item at @index from the radix tree rooted at @root.
+ *
+ * Return: the deleted entry, or %NULL if it was not present
+ * or the entry at the given @index was not @item.
  */
 void *radix_tree_delete_item(struct radix_tree_root *root,
 			     unsigned long index, void *item)
 {
 	struct radix_tree_node *node;
-	unsigned int offset;
 	void **slot;
 	void *entry;
-	int tag;
 
 	entry = __radix_tree_lookup(root, index, &node, &slot);
 	if (!entry)
@@ -1886,32 +1927,20 @@ void *radix_tree_delete_item(struct radix_tree_root *root,
 	if (item && entry != item)
 		return NULL;
 
-	if (!node) {
-		root_tag_clear_all(root);
-		root->rnode = NULL;
-		return entry;
-	}
-
-	offset = get_slot_offset(node, slot);
-
-	/* Clear all tags associated with the item to be deleted.  */
-	for (tag = 0; tag < RADIX_TREE_MAX_TAGS; tag++)
-		node_tag_clear(root, node, tag, offset);
-
-	__radix_tree_replace(root, node, slot, NULL, NULL, NULL);
+	__radix_tree_delete(root, node, slot);
 
 	return entry;
 }
 EXPORT_SYMBOL(radix_tree_delete_item);
 
 /**
- *	radix_tree_delete    -    delete an item from a radix tree
- *	@root:		radix tree root
- *	@index:		index key
+ * radix_tree_delete - delete an entry from a radix tree
+ * @root: radix tree root
+ * @index: index key
  *
- *	Remove the item at @index from the radix tree rooted at @root.
+ * Remove the entry at @index from the radix tree rooted at @root.
  *
- *	Returns the address of the deleted item, or NULL if it was not present.
+ * Return: The deleted entry, or %NULL if it was not present.
  */
 void *radix_tree_delete(struct radix_tree_root *root, unsigned long index)
 {