vxlan: add support for underlay in non-default VRF

Creating a VXLAN device with is underlay in the non-default VRF makes
egress route lookup fail or incorrect since it will resolve in the
default VRF, and ingress fail because the socket listens in the default
VRF.

This patch binds the underlying UDP tunnel socket to the l3mdev of the
lower device of the VXLAN device. This will listen in the proper VRF and
output traffic from said l3mdev, matching l3mdev routing rules and
looking up the correct routing table.

When the VXLAN device does not have a lower device, or the lower device
is in the default VRF, the socket will not be bound to any interface,
keeping the previous behaviour.

The underlay l3mdev is deduced from the VXLAN lower device
(IFLA_VXLAN_LINK).

+----------+                         +---------+
|          |                         |         |
| vrf-blue |                         | vrf-red |
|          |                         |         |
+----+-----+                         +----+----+
     |                                    |
     |                                    |
+----+-----+                         +----+----+
|          |                         |         |
| br-blue  |                         | br-red  |
|          |                         |         |
+----+-----+                         +---+-+---+
     |                                   | |
     |                             +-----+ +-----+
     |                             |             |
+----+-----+                +------+----+   +----+----+
|          |  lower device  |           |   |         |
|   eth0   | <- - - - - - - | vxlan-red |   | tap-red | (... more taps)
|          |                |           |   |         |
+----------+                +-----------+   +---------+

Signed-off-by: Alexis Bauvin <abauvin@scaleway.com>
Reviewed-by: Amine Kherbouche <akherbouche@scaleway.com>
Reviewed-by: David Ahern <dsahern@gmail.com>
Tested-by: Amine Kherbouche <akherbouche@scaleway.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/drivers/net/vxlan.c b/drivers/net/vxlan.c
index 9110662..901eef4 100644
--- a/drivers/net/vxlan.c
+++ b/drivers/net/vxlan.c
@@ -188,7 +188,7 @@ static inline struct vxlan_rdst *first_remote_rtnl(struct vxlan_fdb *fdb)
  * and enabled unshareable flags.
  */
 static struct vxlan_sock *vxlan_find_sock(struct net *net, sa_family_t family,
-					  __be16 port, u32 flags)
+					  __be16 port, u32 flags, int ifindex)
 {
 	struct vxlan_sock *vs;
 
@@ -197,7 +197,8 @@ static struct vxlan_sock *vxlan_find_sock(struct net *net, sa_family_t family,
 	hlist_for_each_entry_rcu(vs, vs_head(net, port), hlist) {
 		if (inet_sk(vs->sock->sk)->inet_sport == port &&
 		    vxlan_get_sk_family(vs) == family &&
-		    vs->flags == flags)
+		    vs->flags == flags &&
+		    vs->sock->sk->sk_bound_dev_if == ifindex)
 			return vs;
 	}
 	return NULL;
@@ -237,7 +238,7 @@ static struct vxlan_dev *vxlan_find_vni(struct net *net, int ifindex,
 {
 	struct vxlan_sock *vs;
 
-	vs = vxlan_find_sock(net, family, port, flags);
+	vs = vxlan_find_sock(net, family, port, flags, ifindex);
 	if (!vs)
 		return NULL;
 
@@ -2288,6 +2289,9 @@ static void vxlan_xmit_one(struct sk_buff *skb, struct net_device *dev,
 		struct rtable *rt;
 		__be16 df = 0;
 
+		if (!ifindex)
+			ifindex = sock4->sock->sk->sk_bound_dev_if;
+
 		rt = vxlan_get_route(vxlan, dev, sock4, skb, ifindex, tos,
 				     dst->sin.sin_addr.s_addr,
 				     &local_ip.sin.sin_addr.s_addr,
@@ -2337,6 +2341,9 @@ static void vxlan_xmit_one(struct sk_buff *skb, struct net_device *dev,
 	} else {
 		struct vxlan_sock *sock6 = rcu_dereference(vxlan->vn6_sock);
 
+		if (!ifindex)
+			ifindex = sock6->sock->sk->sk_bound_dev_if;
+
 		ndst = vxlan6_get_route(vxlan, dev, sock6, skb, ifindex, tos,
 					label, &dst->sin6.sin6_addr,
 					&local_ip.sin6.sin6_addr,
@@ -2951,7 +2958,7 @@ static const struct ethtool_ops vxlan_ethtool_ops = {
 };
 
 static struct socket *vxlan_create_sock(struct net *net, bool ipv6,
-					__be16 port, u32 flags)
+					__be16 port, u32 flags, int ifindex)
 {
 	struct socket *sock;
 	struct udp_port_cfg udp_conf;
@@ -2969,6 +2976,7 @@ static struct socket *vxlan_create_sock(struct net *net, bool ipv6,
 	}
 
 	udp_conf.local_udp_port = port;
+	udp_conf.bind_ifindex = ifindex;
 
 	/* Open UDP socket */
 	err = udp_sock_create(net, &udp_conf, &sock);
@@ -2980,7 +2988,8 @@ static struct socket *vxlan_create_sock(struct net *net, bool ipv6,
 
 /* Create new listen socket if needed */
 static struct vxlan_sock *vxlan_socket_create(struct net *net, bool ipv6,
-					      __be16 port, u32 flags)
+					      __be16 port, u32 flags,
+					      int ifindex)
 {
 	struct vxlan_net *vn = net_generic(net, vxlan_net_id);
 	struct vxlan_sock *vs;
@@ -2995,7 +3004,7 @@ static struct vxlan_sock *vxlan_socket_create(struct net *net, bool ipv6,
 	for (h = 0; h < VNI_HASH_SIZE; ++h)
 		INIT_HLIST_HEAD(&vs->vni_list[h]);
 
-	sock = vxlan_create_sock(net, ipv6, port, flags);
+	sock = vxlan_create_sock(net, ipv6, port, flags, ifindex);
 	if (IS_ERR(sock)) {
 		kfree(vs);
 		return ERR_CAST(sock);
@@ -3033,11 +3042,17 @@ static int __vxlan_sock_add(struct vxlan_dev *vxlan, bool ipv6)
 	struct vxlan_net *vn = net_generic(vxlan->net, vxlan_net_id);
 	struct vxlan_sock *vs = NULL;
 	struct vxlan_dev_node *node;
+	int l3mdev_index = 0;
+
+	if (vxlan->cfg.remote_ifindex)
+		l3mdev_index = l3mdev_master_upper_ifindex_by_index(
+			vxlan->net, vxlan->cfg.remote_ifindex);
 
 	if (!vxlan->cfg.no_share) {
 		spin_lock(&vn->sock_lock);
 		vs = vxlan_find_sock(vxlan->net, ipv6 ? AF_INET6 : AF_INET,
-				     vxlan->cfg.dst_port, vxlan->cfg.flags);
+				     vxlan->cfg.dst_port, vxlan->cfg.flags,
+				     l3mdev_index);
 		if (vs && !refcount_inc_not_zero(&vs->refcnt)) {
 			spin_unlock(&vn->sock_lock);
 			return -EBUSY;
@@ -3046,7 +3061,8 @@ static int __vxlan_sock_add(struct vxlan_dev *vxlan, bool ipv6)
 	}
 	if (!vs)
 		vs = vxlan_socket_create(vxlan->net, ipv6,
-					 vxlan->cfg.dst_port, vxlan->cfg.flags);
+					 vxlan->cfg.dst_port, vxlan->cfg.flags,
+					 l3mdev_index);
 	if (IS_ERR(vs))
 		return PTR_ERR(vs);
 #if IS_ENABLED(CONFIG_IPV6)