libceph, crush: per-pool crush_choose_arg_map for crush_do_rule()

If there is no crush_choose_arg_map for a given pool, a NULL pointer is
passed to preserve existing crush_do_rule() behavior.

Reflects ceph.git commits 55fb91d64071552ea1bc65ab4ea84d3c8b73ab4b,
                          dbe36e08be00c6519a8c89718dd47b0219c20516.

Signed-off-by: Ilya Dryomov <idryomov@gmail.com>
diff --git a/net/ceph/osdmap.c b/net/ceph/osdmap.c
index 9da0ee6..f630d10 100644
--- a/net/ceph/osdmap.c
+++ b/net/ceph/osdmap.c
@@ -138,6 +138,177 @@ static int crush_decode_straw2_bucket(void **p, void *end,
 	return -EINVAL;
 }
 
+static struct crush_choose_arg_map *alloc_choose_arg_map(void)
+{
+	struct crush_choose_arg_map *arg_map;
+
+	arg_map = kzalloc(sizeof(*arg_map), GFP_NOIO);
+	if (!arg_map)
+		return NULL;
+
+	RB_CLEAR_NODE(&arg_map->node);
+	return arg_map;
+}
+
+static void free_choose_arg_map(struct crush_choose_arg_map *arg_map)
+{
+	if (arg_map) {
+		int i, j;
+
+		WARN_ON(!RB_EMPTY_NODE(&arg_map->node));
+
+		for (i = 0; i < arg_map->size; i++) {
+			struct crush_choose_arg *arg = &arg_map->args[i];
+
+			for (j = 0; j < arg->weight_set_size; j++)
+				kfree(arg->weight_set[j].weights);
+			kfree(arg->weight_set);
+			kfree(arg->ids);
+		}
+		kfree(arg_map->args);
+		kfree(arg_map);
+	}
+}
+
+DEFINE_RB_FUNCS(choose_arg_map, struct crush_choose_arg_map, choose_args_index,
+		node);
+
+void clear_choose_args(struct crush_map *c)
+{
+	while (!RB_EMPTY_ROOT(&c->choose_args)) {
+		struct crush_choose_arg_map *arg_map =
+		    rb_entry(rb_first(&c->choose_args),
+			     struct crush_choose_arg_map, node);
+
+		erase_choose_arg_map(&c->choose_args, arg_map);
+		free_choose_arg_map(arg_map);
+	}
+}
+
+static u32 *decode_array_32_alloc(void **p, void *end, u32 *plen)
+{
+	u32 *a = NULL;
+	u32 len;
+	int ret;
+
+	ceph_decode_32_safe(p, end, len, e_inval);
+	if (len) {
+		u32 i;
+
+		a = kmalloc_array(len, sizeof(u32), GFP_NOIO);
+		if (!a) {
+			ret = -ENOMEM;
+			goto fail;
+		}
+
+		ceph_decode_need(p, end, len * sizeof(u32), e_inval);
+		for (i = 0; i < len; i++)
+			a[i] = ceph_decode_32(p);
+	}
+
+	*plen = len;
+	return a;
+
+e_inval:
+	ret = -EINVAL;
+fail:
+	kfree(a);
+	return ERR_PTR(ret);
+}
+
+/*
+ * Assumes @arg is zero-initialized.
+ */
+static int decode_choose_arg(void **p, void *end, struct crush_choose_arg *arg)
+{
+	int ret;
+
+	ceph_decode_32_safe(p, end, arg->weight_set_size, e_inval);
+	if (arg->weight_set_size) {
+		u32 i;
+
+		arg->weight_set = kmalloc_array(arg->weight_set_size,
+						sizeof(*arg->weight_set),
+						GFP_NOIO);
+		if (!arg->weight_set)
+			return -ENOMEM;
+
+		for (i = 0; i < arg->weight_set_size; i++) {
+			struct crush_weight_set *w = &arg->weight_set[i];
+
+			w->weights = decode_array_32_alloc(p, end, &w->size);
+			if (IS_ERR(w->weights)) {
+				ret = PTR_ERR(w->weights);
+				w->weights = NULL;
+				return ret;
+			}
+		}
+	}
+
+	arg->ids = decode_array_32_alloc(p, end, &arg->ids_size);
+	if (IS_ERR(arg->ids)) {
+		ret = PTR_ERR(arg->ids);
+		arg->ids = NULL;
+		return ret;
+	}
+
+	return 0;
+
+e_inval:
+	return -EINVAL;
+}
+
+static int decode_choose_args(void **p, void *end, struct crush_map *c)
+{
+	struct crush_choose_arg_map *arg_map = NULL;
+	u32 num_choose_arg_maps, num_buckets;
+	int ret;
+
+	ceph_decode_32_safe(p, end, num_choose_arg_maps, e_inval);
+	while (num_choose_arg_maps--) {
+		arg_map = alloc_choose_arg_map();
+		if (!arg_map) {
+			ret = -ENOMEM;
+			goto fail;
+		}
+
+		ceph_decode_64_safe(p, end, arg_map->choose_args_index,
+				    e_inval);
+		arg_map->size = c->max_buckets;
+		arg_map->args = kcalloc(arg_map->size, sizeof(*arg_map->args),
+					GFP_NOIO);
+		if (!arg_map->args) {
+			ret = -ENOMEM;
+			goto fail;
+		}
+
+		ceph_decode_32_safe(p, end, num_buckets, e_inval);
+		while (num_buckets--) {
+			struct crush_choose_arg *arg;
+			u32 bucket_index;
+
+			ceph_decode_32_safe(p, end, bucket_index, e_inval);
+			if (bucket_index >= arg_map->size)
+				goto e_inval;
+
+			arg = &arg_map->args[bucket_index];
+			ret = decode_choose_arg(p, end, arg);
+			if (ret)
+				goto fail;
+		}
+
+		insert_choose_arg_map(&c->choose_args, arg_map);
+	}
+
+	return 0;
+
+e_inval:
+	ret = -EINVAL;
+fail:
+	free_choose_arg_map(arg_map);
+	return ret;
+}
+
 static void crush_finalize(struct crush_map *c)
 {
 	__s32 b;
@@ -179,6 +350,8 @@ static struct crush_map *crush_decode(void *pbyval, void *end)
 	if (c == NULL)
 		return ERR_PTR(-ENOMEM);
 
+	c->choose_args = RB_ROOT;
+
         /* set tunables to default values */
         c->choose_local_tries = 2;
         c->choose_local_fallback_tries = 5;
@@ -372,6 +545,21 @@ static struct crush_map *crush_decode(void *pbyval, void *end)
 	dout("crush decode tunable chooseleaf_stable = %d\n",
 	     c->chooseleaf_stable);
 
+	if (*p != end) {
+		/* class_map */
+		ceph_decode_skip_map(p, end, 32, 32, bad);
+		/* class_name */
+		ceph_decode_skip_map(p, end, 32, string, bad);
+		/* class_bucket */
+		ceph_decode_skip_map_of_map(p, end, 32, 32, 32, bad);
+	}
+
+	if (*p != end) {
+		err = decode_choose_args(p, end, c);
+		if (err)
+			goto bad;
+	}
+
 done:
 	crush_finalize(c);
 	dout("crush_decode success\n");
@@ -2103,15 +2291,21 @@ static u32 raw_pg_to_pps(struct ceph_pg_pool_info *pi,
 
 static int do_crush(struct ceph_osdmap *map, int ruleno, int x,
 		    int *result, int result_max,
-		    const __u32 *weight, int weight_max)
+		    const __u32 *weight, int weight_max,
+		    u64 choose_args_index)
 {
+	struct crush_choose_arg_map *arg_map;
 	int r;
 
 	BUG_ON(result_max > CEPH_PG_MAX_SIZE);
 
+	arg_map = lookup_choose_arg_map(&map->crush->choose_args,
+					choose_args_index);
+
 	mutex_lock(&map->crush_workspace_mutex);
 	r = crush_do_rule(map->crush, ruleno, x, result, result_max,
-			  weight, weight_max, map->crush_workspace, NULL);
+			  weight, weight_max, map->crush_workspace,
+			  arg_map ? arg_map->args : NULL);
 	mutex_unlock(&map->crush_workspace_mutex);
 
 	return r;
@@ -2181,7 +2375,7 @@ static void pg_to_raw_osds(struct ceph_osdmap *osdmap,
 	}
 
 	len = do_crush(osdmap, ruleno, pps, raw->osds, pi->size,
-		       osdmap->osd_weight, osdmap->max_osd);
+		       osdmap->osd_weight, osdmap->max_osd, pi->id);
 	if (len < 0) {
 		pr_err("error %d from crush rule %d: pool %lld ruleset %d type %d size %d\n",
 		       len, ruleno, pi->id, pi->crush_ruleset, pi->type,