net: bridge: vlan: add per-vlan state

The first per-vlan option added is state, it is needed for EVPN and for
per-vlan STP. The state allows to control the forwarding on per-vlan
basis. The vlan state is considered only if the port state is forwarding
in order to avoid conflicts and be consistent. br_allowed_egress is
called only when the state is forwarding, but the ingress case is a bit
more complicated due to the fact that we may have the transition between
port:BR_STATE_FORWARDING -> vlan:BR_STATE_LEARNING which should still
allow the bridge to learn from the packet after vlan filtering and it will
be dropped after that. Also to optimize the pvid state check we keep a
copy in the vlan group to avoid one lookup. The state members are
modified with *_ONCE() to annotate the lockless access.

Signed-off-by: Nikolay Aleksandrov <nikolay@cumulusnetworks.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/bridge/br_vlan.c b/net/bridge/br_vlan.c
index d747756..6b5deca 100644
--- a/net/bridge/br_vlan.c
+++ b/net/bridge/br_vlan.c
@@ -34,13 +34,15 @@ static struct net_bridge_vlan *br_vlan_lookup(struct rhashtable *tbl, u16 vid)
 	return rhashtable_lookup_fast(tbl, &vid, br_vlan_rht_params);
 }
 
-static bool __vlan_add_pvid(struct net_bridge_vlan_group *vg, u16 vid)
+static bool __vlan_add_pvid(struct net_bridge_vlan_group *vg,
+			    const struct net_bridge_vlan *v)
 {
-	if (vg->pvid == vid)
+	if (vg->pvid == v->vid)
 		return false;
 
 	smp_wmb();
-	vg->pvid = vid;
+	br_vlan_set_pvid_state(vg, v->state);
+	vg->pvid = v->vid;
 
 	return true;
 }
@@ -69,7 +71,7 @@ static bool __vlan_add_flags(struct net_bridge_vlan *v, u16 flags)
 		vg = nbp_vlan_group(v->port);
 
 	if (flags & BRIDGE_VLAN_INFO_PVID)
-		ret = __vlan_add_pvid(vg, v->vid);
+		ret = __vlan_add_pvid(vg, v);
 	else
 		ret = __vlan_delete_pvid(vg, v->vid);
 
@@ -293,6 +295,9 @@ static int __vlan_add(struct net_bridge_vlan *v, u16 flags,
 		vg->num_vlans++;
 	}
 
+	/* set the state before publishing */
+	v->state = BR_STATE_FORWARDING;
+
 	err = rhashtable_lookup_insert_fast(&vg->vlan_hash, &v->vnode,
 					    br_vlan_rht_params);
 	if (err)
@@ -466,7 +471,8 @@ struct sk_buff *br_handle_vlan(struct net_bridge *br,
 /* Called under RCU */
 static bool __allowed_ingress(const struct net_bridge *br,
 			      struct net_bridge_vlan_group *vg,
-			      struct sk_buff *skb, u16 *vid)
+			      struct sk_buff *skb, u16 *vid,
+			      u8 *state)
 {
 	struct br_vlan_stats *stats;
 	struct net_bridge_vlan *v;
@@ -532,13 +538,25 @@ static bool __allowed_ingress(const struct net_bridge *br,
 			skb->vlan_tci |= pvid;
 
 		/* if stats are disabled we can avoid the lookup */
-		if (!br_opt_get(br, BROPT_VLAN_STATS_ENABLED))
-			return true;
+		if (!br_opt_get(br, BROPT_VLAN_STATS_ENABLED)) {
+			if (*state == BR_STATE_FORWARDING) {
+				*state = br_vlan_get_pvid_state(vg);
+				return br_vlan_state_allowed(*state, true);
+			} else {
+				return true;
+			}
+		}
 	}
 	v = br_vlan_find(vg, *vid);
 	if (!v || !br_vlan_should_use(v))
 		goto drop;
 
+	if (*state == BR_STATE_FORWARDING) {
+		*state = br_vlan_get_state(v);
+		if (!br_vlan_state_allowed(*state, true))
+			goto drop;
+	}
+
 	if (br_opt_get(br, BROPT_VLAN_STATS_ENABLED)) {
 		stats = this_cpu_ptr(v->stats);
 		u64_stats_update_begin(&stats->syncp);
@@ -556,7 +574,7 @@ static bool __allowed_ingress(const struct net_bridge *br,
 
 bool br_allowed_ingress(const struct net_bridge *br,
 			struct net_bridge_vlan_group *vg, struct sk_buff *skb,
-			u16 *vid)
+			u16 *vid, u8 *state)
 {
 	/* If VLAN filtering is disabled on the bridge, all packets are
 	 * permitted.
@@ -566,7 +584,7 @@ bool br_allowed_ingress(const struct net_bridge *br,
 		return true;
 	}
 
-	return __allowed_ingress(br, vg, skb, vid);
+	return __allowed_ingress(br, vg, skb, vid, state);
 }
 
 /* Called under RCU. */
@@ -582,7 +600,8 @@ bool br_allowed_egress(struct net_bridge_vlan_group *vg,
 
 	br_vlan_get_tag(skb, &vid);
 	v = br_vlan_find(vg, vid);
-	if (v && br_vlan_should_use(v))
+	if (v && br_vlan_should_use(v) &&
+	    br_vlan_state_allowed(br_vlan_get_state(v), false))
 		return true;
 
 	return false;
@@ -593,6 +612,7 @@ bool br_should_learn(struct net_bridge_port *p, struct sk_buff *skb, u16 *vid)
 {
 	struct net_bridge_vlan_group *vg;
 	struct net_bridge *br = p->br;
+	struct net_bridge_vlan *v;
 
 	/* If filtering was disabled at input, let it pass. */
 	if (!br_opt_get(br, BROPT_VLAN_ENABLED))
@@ -607,13 +627,15 @@ bool br_should_learn(struct net_bridge_port *p, struct sk_buff *skb, u16 *vid)
 
 	if (!*vid) {
 		*vid = br_get_pvid(vg);
-		if (!*vid)
+		if (!*vid ||
+		    !br_vlan_state_allowed(br_vlan_get_pvid_state(vg), true))
 			return false;
 
 		return true;
 	}
 
-	if (br_vlan_find(vg, *vid))
+	v = br_vlan_find(vg, *vid);
+	if (v && br_vlan_state_allowed(br_vlan_get_state(v), true))
 		return true;
 
 	return false;
@@ -1816,6 +1838,7 @@ static const struct nla_policy br_vlan_db_policy[BRIDGE_VLANDB_ENTRY_MAX + 1] =
 	[BRIDGE_VLANDB_ENTRY_INFO]	= { .type = NLA_EXACT_LEN,
 					    .len = sizeof(struct bridge_vlan_info) },
 	[BRIDGE_VLANDB_ENTRY_RANGE]	= { .type = NLA_U16 },
+	[BRIDGE_VLANDB_ENTRY_STATE]	= { .type = NLA_U8 },
 };
 
 static int br_vlan_rtm_process_one(struct net_device *dev,