netfilter: Pass socket pointer down through okfn().

On the output paths in particular, we have to sometimes deal with two
socket contexts.  First, and usually skb->sk, is the local socket that
generated the frame.

And second, is potentially the socket used to control a tunneling
socket, such as one the encapsulates using UDP.

We do not want to disassociate skb->sk when encapsulating in order
to fix this, because that would break socket memory accounting.

The most extreme case where this can cause huge problems is an
AF_PACKET socket transmitting over a vxlan device.  We hit code
paths doing checks that assume they are dealing with an ipv4
socket, but are actually operating upon the AF_PACKET one.

Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/xfrm/xfrm_output.c b/net/xfrm/xfrm_output.c
index 7c53285..fbcedbe 100644
--- a/net/xfrm/xfrm_output.c
+++ b/net/xfrm/xfrm_output.c
@@ -19,7 +19,7 @@
 #include <net/dst.h>
 #include <net/xfrm.h>
 
-static int xfrm_output2(struct sk_buff *skb);
+static int xfrm_output2(struct sock *sk, struct sk_buff *skb);
 
 static int xfrm_skb_check_space(struct sk_buff *skb)
 {
@@ -130,7 +130,7 @@
 			return dst_output(skb);
 
 		err = nf_hook(skb_dst(skb)->ops->family,
-			      NF_INET_POST_ROUTING, skb,
+			      NF_INET_POST_ROUTING, skb->sk, skb,
 			      NULL, skb_dst(skb)->dev, xfrm_output2);
 		if (unlikely(err != 1))
 			goto out;
@@ -144,12 +144,12 @@
 }
 EXPORT_SYMBOL_GPL(xfrm_output_resume);
 
-static int xfrm_output2(struct sk_buff *skb)
+static int xfrm_output2(struct sock *sk, struct sk_buff *skb)
 {
 	return xfrm_output_resume(skb, 1);
 }
 
-static int xfrm_output_gso(struct sk_buff *skb)
+static int xfrm_output_gso(struct sock *sk, struct sk_buff *skb)
 {
 	struct sk_buff *segs;
 
@@ -165,7 +165,7 @@
 		int err;
 
 		segs->next = NULL;
-		err = xfrm_output2(segs);
+		err = xfrm_output2(sk, segs);
 
 		if (unlikely(err)) {
 			kfree_skb_list(nskb);
@@ -178,13 +178,13 @@
 	return 0;
 }
 
-int xfrm_output(struct sk_buff *skb)
+int xfrm_output(struct sock *sk, struct sk_buff *skb)
 {
 	struct net *net = dev_net(skb_dst(skb)->dev);
 	int err;
 
 	if (skb_is_gso(skb))
-		return xfrm_output_gso(skb);
+		return xfrm_output_gso(sk, skb);
 
 	if (skb->ip_summed == CHECKSUM_PARTIAL) {
 		err = skb_checksum_help(skb);
@@ -195,7 +195,7 @@
 		}
 	}
 
-	return xfrm_output2(skb);
+	return xfrm_output2(sk, skb);
 }
 EXPORT_SYMBOL_GPL(xfrm_output);