Btrfs: avoid possible use-after-free in clear_extent_bit()

clear_extent_bit()
{
    next_node = rb_next(&state->rb_node);
    ...
    clear_state_bit(state);  <-- this may free next_node
    if (next_node) {
        state = rb_entry(next_node);
        ...
    }
}

clear_state_bit() calls merge_state() which may free the next node
of the passing extent_state, so clear_extent_bit() may end up
referencing freed memory.

Signed-off-by: Li Zefan <lizf@cn.fujitsu.com>
diff --git a/fs/btrfs/extent_io.c b/fs/btrfs/extent_io.c
index 05951bd..11eeb81 100644
--- a/fs/btrfs/extent_io.c
+++ b/fs/btrfs/extent_io.c
@@ -402,6 +402,15 @@
 	return 0;
 }
 
+static struct extent_state *next_state(struct extent_state *state)
+{
+	struct rb_node *next = rb_next(&state->rb_node);
+	if (next)
+		return rb_entry(next, struct extent_state, rb_node);
+	else
+		return NULL;
+}
+
 /*
  * utility function to clear some bits in an extent state struct.
  * it will optionally wake up any one waiting on this state (wake == 1)
@@ -409,10 +418,11 @@
  * If no bits are set on the state struct after clearing things, the
  * struct is freed and removed from the tree
  */
-static void clear_state_bit(struct extent_io_tree *tree,
-			    struct extent_state *state,
-			    int *bits, int wake)
+static struct extent_state *clear_state_bit(struct extent_io_tree *tree,
+					    struct extent_state *state,
+					    int *bits, int wake)
 {
+	struct extent_state *next;
 	int bits_to_clear = *bits & ~EXTENT_CTLBITS;
 
 	if ((bits_to_clear & EXTENT_DIRTY) && (state->state & EXTENT_DIRTY)) {
@@ -425,6 +435,7 @@
 	if (wake)
 		wake_up(&state->wq);
 	if (state->state == 0) {
+		next = next_state(state);
 		if (state->tree) {
 			rb_erase(&state->rb_node, &tree->state);
 			state->tree = NULL;
@@ -434,7 +445,9 @@
 		}
 	} else {
 		merge_state(tree, state);
+		next = next_state(state);
 	}
+	return next;
 }
 
 static struct extent_state *
@@ -473,7 +486,6 @@
 	struct extent_state *state;
 	struct extent_state *cached;
 	struct extent_state *prealloc = NULL;
-	struct rb_node *next_node;
 	struct rb_node *node;
 	u64 last_end;
 	int err;
@@ -525,14 +537,11 @@
 	WARN_ON(state->end < start);
 	last_end = state->end;
 
-	if (state->end < end && !need_resched())
-		next_node = rb_next(&state->rb_node);
-	else
-		next_node = NULL;
-
 	/* the state doesn't have the wanted bits, go ahead */
-	if (!(state->state & bits))
+	if (!(state->state & bits)) {
+		state = next_state(state);
 		goto next;
+	}
 
 	/*
 	 *     | ---- desired range ---- |
@@ -590,16 +599,13 @@
 		goto out;
 	}
 
-	clear_state_bit(tree, state, &bits, wake);
+	state = clear_state_bit(tree, state, &bits, wake);
 next:
 	if (last_end == (u64)-1)
 		goto out;
 	start = last_end + 1;
-	if (start <= end && next_node) {
-		state = rb_entry(next_node, struct extent_state,
-				 rb_node);
+	if (start <= end && state && !need_resched())
 		goto hit_next;
-	}
 	goto search_again;
 
 out: