ip6_gre: add ip6 erspan collect_md mode

Similar to ip6 gretap and ip4 gretap, the patch allows
erspan tunnel to operate in collect metadata mode.
bpf_skb_[gs]et_tunnel_key() helpers can make use of
it right away.

Signed-off-by: William Tu <u9012063@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv6/ip6_gre.c b/net/ipv6/ip6_gre.c
index 1510ce9..4562579 100644
--- a/net/ipv6/ip6_gre.c
+++ b/net/ipv6/ip6_gre.c
@@ -524,8 +524,37 @@ static int ip6erspan_rcv(struct sk_buff *skb, int gre_hdr_len,
 					   false, false) < 0)
 			return PACKET_REJECT;
 
-		tunnel->parms.index = ntohl(index);
-		ip6_tnl_rcv(tunnel, skb, tpi, NULL, log_ecn_error);
+		if (tunnel->parms.collect_md) {
+			struct metadata_dst *tun_dst;
+			struct ip_tunnel_info *info;
+			struct erspan_metadata *md;
+			__be64 tun_id;
+			__be16 flags;
+
+			tpi->flags |= TUNNEL_KEY;
+			flags = tpi->flags;
+			tun_id = key32_to_tunnel_id(tpi->key);
+
+			tun_dst = ipv6_tun_rx_dst(skb, flags, tun_id,
+						  sizeof(*md));
+			if (!tun_dst)
+				return PACKET_REJECT;
+
+			info = &tun_dst->u.tun_info;
+			md = ip_tunnel_info_opts(info);
+			if (!md)
+				return PACKET_REJECT;
+
+			md->index = index;
+			info->key.tun_flags |= TUNNEL_ERSPAN_OPT;
+			info->options_len = sizeof(*md);
+
+			ip6_tnl_rcv(tunnel, skb, tpi, tun_dst, log_ecn_error);
+
+		} else {
+			tunnel->parms.index = ntohl(index);
+			ip6_tnl_rcv(tunnel, skb, tpi, NULL, log_ecn_error);
+		}
 
 		return PACKET_RCVD;
 	}
@@ -857,42 +886,73 @@ static netdev_tx_t ip6erspan_tunnel_xmit(struct sk_buff *skb,
 	if (gre_handle_offloads(skb, false))
 		goto tx_err;
 
-	switch (skb->protocol) {
-	case htons(ETH_P_IP):
-		memset(&(IPCB(skb)->opt), 0, sizeof(IPCB(skb)->opt));
-		prepare_ip6gre_xmit_ipv4(skb, dev, &fl6,
-					 &dsfield, &encap_limit);
-		break;
-	case htons(ETH_P_IPV6):
-		if (ipv6_addr_equal(&t->parms.raddr, &ipv6h->saddr))
-			goto tx_err;
-		if (prepare_ip6gre_xmit_ipv6(skb, dev, &fl6,
-					     &dsfield, &encap_limit))
-			goto tx_err;
-		break;
-	default:
-		memcpy(&fl6, &t->fl.u.ip6, sizeof(fl6));
-		break;
-	}
-
 	if (skb->len > dev->mtu + dev->hard_header_len) {
 		pskb_trim(skb, dev->mtu + dev->hard_header_len);
 		truncate = true;
 	}
 
-	erspan_build_header(skb, t->parms.o_key, t->parms.index,
-			    truncate, false);
 	t->parms.o_flags &= ~TUNNEL_KEY;
-
 	IPCB(skb)->flags = 0;
-	fl6.daddr = t->parms.raddr;
+
+	/* For collect_md mode, derive fl6 from the tunnel key,
+	 * for native mode, call prepare_ip6gre_xmit_{ipv4,ipv6}.
+	 */
+	if (t->parms.collect_md) {
+		struct ip_tunnel_info *tun_info;
+		const struct ip_tunnel_key *key;
+		struct erspan_metadata *md;
+
+		tun_info = skb_tunnel_info(skb);
+		if (unlikely(!tun_info ||
+			     !(tun_info->mode & IP_TUNNEL_INFO_TX) ||
+			     ip_tunnel_info_af(tun_info) != AF_INET6))
+			return -EINVAL;
+
+		key = &tun_info->key;
+		memset(&fl6, 0, sizeof(fl6));
+		fl6.flowi6_proto = IPPROTO_GRE;
+		fl6.daddr = key->u.ipv6.dst;
+		fl6.flowlabel = key->label;
+		fl6.flowi6_uid = sock_net_uid(dev_net(dev), NULL);
+
+		dsfield = key->tos;
+		md = ip_tunnel_info_opts(tun_info);
+		if (!md)
+			goto tx_err;
+
+		erspan_build_header(skb, tunnel_id_to_key32(key->tun_id),
+				    ntohl(md->index), truncate, false);
+
+	} else {
+		switch (skb->protocol) {
+		case htons(ETH_P_IP):
+			memset(&(IPCB(skb)->opt), 0, sizeof(IPCB(skb)->opt));
+			prepare_ip6gre_xmit_ipv4(skb, dev, &fl6,
+						 &dsfield, &encap_limit);
+			break;
+		case htons(ETH_P_IPV6):
+			if (ipv6_addr_equal(&t->parms.raddr, &ipv6h->saddr))
+				goto tx_err;
+			if (prepare_ip6gre_xmit_ipv6(skb, dev, &fl6,
+						     &dsfield, &encap_limit))
+				goto tx_err;
+			break;
+		default:
+			memcpy(&fl6, &t->fl.u.ip6, sizeof(fl6));
+			break;
+		}
+
+		erspan_build_header(skb, t->parms.o_key, t->parms.index,
+				    truncate, false);
+		fl6.daddr = t->parms.raddr;
+	}
 
 	/* Push GRE header. */
 	gre_build_header(skb, 8, TUNNEL_SEQ,
 			 htons(ETH_P_ERSPAN), 0, htonl(t->o_seqno++));
 
 	/* TooBig packet may have updated dst->dev's mtu */
-	if (dst && dst_mtu(dst) > dst->dev->mtu)
+	if (!t->parms.collect_md && dst && dst_mtu(dst) > dst->dev->mtu)
 		dst->ops->update_pmtu(dst, NULL, skb, dst->dev->mtu);
 
 	err = ip6_tnl_xmit(skb, dev, dsfield, &fl6, encap_limit, &mtu,