libceph: introduce and switch to decode_pg_mapping()

Signed-off-by: Ilya Dryomov <idryomov@gmail.com>
diff --git a/net/ceph/osdmap.c b/net/ceph/osdmap.c
index 06baf6b..41b380a 100644
--- a/net/ceph/osdmap.c
+++ b/net/ceph/osdmap.c
@@ -434,6 +434,25 @@ int ceph_spg_compare(const struct ceph_spg *lhs, const struct ceph_spg *rhs)
 	return 0;
 }
 
+static struct ceph_pg_mapping *alloc_pg_mapping(size_t payload_len)
+{
+	struct ceph_pg_mapping *pg;
+
+	pg = kmalloc(sizeof(*pg) + payload_len, GFP_NOIO);
+	if (!pg)
+		return NULL;
+
+	RB_CLEAR_NODE(&pg->node);
+	return pg;
+}
+
+static void free_pg_mapping(struct ceph_pg_mapping *pg)
+{
+	WARN_ON(!RB_EMPTY_NODE(&pg->node));
+
+	kfree(pg);
+}
+
 /*
  * rbtree of pg_mapping for handling pg_temp (explicit mapping of pgid
  * to a set of osds) and primary_temp (explicit primary setting)
@@ -1017,47 +1036,36 @@ static int decode_new_pools(void **p, void *end, struct ceph_osdmap *map)
 	return __decode_pools(p, end, map, true);
 }
 
-static int __decode_pg_temp(void **p, void *end, struct ceph_osdmap *map,
-			    bool incremental)
+typedef struct ceph_pg_mapping *(*decode_mapping_fn_t)(void **, void *, bool);
+
+static int decode_pg_mapping(void **p, void *end, struct rb_root *mapping_root,
+			     decode_mapping_fn_t fn, bool incremental)
 {
 	u32 n;
 
+	WARN_ON(!incremental && !fn);
+
 	ceph_decode_32_safe(p, end, n, e_inval);
 	while (n--) {
+		struct ceph_pg_mapping *pg;
 		struct ceph_pg pgid;
-		u32 len, i;
 		int ret;
 
 		ret = ceph_decode_pgid(p, end, &pgid);
 		if (ret)
 			return ret;
 
-		ceph_decode_32_safe(p, end, len, e_inval);
+		ret = __remove_pg_mapping(mapping_root, &pgid);
+		WARN_ON(!incremental && ret != -ENOENT);
 
-		ret = __remove_pg_mapping(&map->pg_temp, &pgid);
-		BUG_ON(!incremental && ret != -ENOENT);
+		if (fn) {
+			pg = fn(p, end, incremental);
+			if (IS_ERR(pg))
+				return PTR_ERR(pg);
 
-		if (!incremental || len > 0) {
-			struct ceph_pg_mapping *pg;
-
-			ceph_decode_need(p, end, len*sizeof(u32), e_inval);
-
-			if (len > (UINT_MAX - sizeof(*pg)) / sizeof(u32))
-				return -EINVAL;
-
-			pg = kzalloc(sizeof(*pg) + len*sizeof(u32), GFP_NOFS);
-			if (!pg)
-				return -ENOMEM;
-
-			pg->pgid = pgid;
-			pg->pg_temp.len = len;
-			for (i = 0; i < len; i++)
-				pg->pg_temp.osds[i] = ceph_decode_32(p);
-
-			ret = __insert_pg_mapping(pg, &map->pg_temp);
-			if (ret) {
-				kfree(pg);
-				return ret;
+			if (pg) {
+				pg->pgid = pgid; /* struct */
+				__insert_pg_mapping(pg, mapping_root);
 			}
 		}
 	}
@@ -1068,69 +1076,77 @@ static int __decode_pg_temp(void **p, void *end, struct ceph_osdmap *map,
 	return -EINVAL;
 }
 
+static struct ceph_pg_mapping *__decode_pg_temp(void **p, void *end,
+						bool incremental)
+{
+	struct ceph_pg_mapping *pg;
+	u32 len, i;
+
+	ceph_decode_32_safe(p, end, len, e_inval);
+	if (len == 0 && incremental)
+		return NULL;	/* new_pg_temp: [] to remove */
+	if (len > (SIZE_MAX - sizeof(*pg)) / sizeof(u32))
+		return ERR_PTR(-EINVAL);
+
+	ceph_decode_need(p, end, len * sizeof(u32), e_inval);
+	pg = alloc_pg_mapping(len * sizeof(u32));
+	if (!pg)
+		return ERR_PTR(-ENOMEM);
+
+	pg->pg_temp.len = len;
+	for (i = 0; i < len; i++)
+		pg->pg_temp.osds[i] = ceph_decode_32(p);
+
+	return pg;
+
+e_inval:
+	return ERR_PTR(-EINVAL);
+}
+
 static int decode_pg_temp(void **p, void *end, struct ceph_osdmap *map)
 {
-	return __decode_pg_temp(p, end, map, false);
+	return decode_pg_mapping(p, end, &map->pg_temp, __decode_pg_temp,
+				 false);
 }
 
 static int decode_new_pg_temp(void **p, void *end, struct ceph_osdmap *map)
 {
-	return __decode_pg_temp(p, end, map, true);
+	return decode_pg_mapping(p, end, &map->pg_temp, __decode_pg_temp,
+				 true);
 }
 
-static int __decode_primary_temp(void **p, void *end, struct ceph_osdmap *map,
-				 bool incremental)
+static struct ceph_pg_mapping *__decode_primary_temp(void **p, void *end,
+						     bool incremental)
 {
-	u32 n;
+	struct ceph_pg_mapping *pg;
+	u32 osd;
 
-	ceph_decode_32_safe(p, end, n, e_inval);
-	while (n--) {
-		struct ceph_pg pgid;
-		u32 osd;
-		int ret;
+	ceph_decode_32_safe(p, end, osd, e_inval);
+	if (osd == (u32)-1 && incremental)
+		return NULL;	/* new_primary_temp: -1 to remove */
 
-		ret = ceph_decode_pgid(p, end, &pgid);
-		if (ret)
-			return ret;
+	pg = alloc_pg_mapping(0);
+	if (!pg)
+		return ERR_PTR(-ENOMEM);
 
-		ceph_decode_32_safe(p, end, osd, e_inval);
-
-		ret = __remove_pg_mapping(&map->primary_temp, &pgid);
-		BUG_ON(!incremental && ret != -ENOENT);
-
-		if (!incremental || osd != (u32)-1) {
-			struct ceph_pg_mapping *pg;
-
-			pg = kzalloc(sizeof(*pg), GFP_NOFS);
-			if (!pg)
-				return -ENOMEM;
-
-			pg->pgid = pgid;
-			pg->primary_temp.osd = osd;
-
-			ret = __insert_pg_mapping(pg, &map->primary_temp);
-			if (ret) {
-				kfree(pg);
-				return ret;
-			}
-		}
-	}
-
-	return 0;
+	pg->primary_temp.osd = osd;
+	return pg;
 
 e_inval:
-	return -EINVAL;
+	return ERR_PTR(-EINVAL);
 }
 
 static int decode_primary_temp(void **p, void *end, struct ceph_osdmap *map)
 {
-	return __decode_primary_temp(p, end, map, false);
+	return decode_pg_mapping(p, end, &map->primary_temp,
+				 __decode_primary_temp, false);
 }
 
 static int decode_new_primary_temp(void **p, void *end,
 				   struct ceph_osdmap *map)
 {
-	return __decode_primary_temp(p, end, map, true);
+	return decode_pg_mapping(p, end, &map->primary_temp,
+				 __decode_primary_temp, true);
 }
 
 u32 ceph_get_primary_affinity(struct ceph_osdmap *map, int osd)