net/ipv6: Pass skb to route lookup

IPv6 does path selection for multipath routes deep in the lookup
functions. The next patch adds L4 hash option and needs the skb
for the forward path. To get the skb to the relevant FIB lookup
functions it needs to go through the fib rules layer, so add a
lookup_data argument to the fib_lookup_arg struct.

Signed-off-by: David Ahern <dsahern@gmail.com>
Reviewed-by: Ido Schimmel <idosch@mellanox.com>
Reviewed-by: Nikolay Aleksandrov <nikolay@cumulusnetworks.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv6/anycast.c b/net/ipv6/anycast.c
index d7d0abc..c61718d 100644
--- a/net/ipv6/anycast.c
+++ b/net/ipv6/anycast.c
@@ -78,7 +78,7 @@ int ipv6_sock_ac_join(struct sock *sk, int ifindex, const struct in6_addr *addr)
 	if (ifindex == 0) {
 		struct rt6_info *rt;
 
-		rt = rt6_lookup(net, addr, NULL, 0, 0);
+		rt = rt6_lookup(net, addr, NULL, 0, NULL, 0);
 		if (rt) {
 			dev = rt->dst.dev;
 			ip6_rt_put(rt);
diff --git a/net/ipv6/fib6_rules.c b/net/ipv6/fib6_rules.c
index 04e5f52..00ef946 100644
--- a/net/ipv6/fib6_rules.c
+++ b/net/ipv6/fib6_rules.c
@@ -61,11 +61,13 @@ unsigned int fib6_rules_seq_read(struct net *net)
 }
 
 struct dst_entry *fib6_rule_lookup(struct net *net, struct flowi6 *fl6,
+				   const struct sk_buff *skb,
 				   int flags, pol_lookup_t lookup)
 {
 	if (net->ipv6.fib6_has_custom_rules) {
 		struct fib_lookup_arg arg = {
 			.lookup_ptr = lookup,
+			.lookup_data = skb,
 			.flags = FIB_LOOKUP_NOREF,
 		};
 
@@ -80,11 +82,11 @@ struct dst_entry *fib6_rule_lookup(struct net *net, struct flowi6 *fl6,
 	} else {
 		struct rt6_info *rt;
 
-		rt = lookup(net, net->ipv6.fib6_local_tbl, fl6, flags);
+		rt = lookup(net, net->ipv6.fib6_local_tbl, fl6, skb, flags);
 		if (rt != net->ipv6.ip6_null_entry && rt->dst.error != -EAGAIN)
 			return &rt->dst;
 		ip6_rt_put(rt);
-		rt = lookup(net, net->ipv6.fib6_main_tbl, fl6, flags);
+		rt = lookup(net, net->ipv6.fib6_main_tbl, fl6, skb, flags);
 		if (rt->dst.error != -EAGAIN)
 			return &rt->dst;
 		ip6_rt_put(rt);
@@ -130,7 +132,7 @@ static int fib6_rule_action(struct fib_rule *rule, struct flowi *flp,
 		goto out;
 	}
 
-	rt = lookup(net, table, flp6, flags);
+	rt = lookup(net, table, flp6, arg->lookup_data, flags);
 	if (rt != net->ipv6.ip6_null_entry) {
 		struct fib6_rule *r = (struct fib6_rule *)rule;
 
diff --git a/net/ipv6/icmp.c b/net/ipv6/icmp.c
index b0778d3..a5d9292 100644
--- a/net/ipv6/icmp.c
+++ b/net/ipv6/icmp.c
@@ -629,7 +629,8 @@ int ip6_err_gen_icmpv6_unreach(struct sk_buff *skb, int nhs, int type,
 	skb_pull(skb2, nhs);
 	skb_reset_network_header(skb2);
 
-	rt = rt6_lookup(dev_net(skb->dev), &ipv6_hdr(skb2)->saddr, NULL, 0, 0);
+	rt = rt6_lookup(dev_net(skb->dev), &ipv6_hdr(skb2)->saddr, NULL, 0,
+			skb, 0);
 
 	if (rt && rt->dst.dev)
 		skb2->dev = rt->dst.dev;
diff --git a/net/ipv6/ip6_fib.c b/net/ipv6/ip6_fib.c
index cab95cf..2f995e9 100644
--- a/net/ipv6/ip6_fib.c
+++ b/net/ipv6/ip6_fib.c
@@ -299,11 +299,12 @@ struct fib6_table *fib6_get_table(struct net *net, u32 id)
 }
 
 struct dst_entry *fib6_rule_lookup(struct net *net, struct flowi6 *fl6,
+				   const struct sk_buff *skb,
 				   int flags, pol_lookup_t lookup)
 {
 	struct rt6_info *rt;
 
-	rt = lookup(net, net->ipv6.fib6_main_tbl, fl6, flags);
+	rt = lookup(net, net->ipv6.fib6_main_tbl, fl6, skb, flags);
 	if (rt->dst.error == -EAGAIN) {
 		ip6_rt_put(rt);
 		rt = net->ipv6.ip6_null_entry;
diff --git a/net/ipv6/ip6_gre.c b/net/ipv6/ip6_gre.c
index 4f150a3..83c7766 100644
--- a/net/ipv6/ip6_gre.c
+++ b/net/ipv6/ip6_gre.c
@@ -1053,7 +1053,7 @@ static void ip6gre_tnl_link_config(struct ip6_tnl *t, int set_mtu)
 
 		struct rt6_info *rt = rt6_lookup(t->net,
 						 &p->raddr, &p->laddr,
-						 p->link, strict);
+						 p->link, NULL, strict);
 
 		if (!rt)
 			return;
diff --git a/net/ipv6/ip6_tunnel.c b/net/ipv6/ip6_tunnel.c
index 869e2e6..1124f310 100644
--- a/net/ipv6/ip6_tunnel.c
+++ b/net/ipv6/ip6_tunnel.c
@@ -679,7 +679,7 @@ ip6ip6_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
 
 		/* Try to guess incoming interface */
 		rt = rt6_lookup(dev_net(skb->dev), &ipv6_hdr(skb2)->saddr,
-				NULL, 0, 0);
+				NULL, 0, skb2, 0);
 
 		if (rt && rt->dst.dev)
 			skb2->dev = rt->dst.dev;
@@ -1444,7 +1444,7 @@ static void ip6_tnl_link_config(struct ip6_tnl *t)
 
 		struct rt6_info *rt = rt6_lookup(t->net,
 						 &p->raddr, &p->laddr,
-						 p->link, strict);
+						 p->link, NULL, strict);
 
 		if (!rt)
 			return;
diff --git a/net/ipv6/ip6_vti.c b/net/ipv6/ip6_vti.c
index c617ea17..a482b85 100644
--- a/net/ipv6/ip6_vti.c
+++ b/net/ipv6/ip6_vti.c
@@ -645,7 +645,7 @@ static void vti6_link_config(struct ip6_tnl *t)
 			      (IPV6_ADDR_MULTICAST | IPV6_ADDR_LINKLOCAL));
 		struct rt6_info *rt = rt6_lookup(t->net,
 						 &p->raddr, &p->laddr,
-						 p->link, strict);
+						 p->link, NULL, strict);
 
 		if (rt)
 			tdev = rt->dst.dev;
diff --git a/net/ipv6/mcast.c b/net/ipv6/mcast.c
index d9bb933..d1a0cef 100644
--- a/net/ipv6/mcast.c
+++ b/net/ipv6/mcast.c
@@ -165,7 +165,7 @@ int ipv6_sock_mc_join(struct sock *sk, int ifindex, const struct in6_addr *addr)
 
 	if (ifindex == 0) {
 		struct rt6_info *rt;
-		rt = rt6_lookup(net, addr, NULL, 0, 0);
+		rt = rt6_lookup(net, addr, NULL, 0, NULL, 0);
 		if (rt) {
 			dev = rt->dst.dev;
 			ip6_rt_put(rt);
@@ -254,7 +254,7 @@ static struct inet6_dev *ip6_mc_find_dev_rcu(struct net *net,
 	struct inet6_dev *idev = NULL;
 
 	if (ifindex == 0) {
-		struct rt6_info *rt = rt6_lookup(net, group, NULL, 0, 0);
+		struct rt6_info *rt = rt6_lookup(net, group, NULL, 0, NULL, 0);
 
 		if (rt) {
 			dev = rt->dst.dev;
diff --git a/net/ipv6/netfilter/ip6t_rpfilter.c b/net/ipv6/netfilter/ip6t_rpfilter.c
index 94deb69..910a273 100644
--- a/net/ipv6/netfilter/ip6t_rpfilter.c
+++ b/net/ipv6/netfilter/ip6t_rpfilter.c
@@ -53,7 +53,7 @@ static bool rpfilter_lookup_reverse6(struct net *net, const struct sk_buff *skb,
 		lookup_flags |= RT6_LOOKUP_F_IFACE;
 	}
 
-	rt = (void *) ip6_route_lookup(net, &fl6, lookup_flags);
+	rt = (void *)ip6_route_lookup(net, &fl6, skb, lookup_flags);
 	if (rt->dst.error)
 		goto out;
 
diff --git a/net/ipv6/netfilter/nft_fib_ipv6.c b/net/ipv6/netfilter/nft_fib_ipv6.c
index cc5174c..3230b3d 100644
--- a/net/ipv6/netfilter/nft_fib_ipv6.c
+++ b/net/ipv6/netfilter/nft_fib_ipv6.c
@@ -181,7 +181,8 @@ void nft_fib6_eval(const struct nft_expr *expr, struct nft_regs *regs,
 
 	*dest = 0;
  again:
-	rt = (void *)ip6_route_lookup(nft_net(pkt), &fl6, lookup_flags);
+	rt = (void *)ip6_route_lookup(nft_net(pkt), &fl6, pkt->skb,
+				      lookup_flags);
 	if (rt->dst.error)
 		goto put_rt_err;
 
diff --git a/net/ipv6/route.c b/net/ipv6/route.c
index 5c89af2..d2b8368 100644
--- a/net/ipv6/route.c
+++ b/net/ipv6/route.c
@@ -452,6 +452,7 @@ static bool rt6_check_expired(const struct rt6_info *rt)
 
 static struct rt6_info *rt6_multipath_select(struct rt6_info *match,
 					     struct flowi6 *fl6, int oif,
+					     const struct sk_buff *skb,
 					     int strict)
 {
 	struct rt6_info *sibling, *next_sibling;
@@ -460,7 +461,7 @@ static struct rt6_info *rt6_multipath_select(struct rt6_info *match,
 	 * case it will always be non-zero. Otherwise now is the time to do it.
 	 */
 	if (!fl6->mp_hash)
-		fl6->mp_hash = rt6_multipath_hash(fl6, NULL, NULL);
+		fl6->mp_hash = rt6_multipath_hash(fl6, skb, NULL);
 
 	if (fl6->mp_hash <= atomic_read(&match->rt6i_nh_upper_bound))
 		return match;
@@ -914,7 +915,9 @@ static bool ip6_hold_safe(struct net *net, struct rt6_info **prt,
 
 static struct rt6_info *ip6_pol_route_lookup(struct net *net,
 					     struct fib6_table *table,
-					     struct flowi6 *fl6, int flags)
+					     struct flowi6 *fl6,
+					     const struct sk_buff *skb,
+					     int flags)
 {
 	struct rt6_info *rt, *rt_cache;
 	struct fib6_node *fn;
@@ -929,8 +932,8 @@ static struct rt6_info *ip6_pol_route_lookup(struct net *net,
 		rt = rt6_device_match(net, rt, &fl6->saddr,
 				      fl6->flowi6_oif, flags);
 		if (rt->rt6i_nsiblings && fl6->flowi6_oif == 0)
-			rt = rt6_multipath_select(rt, fl6,
-						  fl6->flowi6_oif, flags);
+			rt = rt6_multipath_select(rt, fl6, fl6->flowi6_oif,
+						  skb, flags);
 	}
 	if (rt == net->ipv6.ip6_null_entry) {
 		fn = fib6_backtrack(fn, &fl6->saddr);
@@ -954,14 +957,15 @@ static struct rt6_info *ip6_pol_route_lookup(struct net *net,
 }
 
 struct dst_entry *ip6_route_lookup(struct net *net, struct flowi6 *fl6,
-				    int flags)
+				   const struct sk_buff *skb, int flags)
 {
-	return fib6_rule_lookup(net, fl6, flags, ip6_pol_route_lookup);
+	return fib6_rule_lookup(net, fl6, skb, flags, ip6_pol_route_lookup);
 }
 EXPORT_SYMBOL_GPL(ip6_route_lookup);
 
 struct rt6_info *rt6_lookup(struct net *net, const struct in6_addr *daddr,
-			    const struct in6_addr *saddr, int oif, int strict)
+			    const struct in6_addr *saddr, int oif,
+			    const struct sk_buff *skb, int strict)
 {
 	struct flowi6 fl6 = {
 		.flowi6_oif = oif,
@@ -975,7 +979,7 @@ struct rt6_info *rt6_lookup(struct net *net, const struct in6_addr *daddr,
 		flags |= RT6_LOOKUP_F_HAS_SADDR;
 	}
 
-	dst = fib6_rule_lookup(net, &fl6, flags, ip6_pol_route_lookup);
+	dst = fib6_rule_lookup(net, &fl6, skb, flags, ip6_pol_route_lookup);
 	if (dst->error == 0)
 		return (struct rt6_info *) dst;
 
@@ -1647,7 +1651,8 @@ void rt6_age_exceptions(struct rt6_info *rt,
 }
 
 struct rt6_info *ip6_pol_route(struct net *net, struct fib6_table *table,
-			       int oif, struct flowi6 *fl6, int flags)
+			       int oif, struct flowi6 *fl6,
+			       const struct sk_buff *skb, int flags)
 {
 	struct fib6_node *fn, *saved_fn;
 	struct rt6_info *rt, *rt_cache;
@@ -1669,7 +1674,7 @@ struct rt6_info *ip6_pol_route(struct net *net, struct fib6_table *table,
 redo_rt6_select:
 	rt = rt6_select(net, fn, oif, strict);
 	if (rt->rt6i_nsiblings)
-		rt = rt6_multipath_select(rt, fl6, oif, strict);
+		rt = rt6_multipath_select(rt, fl6, oif, skb, strict);
 	if (rt == net->ipv6.ip6_null_entry) {
 		fn = fib6_backtrack(fn, &fl6->saddr);
 		if (fn)
@@ -1768,20 +1773,25 @@ struct rt6_info *ip6_pol_route(struct net *net, struct fib6_table *table,
 }
 EXPORT_SYMBOL_GPL(ip6_pol_route);
 
-static struct rt6_info *ip6_pol_route_input(struct net *net, struct fib6_table *table,
-					    struct flowi6 *fl6, int flags)
+static struct rt6_info *ip6_pol_route_input(struct net *net,
+					    struct fib6_table *table,
+					    struct flowi6 *fl6,
+					    const struct sk_buff *skb,
+					    int flags)
 {
-	return ip6_pol_route(net, table, fl6->flowi6_iif, fl6, flags);
+	return ip6_pol_route(net, table, fl6->flowi6_iif, fl6, skb, flags);
 }
 
 struct dst_entry *ip6_route_input_lookup(struct net *net,
 					 struct net_device *dev,
-					 struct flowi6 *fl6, int flags)
+					 struct flowi6 *fl6,
+					 const struct sk_buff *skb,
+					 int flags)
 {
 	if (rt6_need_strict(&fl6->daddr) && dev->type != ARPHRD_PIMREG)
 		flags |= RT6_LOOKUP_F_IFACE;
 
-	return fib6_rule_lookup(net, fl6, flags, ip6_pol_route_input);
+	return fib6_rule_lookup(net, fl6, skb, flags, ip6_pol_route_input);
 }
 EXPORT_SYMBOL_GPL(ip6_route_input_lookup);
 
@@ -1876,13 +1886,17 @@ void ip6_route_input(struct sk_buff *skb)
 	if (unlikely(fl6.flowi6_proto == IPPROTO_ICMPV6))
 		fl6.mp_hash = rt6_multipath_hash(&fl6, skb, flkeys);
 	skb_dst_drop(skb);
-	skb_dst_set(skb, ip6_route_input_lookup(net, skb->dev, &fl6, flags));
+	skb_dst_set(skb,
+		    ip6_route_input_lookup(net, skb->dev, &fl6, skb, flags));
 }
 
-static struct rt6_info *ip6_pol_route_output(struct net *net, struct fib6_table *table,
-					     struct flowi6 *fl6, int flags)
+static struct rt6_info *ip6_pol_route_output(struct net *net,
+					     struct fib6_table *table,
+					     struct flowi6 *fl6,
+					     const struct sk_buff *skb,
+					     int flags)
 {
-	return ip6_pol_route(net, table, fl6->flowi6_oif, fl6, flags);
+	return ip6_pol_route(net, table, fl6->flowi6_oif, fl6, skb, flags);
 }
 
 struct dst_entry *ip6_route_output_flags(struct net *net, const struct sock *sk,
@@ -1910,7 +1924,7 @@ struct dst_entry *ip6_route_output_flags(struct net *net, const struct sock *sk,
 	else if (sk)
 		flags |= rt6_srcprefs2flags(inet6_sk(sk)->srcprefs);
 
-	return fib6_rule_lookup(net, fl6, flags, ip6_pol_route_output);
+	return fib6_rule_lookup(net, fl6, NULL, flags, ip6_pol_route_output);
 }
 EXPORT_SYMBOL_GPL(ip6_route_output_flags);
 
@@ -2159,6 +2173,7 @@ struct ip6rd_flowi {
 static struct rt6_info *__ip6_route_redirect(struct net *net,
 					     struct fib6_table *table,
 					     struct flowi6 *fl6,
+					     const struct sk_buff *skb,
 					     int flags)
 {
 	struct ip6rd_flowi *rdfl = (struct ip6rd_flowi *)fl6;
@@ -2232,8 +2247,9 @@ static struct rt6_info *__ip6_route_redirect(struct net *net,
 };
 
 static struct dst_entry *ip6_route_redirect(struct net *net,
-					const struct flowi6 *fl6,
-					const struct in6_addr *gateway)
+					    const struct flowi6 *fl6,
+					    const struct sk_buff *skb,
+					    const struct in6_addr *gateway)
 {
 	int flags = RT6_LOOKUP_F_HAS_SADDR;
 	struct ip6rd_flowi rdfl;
@@ -2241,7 +2257,7 @@ static struct dst_entry *ip6_route_redirect(struct net *net,
 	rdfl.fl6 = *fl6;
 	rdfl.gateway = *gateway;
 
-	return fib6_rule_lookup(net, &rdfl.fl6,
+	return fib6_rule_lookup(net, &rdfl.fl6, skb,
 				flags, __ip6_route_redirect);
 }
 
@@ -2261,7 +2277,7 @@ void ip6_redirect(struct sk_buff *skb, struct net *net, int oif, u32 mark,
 	fl6.flowlabel = ip6_flowinfo(iph);
 	fl6.flowi6_uid = uid;
 
-	dst = ip6_route_redirect(net, &fl6, &ipv6_hdr(skb)->saddr);
+	dst = ip6_route_redirect(net, &fl6, skb, &ipv6_hdr(skb)->saddr);
 	rt6_do_redirect(dst, NULL, skb);
 	dst_release(dst);
 }
@@ -2283,7 +2299,7 @@ void ip6_redirect_no_header(struct sk_buff *skb, struct net *net, int oif,
 	fl6.saddr = iph->daddr;
 	fl6.flowi6_uid = sock_net_uid(net, NULL);
 
-	dst = ip6_route_redirect(net, &fl6, &iph->saddr);
+	dst = ip6_route_redirect(net, &fl6, skb, &iph->saddr);
 	rt6_do_redirect(dst, NULL, skb);
 	dst_release(dst);
 }
@@ -2485,7 +2501,7 @@ static struct rt6_info *ip6_nh_lookup_table(struct net *net,
 		flags |= RT6_LOOKUP_F_HAS_SADDR;
 
 	flags |= RT6_LOOKUP_F_IGNORE_LINKSTATE;
-	rt = ip6_pol_route(net, table, cfg->fc_ifindex, &fl6, flags);
+	rt = ip6_pol_route(net, table, cfg->fc_ifindex, &fl6, NULL, flags);
 
 	/* if table lookup failed, fall back to full lookup */
 	if (rt == net->ipv6.ip6_null_entry) {
@@ -2548,7 +2564,7 @@ static int ip6_route_check_nh(struct net *net,
 	}
 
 	if (!grt)
-		grt = rt6_lookup(net, gw_addr, NULL, cfg->fc_ifindex, 1);
+		grt = rt6_lookup(net, gw_addr, NULL, cfg->fc_ifindex, NULL, 1);
 
 	if (!grt)
 		goto out;
@@ -4613,7 +4629,7 @@ static int inet6_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh,
 		if (!ipv6_addr_any(&fl6.saddr))
 			flags |= RT6_LOOKUP_F_HAS_SADDR;
 
-		dst = ip6_route_input_lookup(net, dev, &fl6, flags);
+		dst = ip6_route_input_lookup(net, dev, &fl6, NULL, flags);
 
 		rcu_read_unlock();
 	} else {
diff --git a/net/ipv6/seg6_local.c b/net/ipv6/seg6_local.c
index ba3767e..4572232 100644
--- a/net/ipv6/seg6_local.c
+++ b/net/ipv6/seg6_local.c
@@ -161,7 +161,7 @@ static void lookup_nexthop(struct sk_buff *skb, struct in6_addr *nhaddr,
 		fl6.flowi6_flags = FLOWI_FLAG_KNOWN_NH;
 
 	if (!tbl_id) {
-		dst = ip6_route_input_lookup(net, skb->dev, &fl6, flags);
+		dst = ip6_route_input_lookup(net, skb->dev, &fl6, skb, flags);
 	} else {
 		struct fib6_table *table;
 
@@ -169,7 +169,7 @@ static void lookup_nexthop(struct sk_buff *skb, struct in6_addr *nhaddr,
 		if (!table)
 			goto out;
 
-		rt = ip6_pol_route(net, table, 0, &fl6, flags);
+		rt = ip6_pol_route(net, table, 0, &fl6, skb, flags);
 		dst = &rt->dst;
 	}