ipv4: add a sock pointer to dst->output() path.

In the dst->output() path for ipv4, the code assumes the skb it has to
transmit is attached to an inet socket, specifically via
ip_mc_output() : The sk_mc_loop() test triggers a WARN_ON() when the
provider of the packet is an AF_PACKET socket.

The dst->output() method gets an additional 'struct sock *sk'
parameter. This needs a cascade of changes so that this parameter can
be propagated from vxlan to final consumer.

Fixes: 8f646c922d55 ("vxlan: keep original skb ownership")
Reported-by: lucien xin <lucien.xin@gmail.com>
Signed-off-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/drivers/net/vxlan.c b/drivers/net/vxlan.c
index c55e316..82355d5 100644
--- a/drivers/net/vxlan.c
+++ b/drivers/net/vxlan.c
@@ -1755,8 +1755,8 @@
 	if (err)
 		return err;
 
-	return iptunnel_xmit(rt, skb, src, dst, IPPROTO_UDP, tos, ttl, df,
-			     false);
+	return iptunnel_xmit(vs->sock->sk, rt, skb, src, dst, IPPROTO_UDP,
+			     tos, ttl, df, false);
 }
 EXPORT_SYMBOL_GPL(vxlan_xmit_skb);
 
diff --git a/include/net/dst.h b/include/net/dst.h
index 46ed958..71c60f4 100644
--- a/include/net/dst.h
+++ b/include/net/dst.h
@@ -45,7 +45,7 @@
 	void			*__pad1;
 #endif
 	int			(*input)(struct sk_buff *);
-	int			(*output)(struct sk_buff *);
+	int			(*output)(struct sock *sk, struct sk_buff *skb);
 
 	unsigned short		flags;
 #define DST_HOST		0x0001
@@ -367,7 +367,11 @@
 	return child;
 }
 
-int dst_discard(struct sk_buff *skb);
+int dst_discard_sk(struct sock *sk, struct sk_buff *skb);
+static inline int dst_discard(struct sk_buff *skb)
+{
+	return dst_discard_sk(skb->sk, skb);
+}
 void *dst_alloc(struct dst_ops *ops, struct net_device *dev, int initial_ref,
 		int initial_obsolete, unsigned short flags);
 void __dst_free(struct dst_entry *dst);
@@ -449,9 +453,13 @@
 }
 
 /* Output packet to network from transport.  */
+static inline int dst_output_sk(struct sock *sk, struct sk_buff *skb)
+{
+	return skb_dst(skb)->output(sk, skb);
+}
 static inline int dst_output(struct sk_buff *skb)
 {
-	return skb_dst(skb)->output(skb);
+	return dst_output_sk(skb->sk, skb);
 }
 
 /* Input packet from network to transport.  */
diff --git a/include/net/ip.h b/include/net/ip.h
index 77e73d2..3ec2b0f 100644
--- a/include/net/ip.h
+++ b/include/net/ip.h
@@ -104,13 +104,18 @@
 	   struct net_device *orig_dev);
 int ip_local_deliver(struct sk_buff *skb);
 int ip_mr_input(struct sk_buff *skb);
-int ip_output(struct sk_buff *skb);
-int ip_mc_output(struct sk_buff *skb);
+int ip_output(struct sock *sk, struct sk_buff *skb);
+int ip_mc_output(struct sock *sk, struct sk_buff *skb);
 int ip_fragment(struct sk_buff *skb, int (*output)(struct sk_buff *));
 int ip_do_nat(struct sk_buff *skb);
 void ip_send_check(struct iphdr *ip);
 int __ip_local_out(struct sk_buff *skb);
-int ip_local_out(struct sk_buff *skb);
+int ip_local_out_sk(struct sock *sk, struct sk_buff *skb);
+static inline int ip_local_out(struct sk_buff *skb)
+{
+	return ip_local_out_sk(skb->sk, skb);
+}
+
 int ip_queue_xmit(struct sock *sk, struct sk_buff *skb, struct flowi *fl);
 void ip_init(void);
 int ip_append_data(struct sock *sk, struct flowi4 *fl4,
diff --git a/include/net/ip_tunnels.h b/include/net/ip_tunnels.h
index e77c104..a4daf9e 100644
--- a/include/net/ip_tunnels.h
+++ b/include/net/ip_tunnels.h
@@ -153,7 +153,7 @@
 }
 
 int iptunnel_pull_header(struct sk_buff *skb, int hdr_len, __be16 inner_proto);
-int iptunnel_xmit(struct rtable *rt, struct sk_buff *skb,
+int iptunnel_xmit(struct sock *sk, struct rtable *rt, struct sk_buff *skb,
 		  __be32 src, __be32 dst, __u8 proto,
 		  __u8 tos, __u8 ttl, __be16 df, bool xnet);
 
diff --git a/include/net/ipv6.h b/include/net/ipv6.h
index 4f541f1..d640925 100644
--- a/include/net/ipv6.h
+++ b/include/net/ipv6.h
@@ -731,7 +731,7 @@
  *	skb processing functions
  */
 
-int ip6_output(struct sk_buff *skb);
+int ip6_output(struct sock *sk, struct sk_buff *skb);
 int ip6_forward(struct sk_buff *skb);
 int ip6_input(struct sk_buff *skb);
 int ip6_mc_input(struct sk_buff *skb);
diff --git a/include/net/xfrm.h b/include/net/xfrm.h
index 32682ae..116e9c7 100644
--- a/include/net/xfrm.h
+++ b/include/net/xfrm.h
@@ -333,7 +333,7 @@
 						const xfrm_address_t *saddr);
 	int			(*tmpl_sort)(struct xfrm_tmpl **dst, struct xfrm_tmpl **src, int n);
 	int			(*state_sort)(struct xfrm_state **dst, struct xfrm_state **src, int n);
-	int			(*output)(struct sk_buff *skb);
+	int			(*output)(struct sock *sk, struct sk_buff *skb);
 	int			(*output_finish)(struct sk_buff *skb);
 	int			(*extract_input)(struct xfrm_state *x,
 						 struct sk_buff *skb);
@@ -1540,7 +1540,7 @@
 
 int xfrm4_extract_output(struct xfrm_state *x, struct sk_buff *skb);
 int xfrm4_prepare_output(struct xfrm_state *x, struct sk_buff *skb);
-int xfrm4_output(struct sk_buff *skb);
+int xfrm4_output(struct sock *sk, struct sk_buff *skb);
 int xfrm4_output_finish(struct sk_buff *skb);
 int xfrm4_rcv_cb(struct sk_buff *skb, u8 protocol, int err);
 int xfrm4_protocol_register(struct xfrm4_protocol *handler, unsigned char protocol);
@@ -1565,7 +1565,7 @@
 __be32 xfrm6_tunnel_spi_lookup(struct net *net, const xfrm_address_t *saddr);
 int xfrm6_extract_output(struct xfrm_state *x, struct sk_buff *skb);
 int xfrm6_prepare_output(struct xfrm_state *x, struct sk_buff *skb);
-int xfrm6_output(struct sk_buff *skb);
+int xfrm6_output(struct sock *sk, struct sk_buff *skb);
 int xfrm6_output_finish(struct sk_buff *skb);
 int xfrm6_find_1stfragopt(struct xfrm_state *x, struct sk_buff *skb,
 			  u8 **prevhdr);
diff --git a/net/core/dst.c b/net/core/dst.c
index ca4231e..80d6286 100644
--- a/net/core/dst.c
+++ b/net/core/dst.c
@@ -142,12 +142,12 @@
 	mutex_unlock(&dst_gc_mutex);
 }
 
-int dst_discard(struct sk_buff *skb)
+int dst_discard_sk(struct sock *sk, struct sk_buff *skb)
 {
 	kfree_skb(skb);
 	return 0;
 }
-EXPORT_SYMBOL(dst_discard);
+EXPORT_SYMBOL(dst_discard_sk);
 
 const u32 dst_default_metrics[RTAX_MAX + 1] = {
 	/* This initializer is needed to force linker to place this variable
@@ -184,7 +184,7 @@
 	dst->xfrm = NULL;
 #endif
 	dst->input = dst_discard;
-	dst->output = dst_discard;
+	dst->output = dst_discard_sk;
 	dst->error = 0;
 	dst->obsolete = initial_obsolete;
 	dst->header_len = 0;
@@ -209,8 +209,10 @@
 	/* The first case (dev==NULL) is required, when
 	   protocol module is unloaded.
 	 */
-	if (dst->dev == NULL || !(dst->dev->flags&IFF_UP))
-		dst->input = dst->output = dst_discard;
+	if (dst->dev == NULL || !(dst->dev->flags&IFF_UP)) {
+		dst->input = dst_discard;
+		dst->output = dst_discard_sk;
+	}
 	dst->obsolete = DST_OBSOLETE_DEAD;
 }
 
@@ -361,7 +363,8 @@
 		return;
 
 	if (!unregister) {
-		dst->input = dst->output = dst_discard;
+		dst->input = dst_discard;
+		dst->output = dst_discard_sk;
 	} else {
 		dst->dev = dev_net(dst->dev)->loopback_dev;
 		dev_hold(dst->dev);
diff --git a/net/decnet/dn_route.c b/net/decnet/dn_route.c
index ce0cbbf..daccc4a 100644
--- a/net/decnet/dn_route.c
+++ b/net/decnet/dn_route.c
@@ -752,7 +752,7 @@
 	return n->output(n, skb);
 }
 
-static int dn_output(struct sk_buff *skb)
+static int dn_output(struct sock *sk, struct sk_buff *skb)
 {
 	struct dst_entry *dst = skb_dst(skb);
 	struct dn_route *rt = (struct dn_route *)dst;
@@ -838,6 +838,18 @@
  * Used to catch bugs. This should never normally get
  * called.
  */
+static int dn_rt_bug_sk(struct sock *sk, struct sk_buff *skb)
+{
+	struct dn_skb_cb *cb = DN_SKB_CB(skb);
+
+	net_dbg_ratelimited("dn_rt_bug: skb from:%04x to:%04x\n",
+			    le16_to_cpu(cb->src), le16_to_cpu(cb->dst));
+
+	kfree_skb(skb);
+
+	return NET_RX_DROP;
+}
+
 static int dn_rt_bug(struct sk_buff *skb)
 {
 	struct dn_skb_cb *cb = DN_SKB_CB(skb);
@@ -1463,7 +1475,7 @@
 
 	rt->n = neigh;
 	rt->dst.lastuse = jiffies;
-	rt->dst.output = dn_rt_bug;
+	rt->dst.output = dn_rt_bug_sk;
 	switch (res.type) {
 	case RTN_UNICAST:
 		rt->dst.input = dn_forward;
diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c
index 7ad68b8..1cbeba5 100644
--- a/net/ipv4/ip_output.c
+++ b/net/ipv4/ip_output.c
@@ -101,17 +101,17 @@
 		       skb_dst(skb)->dev, dst_output);
 }
 
-int ip_local_out(struct sk_buff *skb)
+int ip_local_out_sk(struct sock *sk, struct sk_buff *skb)
 {
 	int err;
 
 	err = __ip_local_out(skb);
 	if (likely(err == 1))
-		err = dst_output(skb);
+		err = dst_output_sk(sk, skb);
 
 	return err;
 }
-EXPORT_SYMBOL_GPL(ip_local_out);
+EXPORT_SYMBOL_GPL(ip_local_out_sk);
 
 static inline int ip_select_ttl(struct inet_sock *inet, struct dst_entry *dst)
 {
@@ -226,9 +226,8 @@
 		return ip_finish_output2(skb);
 }
 
-int ip_mc_output(struct sk_buff *skb)
+int ip_mc_output(struct sock *sk, struct sk_buff *skb)
 {
-	struct sock *sk = skb->sk;
 	struct rtable *rt = skb_rtable(skb);
 	struct net_device *dev = rt->dst.dev;
 
@@ -287,7 +286,7 @@
 			    !(IPCB(skb)->flags & IPSKB_REROUTED));
 }
 
-int ip_output(struct sk_buff *skb)
+int ip_output(struct sock *sk, struct sk_buff *skb)
 {
 	struct net_device *dev = skb_dst(skb)->dev;
 
diff --git a/net/ipv4/ip_tunnel.c b/net/ipv4/ip_tunnel.c
index e77381d..484d0ce 100644
--- a/net/ipv4/ip_tunnel.c
+++ b/net/ipv4/ip_tunnel.c
@@ -670,7 +670,7 @@
 		return;
 	}
 
-	err = iptunnel_xmit(rt, skb, fl4.saddr, fl4.daddr, protocol,
+	err = iptunnel_xmit(skb->sk, rt, skb, fl4.saddr, fl4.daddr, protocol,
 			    tos, ttl, df, !net_eq(tunnel->net, dev_net(dev)));
 	iptunnel_xmit_stats(err, &dev->stats, dev->tstats);
 
diff --git a/net/ipv4/ip_tunnel_core.c b/net/ipv4/ip_tunnel_core.c
index e0c2b1d..bcf206c 100644
--- a/net/ipv4/ip_tunnel_core.c
+++ b/net/ipv4/ip_tunnel_core.c
@@ -46,7 +46,7 @@
 #include <net/netns/generic.h>
 #include <net/rtnetlink.h>
 
-int iptunnel_xmit(struct rtable *rt, struct sk_buff *skb,
+int iptunnel_xmit(struct sock *sk, struct rtable *rt, struct sk_buff *skb,
 		  __be32 src, __be32 dst, __u8 proto,
 		  __u8 tos, __u8 ttl, __be16 df, bool xnet)
 {
@@ -76,7 +76,7 @@
 	iph->ttl	=	ttl;
 	__ip_select_ident(iph, &rt->dst, (skb_shinfo(skb)->gso_segs ?: 1) - 1);
 
-	err = ip_local_out(skb);
+	err = ip_local_out_sk(sk, skb);
 	if (unlikely(net_xmit_eval(err)))
 		pkt_len = 0;
 	return pkt_len;
diff --git a/net/ipv4/route.c b/net/ipv4/route.c
index 20a59c3..1485aaf 100644
--- a/net/ipv4/route.c
+++ b/net/ipv4/route.c
@@ -1129,7 +1129,7 @@
 		dst_set_expires(&rt->dst, 0);
 }
 
-static int ip_rt_bug(struct sk_buff *skb)
+static int ip_rt_bug(struct sock *sk, struct sk_buff *skb)
 {
 	pr_debug("%s: %pI4 -> %pI4, %s\n",
 		 __func__, &ip_hdr(skb)->saddr, &ip_hdr(skb)->daddr,
@@ -2218,7 +2218,7 @@
 
 		new->__use = 1;
 		new->input = dst_discard;
-		new->output = dst_discard;
+		new->output = dst_discard_sk;
 
 		new->dev = ort->dst.dev;
 		if (new->dev)
diff --git a/net/ipv4/xfrm4_output.c b/net/ipv4/xfrm4_output.c
index baa0f63..40e701f 100644
--- a/net/ipv4/xfrm4_output.c
+++ b/net/ipv4/xfrm4_output.c
@@ -86,7 +86,7 @@
 	return xfrm_output(skb);
 }
 
-int xfrm4_output(struct sk_buff *skb)
+int xfrm4_output(struct sock *sk, struct sk_buff *skb)
 {
 	struct dst_entry *dst = skb_dst(skb);
 	struct xfrm_state *x = dst->xfrm;
diff --git a/net/ipv6/ip6_output.c b/net/ipv6/ip6_output.c
index 3284d61..40e7581 100644
--- a/net/ipv6/ip6_output.c
+++ b/net/ipv6/ip6_output.c
@@ -132,7 +132,7 @@
 		return ip6_finish_output2(skb);
 }
 
-int ip6_output(struct sk_buff *skb)
+int ip6_output(struct sock *sk, struct sk_buff *skb)
 {
 	struct net_device *dev = skb_dst(skb)->dev;
 	struct inet6_dev *idev = ip6_dst_idev(skb_dst(skb));
diff --git a/net/ipv6/route.c b/net/ipv6/route.c
index 5ea462e..4011617 100644
--- a/net/ipv6/route.c
+++ b/net/ipv6/route.c
@@ -84,9 +84,9 @@
 static int		 ip6_dst_gc(struct dst_ops *ops);
 
 static int		ip6_pkt_discard(struct sk_buff *skb);
-static int		ip6_pkt_discard_out(struct sk_buff *skb);
+static int		ip6_pkt_discard_out(struct sock *sk, struct sk_buff *skb);
 static int		ip6_pkt_prohibit(struct sk_buff *skb);
-static int		ip6_pkt_prohibit_out(struct sk_buff *skb);
+static int		ip6_pkt_prohibit_out(struct sock *sk, struct sk_buff *skb);
 static void		ip6_link_failure(struct sk_buff *skb);
 static void		ip6_rt_update_pmtu(struct dst_entry *dst, struct sock *sk,
 					   struct sk_buff *skb, u32 mtu);
@@ -290,7 +290,7 @@
 		.obsolete	= DST_OBSOLETE_FORCE_CHK,
 		.error		= -EINVAL,
 		.input		= dst_discard,
-		.output		= dst_discard,
+		.output		= dst_discard_sk,
 	},
 	.rt6i_flags	= (RTF_REJECT | RTF_NONEXTHOP),
 	.rt6i_protocol  = RTPROT_KERNEL,
@@ -1058,7 +1058,7 @@
 
 		new->__use = 1;
 		new->input = dst_discard;
-		new->output = dst_discard;
+		new->output = dst_discard_sk;
 
 		if (dst_metrics_read_only(&ort->dst))
 			new->_metrics = ort->dst._metrics;
@@ -1577,7 +1577,7 @@
 		switch (cfg->fc_type) {
 		case RTN_BLACKHOLE:
 			rt->dst.error = -EINVAL;
-			rt->dst.output = dst_discard;
+			rt->dst.output = dst_discard_sk;
 			rt->dst.input = dst_discard;
 			break;
 		case RTN_PROHIBIT:
@@ -2129,7 +2129,7 @@
 	return ip6_pkt_drop(skb, ICMPV6_NOROUTE, IPSTATS_MIB_INNOROUTES);
 }
 
-static int ip6_pkt_discard_out(struct sk_buff *skb)
+static int ip6_pkt_discard_out(struct sock *sk, struct sk_buff *skb)
 {
 	skb->dev = skb_dst(skb)->dev;
 	return ip6_pkt_drop(skb, ICMPV6_NOROUTE, IPSTATS_MIB_OUTNOROUTES);
@@ -2140,7 +2140,7 @@
 	return ip6_pkt_drop(skb, ICMPV6_ADM_PROHIBITED, IPSTATS_MIB_INNOROUTES);
 }
 
-static int ip6_pkt_prohibit_out(struct sk_buff *skb)
+static int ip6_pkt_prohibit_out(struct sock *sk, struct sk_buff *skb)
 {
 	skb->dev = skb_dst(skb)->dev;
 	return ip6_pkt_drop(skb, ICMPV6_ADM_PROHIBITED, IPSTATS_MIB_OUTNOROUTES);
diff --git a/net/ipv6/sit.c b/net/ipv6/sit.c
index 1693c8d..8da8268 100644
--- a/net/ipv6/sit.c
+++ b/net/ipv6/sit.c
@@ -974,8 +974,9 @@
 		goto out;
 	}
 
-	err = iptunnel_xmit(rt, skb, fl4.saddr, fl4.daddr, IPPROTO_IPV6, tos,
-			    ttl, df, !net_eq(tunnel->net, dev_net(dev)));
+	err = iptunnel_xmit(skb->sk, rt, skb, fl4.saddr, fl4.daddr,
+			    IPPROTO_IPV6, tos, ttl, df,
+			    !net_eq(tunnel->net, dev_net(dev)));
 	iptunnel_xmit_stats(err, &dev->stats, dev->tstats);
 	return NETDEV_TX_OK;
 
diff --git a/net/ipv6/xfrm6_output.c b/net/ipv6/xfrm6_output.c
index 6cd625e..19ef329 100644
--- a/net/ipv6/xfrm6_output.c
+++ b/net/ipv6/xfrm6_output.c
@@ -163,7 +163,7 @@
 	return x->outer_mode->afinfo->output_finish(skb);
 }
 
-int xfrm6_output(struct sk_buff *skb)
+int xfrm6_output(struct sock *sk, struct sk_buff *skb)
 {
 	return NF_HOOK(NFPROTO_IPV6, NF_INET_POST_ROUTING, skb, NULL,
 		       skb_dst(skb)->dev, __xfrm6_output);
diff --git a/net/openvswitch/vport-gre.c b/net/openvswitch/vport-gre.c
index a3d6951..ebb6e244 100644
--- a/net/openvswitch/vport-gre.c
+++ b/net/openvswitch/vport-gre.c
@@ -174,7 +174,7 @@
 
 	skb->local_df = 1;
 
-	return iptunnel_xmit(rt, skb, fl.saddr,
+	return iptunnel_xmit(skb->sk, rt, skb, fl.saddr,
 			     OVS_CB(skb)->tun_key->ipv4_dst, IPPROTO_GRE,
 			     OVS_CB(skb)->tun_key->ipv4_tos,
 			     OVS_CB(skb)->tun_key->ipv4_ttl, df, false);
diff --git a/net/xfrm/xfrm_policy.c b/net/xfrm/xfrm_policy.c
index f02f511..c08fbd1 100644
--- a/net/xfrm/xfrm_policy.c
+++ b/net/xfrm/xfrm_policy.c
@@ -1842,7 +1842,7 @@
 	xfrm_pol_put(pol);
 }
 
-static int xdst_queue_output(struct sk_buff *skb)
+static int xdst_queue_output(struct sock *sk, struct sk_buff *skb)
 {
 	unsigned long sched_next;
 	struct dst_entry *dst = skb_dst(skb);