netlink: extend policy range validation

Using a pointer to a struct indicating the min/max values,
extend the ability to do range validation for arbitrary
values. Small values in the s16 range can be kept in the
policy directly.

Signed-off-by: Johannes Berg <johannes.berg@intel.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/lib/nlattr.c b/lib/nlattr.c
index 7f7ebd8..a8beb17 100644
--- a/lib/nlattr.c
+++ b/lib/nlattr.c
@@ -111,17 +111,34 @@ static int nla_validate_array(const struct nlattr *head, int len, int maxtype,
 	return 0;
 }
 
-static int nla_validate_int_range(const struct nla_policy *pt,
-				  const struct nlattr *nla,
-				  struct netlink_ext_ack *extack)
+static int nla_validate_int_range_unsigned(const struct nla_policy *pt,
+					   const struct nlattr *nla,
+					   struct netlink_ext_ack *extack)
 {
-	bool validate_min, validate_max;
-	s64 value;
+	struct netlink_range_validation _range = {
+		.min = 0,
+		.max = U64_MAX,
+	}, *range = &_range;
+	u64 value;
 
-	validate_min = pt->validation_type == NLA_VALIDATE_RANGE ||
-		       pt->validation_type == NLA_VALIDATE_MIN;
-	validate_max = pt->validation_type == NLA_VALIDATE_RANGE ||
-		       pt->validation_type == NLA_VALIDATE_MAX;
+	WARN_ON_ONCE(pt->validation_type != NLA_VALIDATE_RANGE_PTR &&
+		     (pt->min < 0 || pt->max < 0));
+
+	switch (pt->validation_type) {
+	case NLA_VALIDATE_RANGE:
+		range->min = pt->min;
+		range->max = pt->max;
+		break;
+	case NLA_VALIDATE_RANGE_PTR:
+		range = pt->range;
+		break;
+	case NLA_VALIDATE_MIN:
+		range->min = pt->min;
+		break;
+	case NLA_VALIDATE_MAX:
+		range->max = pt->max;
+		break;
+	}
 
 	switch (pt->type) {
 	case NLA_U8:
@@ -133,6 +150,49 @@ static int nla_validate_int_range(const struct nla_policy *pt,
 	case NLA_U32:
 		value = nla_get_u32(nla);
 		break;
+	case NLA_U64:
+		value = nla_get_u64(nla);
+		break;
+	default:
+		return -EINVAL;
+	}
+
+	if (value < range->min || value > range->max) {
+		NL_SET_ERR_MSG_ATTR(extack, nla,
+				    "integer out of range");
+		return -ERANGE;
+	}
+
+	return 0;
+}
+
+static int nla_validate_int_range_signed(const struct nla_policy *pt,
+					 const struct nlattr *nla,
+					 struct netlink_ext_ack *extack)
+{
+	struct netlink_range_validation_signed _range = {
+		.min = S64_MIN,
+		.max = S64_MAX,
+	}, *range = &_range;
+	s64 value;
+
+	switch (pt->validation_type) {
+	case NLA_VALIDATE_RANGE:
+		range->min = pt->min;
+		range->max = pt->max;
+		break;
+	case NLA_VALIDATE_RANGE_PTR:
+		range = pt->range_signed;
+		break;
+	case NLA_VALIDATE_MIN:
+		range->min = pt->min;
+		break;
+	case NLA_VALIDATE_MAX:
+		range->max = pt->max;
+		break;
+	}
+
+	switch (pt->type) {
 	case NLA_S8:
 		value = nla_get_s8(nla);
 		break;
@@ -145,22 +205,11 @@ static int nla_validate_int_range(const struct nla_policy *pt,
 	case NLA_S64:
 		value = nla_get_s64(nla);
 		break;
-	case NLA_U64:
-		/* treat this one specially, since it may not fit into s64 */
-		if ((validate_min && nla_get_u64(nla) < pt->min) ||
-		    (validate_max && nla_get_u64(nla) > pt->max)) {
-			NL_SET_ERR_MSG_ATTR(extack, nla,
-					    "integer out of range");
-			return -ERANGE;
-		}
-		return 0;
 	default:
-		WARN_ON(1);
 		return -EINVAL;
 	}
 
-	if ((validate_min && value < pt->min) ||
-	    (validate_max && value > pt->max)) {
+	if (value < range->min || value > range->max) {
 		NL_SET_ERR_MSG_ATTR(extack, nla,
 				    "integer out of range");
 		return -ERANGE;
@@ -169,6 +218,27 @@ static int nla_validate_int_range(const struct nla_policy *pt,
 	return 0;
 }
 
+static int nla_validate_int_range(const struct nla_policy *pt,
+				  const struct nlattr *nla,
+				  struct netlink_ext_ack *extack)
+{
+	switch (pt->type) {
+	case NLA_U8:
+	case NLA_U16:
+	case NLA_U32:
+	case NLA_U64:
+		return nla_validate_int_range_unsigned(pt, nla, extack);
+	case NLA_S8:
+	case NLA_S16:
+	case NLA_S32:
+	case NLA_S64:
+		return nla_validate_int_range_signed(pt, nla, extack);
+	default:
+		WARN_ON(1);
+		return -EINVAL;
+	}
+}
+
 static int validate_nla(const struct nlattr *nla, int maxtype,
 			const struct nla_policy *policy, unsigned int validate,
 			struct netlink_ext_ack *extack, unsigned int depth)
@@ -348,6 +418,7 @@ static int validate_nla(const struct nlattr *nla, int maxtype,
 	case NLA_VALIDATE_NONE:
 		/* nothing to do */
 		break;
+	case NLA_VALIDATE_RANGE_PTR:
 	case NLA_VALIDATE_RANGE:
 	case NLA_VALIDATE_MIN:
 	case NLA_VALIDATE_MAX: