IB/mlx5: Don't override existing ip_protocol

Two flow specifications can set the ip protocol field in
the flow table entry:

1) IB_FLOW_SPEC_TCP/UDP/GRE - set the ip protocol accordingly.
2) IB_FLOW_SPEC_IPV4/6 - has ip_protocol field for users
who want to receive specific L4 packets.

We need to avoid overriding of the ip_protocol with zeros,
in case that the user first put the L4 specification and
only then the L3.

Fixes: ca0d47538528b ('IB/mlx5: Add support in TOS and protocol to flow steering')
Signed-off-by: Maor Gottlieb <maorg@mellanox.com>
Signed-off-by: Leon Romanovsky <leonro@mellanox.com>
Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
diff --git a/drivers/infiniband/hw/mlx5/main.c b/drivers/infiniband/hw/mlx5/main.c
index ae00f99..b32b5c7 100644
--- a/drivers/infiniband/hw/mlx5/main.c
+++ b/drivers/infiniband/hw/mlx5/main.c
@@ -2333,10 +2333,29 @@ static u8 get_match_criteria_enable(u32 *match_criteria)
 	return match_criteria_enable;
 }
 
-static void set_proto(void *outer_c, void *outer_v, u8 mask, u8 val)
+static int set_proto(void *outer_c, void *outer_v, u8 mask, u8 val)
 {
-	MLX5_SET(fte_match_set_lyr_2_4, outer_c, ip_protocol, mask);
-	MLX5_SET(fte_match_set_lyr_2_4, outer_v, ip_protocol, val);
+	u8 entry_mask;
+	u8 entry_val;
+	int err = 0;
+
+	if (!mask)
+		goto out;
+
+	entry_mask = MLX5_GET(fte_match_set_lyr_2_4, outer_c,
+			      ip_protocol);
+	entry_val = MLX5_GET(fte_match_set_lyr_2_4, outer_v,
+			     ip_protocol);
+	if (!entry_mask) {
+		MLX5_SET(fte_match_set_lyr_2_4, outer_c, ip_protocol, mask);
+		MLX5_SET(fte_match_set_lyr_2_4, outer_v, ip_protocol, val);
+		goto out;
+	}
+	/* Don't override existing ip protocol */
+	if (mask != entry_mask || val != entry_val)
+		err = -EINVAL;
+out:
+	return err;
 }
 
 static void set_flow_label(void *misc_c, void *misc_v, u32 mask, u32 val,
@@ -2570,8 +2589,10 @@ static int parse_flow_attr(struct mlx5_core_dev *mdev, u32 *match_c,
 		set_tos(headers_c, headers_v,
 			ib_spec->ipv4.mask.tos, ib_spec->ipv4.val.tos);
 
-		set_proto(headers_c, headers_v,
-			  ib_spec->ipv4.mask.proto, ib_spec->ipv4.val.proto);
+		if (set_proto(headers_c, headers_v,
+			      ib_spec->ipv4.mask.proto,
+			      ib_spec->ipv4.val.proto))
+			return -EINVAL;
 		break;
 	case IB_FLOW_SPEC_IPV6:
 		if (FIELDS_NOT_SUPPORTED(ib_spec->ipv6.mask, LAST_IPV6_FIELD))
@@ -2610,9 +2631,10 @@ static int parse_flow_attr(struct mlx5_core_dev *mdev, u32 *match_c,
 			ib_spec->ipv6.mask.traffic_class,
 			ib_spec->ipv6.val.traffic_class);
 
-		set_proto(headers_c, headers_v,
-			  ib_spec->ipv6.mask.next_hdr,
-			  ib_spec->ipv6.val.next_hdr);
+		if (set_proto(headers_c, headers_v,
+			      ib_spec->ipv6.mask.next_hdr,
+			      ib_spec->ipv6.val.next_hdr))
+			return -EINVAL;
 
 		set_flow_label(misc_params_c, misc_params_v,
 			       ntohl(ib_spec->ipv6.mask.flow_label),
@@ -2633,10 +2655,8 @@ static int parse_flow_attr(struct mlx5_core_dev *mdev, u32 *match_c,
 					 LAST_TCP_UDP_FIELD))
 			return -EOPNOTSUPP;
 
-		MLX5_SET(fte_match_set_lyr_2_4, headers_c, ip_protocol,
-			 0xff);
-		MLX5_SET(fte_match_set_lyr_2_4, headers_v, ip_protocol,
-			 IPPROTO_TCP);
+		if (set_proto(headers_c, headers_v, 0xff, IPPROTO_TCP))
+			return -EINVAL;
 
 		MLX5_SET(fte_match_set_lyr_2_4, headers_c, tcp_sport,
 			 ntohs(ib_spec->tcp_udp.mask.src_port));
@@ -2653,10 +2673,8 @@ static int parse_flow_attr(struct mlx5_core_dev *mdev, u32 *match_c,
 					 LAST_TCP_UDP_FIELD))
 			return -EOPNOTSUPP;
 
-		MLX5_SET(fte_match_set_lyr_2_4, headers_c, ip_protocol,
-			 0xff);
-		MLX5_SET(fte_match_set_lyr_2_4, headers_v, ip_protocol,
-			 IPPROTO_UDP);
+		if (set_proto(headers_c, headers_v, 0xff, IPPROTO_UDP))
+			return -EINVAL;
 
 		MLX5_SET(fte_match_set_lyr_2_4, headers_c, udp_sport,
 			 ntohs(ib_spec->tcp_udp.mask.src_port));
@@ -2672,6 +2690,9 @@ static int parse_flow_attr(struct mlx5_core_dev *mdev, u32 *match_c,
 		if (ib_spec->gre.mask.c_ks_res0_ver)
 			return -EOPNOTSUPP;
 
+		if (set_proto(headers_c, headers_v, 0xff, IPPROTO_GRE))
+			return -EINVAL;
+
 		MLX5_SET(fte_match_set_lyr_2_4, headers_c, ip_protocol,
 			 0xff);
 		MLX5_SET(fte_match_set_lyr_2_4, headers_v, ip_protocol,