net: sched: introduce per-egress action device callbacks

Introduce infrastructure that allows drivers to register callbacks that
are called whenever tc would offload inserted rule and specified device
acts as tc action egress device.

Signed-off-by: Jiri Pirko <jiri@mellanox.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/sched/act_api.c b/net/sched/act_api.c
index da6fa82..ac97db9 100644
--- a/net/sched/act_api.c
+++ b/net/sched/act_api.c
@@ -21,6 +21,8 @@
 #include <linux/kmod.h>
 #include <linux/err.h>
 #include <linux/module.h>
+#include <linux/rhashtable.h>
+#include <linux/list.h>
 #include <net/net_namespace.h>
 #include <net/sock.h>
 #include <net/sch_generic.h>
@@ -1249,8 +1251,226 @@ static int tc_dump_action(struct sk_buff *skb, struct netlink_callback *cb)
 	return skb->len;
 }
 
+struct tcf_action_net {
+	struct rhashtable egdev_ht;
+};
+
+static unsigned int tcf_action_net_id;
+
+struct tcf_action_egdev_cb {
+	struct list_head list;
+	tc_setup_cb_t *cb;
+	void *cb_priv;
+};
+
+struct tcf_action_egdev {
+	struct rhash_head ht_node;
+	const struct net_device *dev;
+	unsigned int refcnt;
+	struct list_head cb_list;
+};
+
+static const struct rhashtable_params tcf_action_egdev_ht_params = {
+	.key_offset = offsetof(struct tcf_action_egdev, dev),
+	.head_offset = offsetof(struct tcf_action_egdev, ht_node),
+	.key_len = sizeof(const struct net_device *),
+};
+
+static struct tcf_action_egdev *
+tcf_action_egdev_lookup(const struct net_device *dev)
+{
+	struct net *net = dev_net(dev);
+	struct tcf_action_net *tan = net_generic(net, tcf_action_net_id);
+
+	return rhashtable_lookup_fast(&tan->egdev_ht, &dev,
+				      tcf_action_egdev_ht_params);
+}
+
+static struct tcf_action_egdev *
+tcf_action_egdev_get(const struct net_device *dev)
+{
+	struct tcf_action_egdev *egdev;
+	struct tcf_action_net *tan;
+
+	egdev = tcf_action_egdev_lookup(dev);
+	if (egdev)
+		goto inc_ref;
+
+	egdev = kzalloc(sizeof(*egdev), GFP_KERNEL);
+	if (!egdev)
+		return NULL;
+	INIT_LIST_HEAD(&egdev->cb_list);
+	tan = net_generic(dev_net(dev), tcf_action_net_id);
+	rhashtable_insert_fast(&tan->egdev_ht, &egdev->ht_node,
+			       tcf_action_egdev_ht_params);
+
+inc_ref:
+	egdev->refcnt++;
+	return egdev;
+}
+
+static void tcf_action_egdev_put(struct tcf_action_egdev *egdev)
+{
+	struct tcf_action_net *tan;
+
+	if (--egdev->refcnt)
+		return;
+	tan = net_generic(dev_net(egdev->dev), tcf_action_net_id);
+	rhashtable_remove_fast(&tan->egdev_ht, &egdev->ht_node,
+			       tcf_action_egdev_ht_params);
+	kfree(egdev);
+}
+
+static struct tcf_action_egdev_cb *
+tcf_action_egdev_cb_lookup(struct tcf_action_egdev *egdev,
+			   tc_setup_cb_t *cb, void *cb_priv)
+{
+	struct tcf_action_egdev_cb *egdev_cb;
+
+	list_for_each_entry(egdev_cb, &egdev->cb_list, list)
+		if (egdev_cb->cb == cb && egdev_cb->cb_priv == cb_priv)
+			return egdev_cb;
+	return NULL;
+}
+
+static int tcf_action_egdev_cb_call(struct tcf_action_egdev *egdev,
+				    enum tc_setup_type type,
+				    void *type_data, bool err_stop)
+{
+	struct tcf_action_egdev_cb *egdev_cb;
+	int ok_count = 0;
+	int err;
+
+	list_for_each_entry(egdev_cb, &egdev->cb_list, list) {
+		err = egdev_cb->cb(type, type_data, egdev_cb->cb_priv);
+		if (err) {
+			if (err_stop)
+				return err;
+		} else {
+			ok_count++;
+		}
+	}
+	return ok_count;
+}
+
+static int tcf_action_egdev_cb_add(struct tcf_action_egdev *egdev,
+				   tc_setup_cb_t *cb, void *cb_priv)
+{
+	struct tcf_action_egdev_cb *egdev_cb;
+
+	egdev_cb = tcf_action_egdev_cb_lookup(egdev, cb, cb_priv);
+	if (WARN_ON(egdev_cb))
+		return -EEXIST;
+	egdev_cb = kzalloc(sizeof(*egdev_cb), GFP_KERNEL);
+	if (!egdev_cb)
+		return -ENOMEM;
+	egdev_cb->cb = cb;
+	egdev_cb->cb_priv = cb_priv;
+	list_add(&egdev_cb->list, &egdev->cb_list);
+	return 0;
+}
+
+static void tcf_action_egdev_cb_del(struct tcf_action_egdev *egdev,
+				    tc_setup_cb_t *cb, void *cb_priv)
+{
+	struct tcf_action_egdev_cb *egdev_cb;
+
+	egdev_cb = tcf_action_egdev_cb_lookup(egdev, cb, cb_priv);
+	if (WARN_ON(!egdev_cb))
+		return;
+	list_del(&egdev_cb->list);
+	kfree(egdev_cb);
+}
+
+static int __tc_setup_cb_egdev_register(const struct net_device *dev,
+					tc_setup_cb_t *cb, void *cb_priv)
+{
+	struct tcf_action_egdev *egdev = tcf_action_egdev_get(dev);
+	int err;
+
+	if (!egdev)
+		return -ENOMEM;
+	err = tcf_action_egdev_cb_add(egdev, cb, cb_priv);
+	if (err)
+		goto err_cb_add;
+	return 0;
+
+err_cb_add:
+	tcf_action_egdev_put(egdev);
+	return err;
+}
+int tc_setup_cb_egdev_register(const struct net_device *dev,
+			       tc_setup_cb_t *cb, void *cb_priv)
+{
+	int err;
+
+	rtnl_lock();
+	err = __tc_setup_cb_egdev_register(dev, cb, cb_priv);
+	rtnl_unlock();
+	return err;
+}
+EXPORT_SYMBOL_GPL(tc_setup_cb_egdev_register);
+
+static void __tc_setup_cb_egdev_unregister(const struct net_device *dev,
+					   tc_setup_cb_t *cb, void *cb_priv)
+{
+	struct tcf_action_egdev *egdev = tcf_action_egdev_lookup(dev);
+
+	if (WARN_ON(!egdev))
+		return;
+	tcf_action_egdev_cb_del(egdev, cb, cb_priv);
+	tcf_action_egdev_put(egdev);
+}
+void tc_setup_cb_egdev_unregister(const struct net_device *dev,
+				  tc_setup_cb_t *cb, void *cb_priv)
+{
+	rtnl_lock();
+	__tc_setup_cb_egdev_unregister(dev, cb, cb_priv);
+	rtnl_unlock();
+}
+EXPORT_SYMBOL_GPL(tc_setup_cb_egdev_unregister);
+
+int tc_setup_cb_egdev_call(const struct net_device *dev,
+			   enum tc_setup_type type, void *type_data,
+			   bool err_stop)
+{
+	struct tcf_action_egdev *egdev = tcf_action_egdev_lookup(dev);
+
+	if (!egdev)
+		return 0;
+	return tcf_action_egdev_cb_call(egdev, type, type_data, err_stop);
+}
+EXPORT_SYMBOL_GPL(tc_setup_cb_egdev_call);
+
+static __net_init int tcf_action_net_init(struct net *net)
+{
+	struct tcf_action_net *tan = net_generic(net, tcf_action_net_id);
+
+	return rhashtable_init(&tan->egdev_ht, &tcf_action_egdev_ht_params);
+}
+
+static void __net_exit tcf_action_net_exit(struct net *net)
+{
+	struct tcf_action_net *tan = net_generic(net, tcf_action_net_id);
+
+	rhashtable_destroy(&tan->egdev_ht);
+}
+
+static struct pernet_operations tcf_action_net_ops = {
+	.init = tcf_action_net_init,
+	.exit = tcf_action_net_exit,
+	.id = &tcf_action_net_id,
+	.size = sizeof(struct tcf_action_net),
+};
+
 static int __init tc_action_init(void)
 {
+	int err;
+
+	err = register_pernet_subsys(&tcf_action_net_ops);
+	if (err)
+		return err;
+
 	rtnl_register(PF_UNSPEC, RTM_NEWACTION, tc_ctl_action, NULL, 0);
 	rtnl_register(PF_UNSPEC, RTM_DELACTION, tc_ctl_action, NULL, 0);
 	rtnl_register(PF_UNSPEC, RTM_GETACTION, tc_ctl_action, tc_dump_action,
diff --git a/net/sched/cls_api.c b/net/sched/cls_api.c
index 450873b..99f9432 100644
--- a/net/sched/cls_api.c
+++ b/net/sched/cls_api.c
@@ -1026,6 +1026,36 @@ int tcf_exts_get_dev(struct net_device *dev, struct tcf_exts *exts,
 }
 EXPORT_SYMBOL(tcf_exts_get_dev);
 
+int tcf_exts_egdev_cb_call(struct tcf_exts *exts, enum tc_setup_type type,
+			   void *type_data, bool err_stop)
+{
+	int ok_count = 0;
+#ifdef CONFIG_NET_CLS_ACT
+	const struct tc_action *a;
+	struct net_device *dev;
+	LIST_HEAD(actions);
+	int ret;
+
+	if (!tcf_exts_has_actions(exts))
+		return 0;
+
+	tcf_exts_to_list(exts, &actions);
+	list_for_each_entry(a, &actions, list) {
+		if (!a->ops->get_dev)
+			continue;
+		dev = a->ops->get_dev(a);
+		if (!dev || !tc_can_offload(dev))
+			continue;
+		ret = tc_setup_cb_egdev_call(dev, type, type_data, err_stop);
+		if (ret < 0)
+			return ret;
+		ok_count += ret;
+	}
+#endif
+	return ok_count;
+}
+EXPORT_SYMBOL(tcf_exts_egdev_cb_call);
+
 static int __init tc_filter_init(void)
 {
 	rtnl_register(PF_UNSPEC, RTM_NEWTFILTER, tc_ctl_tfilter, NULL, 0);