netfilter: synproxy: only register hooks when needed

Defer registration of the synproxy hooks until the first SYNPROXY rule is
added.  Also means we only register hooks in namespaces that need it.

Signed-off-by: Florian Westphal <fw@strlen.de>
Signed-off-by: Pablo Neira Ayuso <pablo@netfilter.org>
diff --git a/include/net/netfilter/nf_conntrack_synproxy.h b/include/net/netfilter/nf_conntrack_synproxy.h
index b0ca402..a2fcb52 100644
--- a/include/net/netfilter/nf_conntrack_synproxy.h
+++ b/include/net/netfilter/nf_conntrack_synproxy.h
@@ -52,6 +52,8 @@ struct synproxy_stats {
 struct synproxy_net {
 	struct nf_conn			*tmpl;
 	struct synproxy_stats __percpu	*stats;
+	unsigned int			hook_ref4;
+	unsigned int			hook_ref6;
 };
 
 extern unsigned int synproxy_net_id;
diff --git a/net/ipv4/netfilter/ipt_SYNPROXY.c b/net/ipv4/netfilter/ipt_SYNPROXY.c
index 3240a26..c308ee0 100644
--- a/net/ipv4/netfilter/ipt_SYNPROXY.c
+++ b/net/ipv4/netfilter/ipt_SYNPROXY.c
@@ -409,33 +409,6 @@ static unsigned int ipv4_synproxy_hook(void *priv,
 	return NF_ACCEPT;
 }
 
-static int synproxy_tg4_check(const struct xt_tgchk_param *par)
-{
-	const struct ipt_entry *e = par->entryinfo;
-
-	if (e->ip.proto != IPPROTO_TCP ||
-	    e->ip.invflags & XT_INV_PROTO)
-		return -EINVAL;
-
-	return nf_ct_netns_get(par->net, par->family);
-}
-
-static void synproxy_tg4_destroy(const struct xt_tgdtor_param *par)
-{
-	nf_ct_netns_put(par->net, par->family);
-}
-
-static struct xt_target synproxy_tg4_reg __read_mostly = {
-	.name		= "SYNPROXY",
-	.family		= NFPROTO_IPV4,
-	.hooks		= (1 << NF_INET_LOCAL_IN) | (1 << NF_INET_FORWARD),
-	.target		= synproxy_tg4,
-	.targetsize	= sizeof(struct xt_synproxy_info),
-	.checkentry	= synproxy_tg4_check,
-	.destroy	= synproxy_tg4_destroy,
-	.me		= THIS_MODULE,
-};
-
 static struct nf_hook_ops ipv4_synproxy_ops[] __read_mostly = {
 	{
 		.hook		= ipv4_synproxy_hook,
@@ -451,31 +424,63 @@ static struct nf_hook_ops ipv4_synproxy_ops[] __read_mostly = {
 	},
 };
 
-static int __init synproxy_tg4_init(void)
+static int synproxy_tg4_check(const struct xt_tgchk_param *par)
 {
+	struct synproxy_net *snet = synproxy_pernet(par->net);
+	const struct ipt_entry *e = par->entryinfo;
 	int err;
 
-	err = nf_register_hooks(ipv4_synproxy_ops,
-				ARRAY_SIZE(ipv4_synproxy_ops));
-	if (err < 0)
-		goto err1;
+	if (e->ip.proto != IPPROTO_TCP ||
+	    e->ip.invflags & XT_INV_PROTO)
+		return -EINVAL;
 
-	err = xt_register_target(&synproxy_tg4_reg);
-	if (err < 0)
-		goto err2;
+	err = nf_ct_netns_get(par->net, par->family);
+	if (err)
+		return err;
 
-	return 0;
+	if (snet->hook_ref4 == 0) {
+		err = nf_register_net_hooks(par->net, ipv4_synproxy_ops,
+					    ARRAY_SIZE(ipv4_synproxy_ops));
+		if (err) {
+			nf_ct_netns_put(par->net, par->family);
+			return err;
+		}
+	}
 
-err2:
-	nf_unregister_hooks(ipv4_synproxy_ops, ARRAY_SIZE(ipv4_synproxy_ops));
-err1:
+	snet->hook_ref4++;
 	return err;
 }
 
+static void synproxy_tg4_destroy(const struct xt_tgdtor_param *par)
+{
+	struct synproxy_net *snet = synproxy_pernet(par->net);
+
+	snet->hook_ref4--;
+	if (snet->hook_ref4 == 0)
+		nf_unregister_net_hooks(par->net, ipv4_synproxy_ops,
+					ARRAY_SIZE(ipv4_synproxy_ops));
+	nf_ct_netns_put(par->net, par->family);
+}
+
+static struct xt_target synproxy_tg4_reg __read_mostly = {
+	.name		= "SYNPROXY",
+	.family		= NFPROTO_IPV4,
+	.hooks		= (1 << NF_INET_LOCAL_IN) | (1 << NF_INET_FORWARD),
+	.target		= synproxy_tg4,
+	.targetsize	= sizeof(struct xt_synproxy_info),
+	.checkentry	= synproxy_tg4_check,
+	.destroy	= synproxy_tg4_destroy,
+	.me		= THIS_MODULE,
+};
+
+static int __init synproxy_tg4_init(void)
+{
+	return xt_register_target(&synproxy_tg4_reg);
+}
+
 static void __exit synproxy_tg4_exit(void)
 {
 	xt_unregister_target(&synproxy_tg4_reg);
-	nf_unregister_hooks(ipv4_synproxy_ops, ARRAY_SIZE(ipv4_synproxy_ops));
 }
 
 module_init(synproxy_tg4_init);
diff --git a/net/ipv6/netfilter/ip6t_SYNPROXY.c b/net/ipv6/netfilter/ip6t_SYNPROXY.c
index 4ef1ddd..1252537 100644
--- a/net/ipv6/netfilter/ip6t_SYNPROXY.c
+++ b/net/ipv6/netfilter/ip6t_SYNPROXY.c
@@ -430,34 +430,6 @@ static unsigned int ipv6_synproxy_hook(void *priv,
 	return NF_ACCEPT;
 }
 
-static int synproxy_tg6_check(const struct xt_tgchk_param *par)
-{
-	const struct ip6t_entry *e = par->entryinfo;
-
-	if (!(e->ipv6.flags & IP6T_F_PROTO) ||
-	    e->ipv6.proto != IPPROTO_TCP ||
-	    e->ipv6.invflags & XT_INV_PROTO)
-		return -EINVAL;
-
-	return nf_ct_netns_get(par->net, par->family);
-}
-
-static void synproxy_tg6_destroy(const struct xt_tgdtor_param *par)
-{
-	nf_ct_netns_put(par->net, par->family);
-}
-
-static struct xt_target synproxy_tg6_reg __read_mostly = {
-	.name		= "SYNPROXY",
-	.family		= NFPROTO_IPV6,
-	.hooks		= (1 << NF_INET_LOCAL_IN) | (1 << NF_INET_FORWARD),
-	.target		= synproxy_tg6,
-	.targetsize	= sizeof(struct xt_synproxy_info),
-	.checkentry	= synproxy_tg6_check,
-	.destroy	= synproxy_tg6_destroy,
-	.me		= THIS_MODULE,
-};
-
 static struct nf_hook_ops ipv6_synproxy_ops[] __read_mostly = {
 	{
 		.hook		= ipv6_synproxy_hook,
@@ -473,31 +445,64 @@ static struct nf_hook_ops ipv6_synproxy_ops[] __read_mostly = {
 	},
 };
 
-static int __init synproxy_tg6_init(void)
+static int synproxy_tg6_check(const struct xt_tgchk_param *par)
 {
+	struct synproxy_net *snet = synproxy_pernet(par->net);
+	const struct ip6t_entry *e = par->entryinfo;
 	int err;
 
-	err = nf_register_hooks(ipv6_synproxy_ops,
-				ARRAY_SIZE(ipv6_synproxy_ops));
-	if (err < 0)
-		goto err1;
+	if (!(e->ipv6.flags & IP6T_F_PROTO) ||
+	    e->ipv6.proto != IPPROTO_TCP ||
+	    e->ipv6.invflags & XT_INV_PROTO)
+		return -EINVAL;
 
-	err = xt_register_target(&synproxy_tg6_reg);
-	if (err < 0)
-		goto err2;
+	err = nf_ct_netns_get(par->net, par->family);
+	if (err)
+		return err;
 
-	return 0;
+	if (snet->hook_ref6 == 0) {
+		err = nf_register_net_hooks(par->net, ipv6_synproxy_ops,
+					    ARRAY_SIZE(ipv6_synproxy_ops));
+		if (err) {
+			nf_ct_netns_put(par->net, par->family);
+			return err;
+		}
+	}
 
-err2:
-	nf_unregister_hooks(ipv6_synproxy_ops, ARRAY_SIZE(ipv6_synproxy_ops));
-err1:
+	snet->hook_ref6++;
 	return err;
 }
 
+static void synproxy_tg6_destroy(const struct xt_tgdtor_param *par)
+{
+	struct synproxy_net *snet = synproxy_pernet(par->net);
+
+	snet->hook_ref6--;
+	if (snet->hook_ref6 == 0)
+		nf_unregister_net_hooks(par->net, ipv6_synproxy_ops,
+					ARRAY_SIZE(ipv6_synproxy_ops));
+	nf_ct_netns_put(par->net, par->family);
+}
+
+static struct xt_target synproxy_tg6_reg __read_mostly = {
+	.name		= "SYNPROXY",
+	.family		= NFPROTO_IPV6,
+	.hooks		= (1 << NF_INET_LOCAL_IN) | (1 << NF_INET_FORWARD),
+	.target		= synproxy_tg6,
+	.targetsize	= sizeof(struct xt_synproxy_info),
+	.checkentry	= synproxy_tg6_check,
+	.destroy	= synproxy_tg6_destroy,
+	.me		= THIS_MODULE,
+};
+
+static int __init synproxy_tg6_init(void)
+{
+	return xt_register_target(&synproxy_tg6_reg);
+}
+
 static void __exit synproxy_tg6_exit(void)
 {
 	xt_unregister_target(&synproxy_tg6_reg);
-	nf_unregister_hooks(ipv6_synproxy_ops, ARRAY_SIZE(ipv6_synproxy_ops));
 }
 
 module_init(synproxy_tg6_init);