Merge branch 'bpf-sockmap-listen'

Jakub Sitnicki says:

====================
This patch set turns SOCK{MAP,HASH} into generic collections for TCP
sockets, both listening and established. Adding support for listening
sockets enables us to use these BPF map types with reuseport BPF programs.

Why? SOCKMAP and SOCKHASH, in comparison to REUSEPORT_SOCKARRAY, allow
the socket to be in more than one map at the same time.

Having a BPF map type that can hold listening sockets, and gracefully
co-exist with reuseport BPF is important if, in the future, we want
BPF programs that run at socket lookup time [0]. Cover letter for v1 of
this series tells the full story of how we got here [1].

Although SOCK{MAP,HASH} are not a drop-in replacement for SOCKARRAY just
yet, because UDP support is lacking, it's a step in this direction. We're
working with Lorenz on extending SOCK{MAP,HASH} to hold UDP sockets, and
expect to post RFC series for sockmap + UDP in the near future.

I've dropped Acks from all patches that have been touched since v6.

The audit for missing READ_ONCE annotations for access to sk_prot is
ongoing. Thus far I've found one location specific to TCP listening sockets
that needed annotating. This got fixed it in this iteration. I wonder if
sparse checker could be put to work to identify places where we have
sk_prot access while not holding sk_lock...

The patch series depends on another one, posted earlier [2], that has
been split out of it.

v6 -> v7:

- Extended the series to cover SOCKHASH. (patches 4-8, 10-11) (John)

- Rebased onto recent bpf-next. Resolved conflicts in recent fixes to
  sk_state checks on sockmap/sockhash update path. (patch 4)

- Added missing READ_ONCE annotation in sock_copy. (patch 1)

- Split out patches that simplify sk_psock_restore_proto [2].

v5 -> v6:

- Added a fix-up for patch 1 which I forgot to commit in v5. Sigh.

v4 -> v5:

- Rebase onto recent bpf-next to resolve conflicts. (Daniel)

v3 -> v4:

- Make tcp_bpf_clone parameter names consistent across function declaration
  and definition. (Martin)

- Use sock_map_redirect_okay helper everywhere we need to take a different
  action for listening sockets. (Lorenz)

- Expand comment explaining the need for a callback from reuseport to
  sockarray code in reuseport_detach_sock. (Martin)

- Mention the possibility of using a u64 counter for reuseport IDs in the
  future in the description for patch 10. (Martin)

v2 -> v3:

- Generate reuseport ID when group is created. Please see patch 10
  description for details. (Martin)

- Fix the build when CONFIG_NET_SOCK_MSG is not selected by either
  CONFIG_BPF_STREAM_PARSER or CONFIG_TLS. (kbuild bot & John)

- Allow updating sockmap from BPF on BPF_SOCK_OPS_TCP_LISTEN_CB callback. An
  oversight in previous iterations. Users may want to populate the sockmap with
  listening sockets from BPF as well.

- Removed RCU read lock assertion in sock_map_lookup_sys. (Martin)

- Get rid of a warning when child socket was cloned with parent's psock
  state. (John)

- Check for tcp_bpf_unhash rather than tcp_bpf_recvmsg when deciding if
  sk_proto needs restoring on clone. Check for recvmsg in the context of
  listening socket cloning was confusing. (Martin)

- Consolidate sock_map_sk_is_suitable with sock_map_update_okay. This led
  to adding dedicated predicates for sockhash. Update self-tests
  accordingly. (John)

- Annotate unlikely branch in bpf_{sk,msg}_redirect_map when socket isn't
  in a map, or isn't a valid redirect target. (John)

- Document paired READ/WRITE_ONCE annotations and cover shared access in
  more detail in patch 2 description. (John)

- Correct a couple of log messages in sockmap_listen self-tests so the
  message reflects the actual failure.

- Rework reuseport tests from sockmap_listen suite so that ENOENT error
  from bpf_sk_select_reuseport handler does not happen on happy path.

v1 -> v2:

- af_ops->syn_recv_sock callback is no longer overridden and burdened with
  restoring sk_prot and clearing sk_user_data in the child socket. As child
  socket is already hashed when syn_recv_sock returns, it is too late to
  put it in the right state. Instead patches 3 & 4 address restoring
  sk_prot and clearing sk_user_data before we hash the child socket.
  (Pointed out by Martin Lau)

- Annotate shared access to sk->sk_prot with READ_ONCE/WRITE_ONCE macros as
  we write to it from sk_msg while socket might be getting cloned on
  another CPU. (Suggested by John Fastabend)

- Convert tests for SOCKMAP holding listening sockets to return-on-error
  style, and hook them up to test_progs. Also use BPF skeleton for setup.
  Add new tests to cover the race scenario discovered during v1 review.

RFC -> v1:

- Switch from overriding proto->accept to af_ops->syn_recv_sock, which
  happens earlier. Clearing the psock state after accept() does not work
  for child sockets that become orphaned (never got accepted). v4-mapped
  sockets need special care.

- Return the socket cookie on SOCKMAP lookup from syscall to be on par with
  REUSEPORT_SOCKARRAY. Requires SOCKMAP to take u64 on lookup/update from
  syscall.

- Make bpf_sk_redirect_map (ingress) and bpf_msg_redirect_map (egress)
  SOCKMAP helpers fail when target socket is a listening one.

- Make bpf_sk_select_reuseport helper fail when target is a TCP established
  socket.

- Teach libbpf to recognize SK_REUSEPORT program type from section name.

- Add a dedicated set of tests for SOCKMAP holding listening sockets,
  covering map operations, overridden socket callbacks, and BPF helpers.

[0] https://lore.kernel.org/bpf/20190828072250.29828-1-jakub@cloudflare.com/
[1] https://lore.kernel.org/bpf/20191123110751.6729-1-jakub@cloudflare.com/
[2] https://lore.kernel.org/bpf/20200217121530.754315-1-jakub@cloudflare.com/
====================

Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index d90ef61..112765b 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -352,7 +352,8 @@ static inline void sk_psock_update_proto(struct sock *sk,
 	psock->saved_write_space = sk->sk_write_space;
 
 	psock->sk_proto = sk->sk_prot;
-	sk->sk_prot = ops;
+	/* Pairs with lockless read in sk_clone_lock() */
+	WRITE_ONCE(sk->sk_prot, ops);
 }
 
 static inline void sk_psock_restore_proto(struct sock *sk,
diff --git a/include/net/sock.h b/include/net/sock.h
index 02162b0..9f37fdf 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -502,10 +502,43 @@ enum sk_pacing {
 	SK_PACING_FQ		= 2,
 };
 
+/* Pointer stored in sk_user_data might not be suitable for copying
+ * when cloning the socket. For instance, it can point to a reference
+ * counted object. sk_user_data bottom bit is set if pointer must not
+ * be copied.
+ */
+#define SK_USER_DATA_NOCOPY	1UL
+#define SK_USER_DATA_PTRMASK	~(SK_USER_DATA_NOCOPY)
+
+/**
+ * sk_user_data_is_nocopy - Test if sk_user_data pointer must not be copied
+ * @sk: socket
+ */
+static inline bool sk_user_data_is_nocopy(const struct sock *sk)
+{
+	return ((uintptr_t)sk->sk_user_data & SK_USER_DATA_NOCOPY);
+}
+
 #define __sk_user_data(sk) ((*((void __rcu **)&(sk)->sk_user_data)))
 
-#define rcu_dereference_sk_user_data(sk)	rcu_dereference(__sk_user_data((sk)))
-#define rcu_assign_sk_user_data(sk, ptr)	rcu_assign_pointer(__sk_user_data((sk)), ptr)
+#define rcu_dereference_sk_user_data(sk)				\
+({									\
+	void *__tmp = rcu_dereference(__sk_user_data((sk)));		\
+	(void *)((uintptr_t)__tmp & SK_USER_DATA_PTRMASK);		\
+})
+#define rcu_assign_sk_user_data(sk, ptr)				\
+({									\
+	uintptr_t __tmp = (uintptr_t)(ptr);				\
+	WARN_ON_ONCE(__tmp & ~SK_USER_DATA_PTRMASK);			\
+	rcu_assign_pointer(__sk_user_data((sk)), __tmp);		\
+})
+#define rcu_assign_sk_user_data_nocopy(sk, ptr)				\
+({									\
+	uintptr_t __tmp = (uintptr_t)(ptr);				\
+	WARN_ON_ONCE(__tmp & ~SK_USER_DATA_PTRMASK);			\
+	rcu_assign_pointer(__sk_user_data((sk)),			\
+			   __tmp | SK_USER_DATA_NOCOPY);		\
+})
 
 /*
  * SK_CAN_REUSE and SK_NO_REUSE on a socket mean that the socket is OK
diff --git a/include/net/sock_reuseport.h b/include/net/sock_reuseport.h
index 43f4a81..3ecaa15 100644
--- a/include/net/sock_reuseport.h
+++ b/include/net/sock_reuseport.h
@@ -55,6 +55,4 @@ static inline bool reuseport_has_conns(struct sock *sk, bool set)
 	return ret;
 }
 
-int reuseport_get_id(struct sock_reuseport *reuse);
-
 #endif  /* _SOCK_REUSEPORT_H */
diff --git a/include/net/tcp.h b/include/net/tcp.h
index a5ea27d..07f947c 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -2203,6 +2203,13 @@ int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
 		    int nonblock, int flags, int *addr_len);
 int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
 		      struct msghdr *msg, int len, int flags);
+#ifdef CONFIG_NET_SOCK_MSG
+void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
+#else
+static inline void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
+{
+}
+#endif
 
 /* Call BPF_SOCK_OPS program that returns an int. If the return value
  * is < 0, then the BPF op failed (for example if the loaded BPF
diff --git a/kernel/bpf/reuseport_array.c b/kernel/bpf/reuseport_array.c
index 50c083b..01badd3 100644
--- a/kernel/bpf/reuseport_array.c
+++ b/kernel/bpf/reuseport_array.c
@@ -305,11 +305,6 @@ int bpf_fd_reuseport_array_update_elem(struct bpf_map *map, void *key,
 	if (err)
 		goto put_file_unlock;
 
-	/* Ensure reuse->reuseport_id is set */
-	err = reuseport_get_id(reuse);
-	if (err < 0)
-		goto put_file_unlock;
-
 	WRITE_ONCE(nsk->sk_user_data, &array->ptrs[index]);
 	rcu_assign_pointer(array->ptrs[index], nsk);
 	free_osk = osk;
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 1cc945d..6d15dfb 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -3693,14 +3693,16 @@ static int check_map_func_compatibility(struct bpf_verifier_env *env,
 		if (func_id != BPF_FUNC_sk_redirect_map &&
 		    func_id != BPF_FUNC_sock_map_update &&
 		    func_id != BPF_FUNC_map_delete_elem &&
-		    func_id != BPF_FUNC_msg_redirect_map)
+		    func_id != BPF_FUNC_msg_redirect_map &&
+		    func_id != BPF_FUNC_sk_select_reuseport)
 			goto error;
 		break;
 	case BPF_MAP_TYPE_SOCKHASH:
 		if (func_id != BPF_FUNC_sk_redirect_hash &&
 		    func_id != BPF_FUNC_sock_hash_update &&
 		    func_id != BPF_FUNC_map_delete_elem &&
-		    func_id != BPF_FUNC_msg_redirect_hash)
+		    func_id != BPF_FUNC_msg_redirect_hash &&
+		    func_id != BPF_FUNC_sk_select_reuseport)
 			goto error;
 		break;
 	case BPF_MAP_TYPE_REUSEPORT_SOCKARRAY:
@@ -3774,7 +3776,9 @@ static int check_map_func_compatibility(struct bpf_verifier_env *env,
 			goto error;
 		break;
 	case BPF_FUNC_sk_select_reuseport:
-		if (map->map_type != BPF_MAP_TYPE_REUSEPORT_SOCKARRAY)
+		if (map->map_type != BPF_MAP_TYPE_REUSEPORT_SOCKARRAY &&
+		    map->map_type != BPF_MAP_TYPE_SOCKMAP &&
+		    map->map_type != BPF_MAP_TYPE_SOCKHASH)
 			goto error;
 		break;
 	case BPF_FUNC_map_peek_elem:
diff --git a/net/core/filter.c b/net/core/filter.c
index c180871..925b23d 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -8620,6 +8620,7 @@ struct sock *bpf_run_sk_reuseport(struct sock_reuseport *reuse, struct sock *sk,
 BPF_CALL_4(sk_select_reuseport, struct sk_reuseport_kern *, reuse_kern,
 	   struct bpf_map *, map, void *, key, u32, flags)
 {
+	bool is_sockarray = map->map_type == BPF_MAP_TYPE_REUSEPORT_SOCKARRAY;
 	struct sock_reuseport *reuse;
 	struct sock *selected_sk;
 
@@ -8628,26 +8629,20 @@ BPF_CALL_4(sk_select_reuseport, struct sk_reuseport_kern *, reuse_kern,
 		return -ENOENT;
 
 	reuse = rcu_dereference(selected_sk->sk_reuseport_cb);
-	if (!reuse)
-		/* selected_sk is unhashed (e.g. by close()) after the
-		 * above map_lookup_elem().  Treat selected_sk has already
-		 * been removed from the map.
+	if (!reuse) {
+		/* reuseport_array has only sk with non NULL sk_reuseport_cb.
+		 * The only (!reuse) case here is - the sk has already been
+		 * unhashed (e.g. by close()), so treat it as -ENOENT.
+		 *
+		 * Other maps (e.g. sock_map) do not provide this guarantee and
+		 * the sk may never be in the reuseport group to begin with.
 		 */
-		return -ENOENT;
+		return is_sockarray ? -ENOENT : -EINVAL;
+	}
 
 	if (unlikely(reuse->reuseport_id != reuse_kern->reuseport_id)) {
-		struct sock *sk;
+		struct sock *sk = reuse_kern->sk;
 
-		if (unlikely(!reuse_kern->reuseport_id))
-			/* There is a small race between adding the
-			 * sk to the map and setting the
-			 * reuse_kern->reuseport_id.
-			 * Treat it as the sk has not been added to
-			 * the bpf map yet.
-			 */
-			return -ENOENT;
-
-		sk = reuse_kern->sk;
 		if (sk->sk_protocol != selected_sk->sk_protocol)
 			return -EPROTOTYPE;
 		else if (sk->sk_family != selected_sk->sk_family)
diff --git a/net/core/skmsg.c b/net/core/skmsg.c
index ded2d52..eeb28cb 100644
--- a/net/core/skmsg.c
+++ b/net/core/skmsg.c
@@ -512,7 +512,7 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node)
 	sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED);
 	refcount_set(&psock->refcnt, 1);
 
-	rcu_assign_sk_user_data(sk, psock);
+	rcu_assign_sk_user_data_nocopy(sk, psock);
 	sock_hold(sk);
 
 	return psock;
diff --git a/net/core/sock.c b/net/core/sock.c
index a4c8fac..e4af4dbc 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -1572,13 +1572,14 @@ static inline void sock_lock_init(struct sock *sk)
  */
 static void sock_copy(struct sock *nsk, const struct sock *osk)
 {
+	const struct proto *prot = READ_ONCE(osk->sk_prot);
 #ifdef CONFIG_SECURITY_NETWORK
 	void *sptr = nsk->sk_security;
 #endif
 	memcpy(nsk, osk, offsetof(struct sock, sk_dontcopy_begin));
 
 	memcpy(&nsk->sk_dontcopy_end, &osk->sk_dontcopy_end,
-	       osk->sk_prot->obj_size - offsetof(struct sock, sk_dontcopy_end));
+	       prot->obj_size - offsetof(struct sock, sk_dontcopy_end));
 
 #ifdef CONFIG_SECURITY_NETWORK
 	nsk->sk_security = sptr;
@@ -1792,16 +1793,17 @@ static void sk_init_common(struct sock *sk)
  */
 struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority)
 {
+	struct proto *prot = READ_ONCE(sk->sk_prot);
 	struct sock *newsk;
 	bool is_charged = true;
 
-	newsk = sk_prot_alloc(sk->sk_prot, priority, sk->sk_family);
+	newsk = sk_prot_alloc(prot, priority, sk->sk_family);
 	if (newsk != NULL) {
 		struct sk_filter *filter;
 
 		sock_copy(newsk, sk);
 
-		newsk->sk_prot_creator = sk->sk_prot;
+		newsk->sk_prot_creator = prot;
 
 		/* SANITY */
 		if (likely(newsk->sk_net_refcnt))
@@ -1863,6 +1865,12 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority)
 			goto out;
 		}
 
+		/* Clear sk_user_data if parent had the pointer tagged
+		 * as not suitable for copying when cloning.
+		 */
+		if (sk_user_data_is_nocopy(newsk))
+			RCU_INIT_POINTER(newsk->sk_user_data, NULL);
+
 		newsk->sk_err	   = 0;
 		newsk->sk_err_soft = 0;
 		newsk->sk_priority = 0;
diff --git a/net/core/sock_map.c b/net/core/sock_map.c
index 3a7a96a..2e0f465 100644
--- a/net/core/sock_map.c
+++ b/net/core/sock_map.c
@@ -10,6 +10,7 @@
 #include <linux/skmsg.h>
 #include <linux/list.h>
 #include <linux/jhash.h>
+#include <linux/sock_diag.h>
 
 struct bpf_stab {
 	struct bpf_map map;
@@ -31,7 +32,8 @@ static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
 		return ERR_PTR(-EPERM);
 	if (attr->max_entries == 0 ||
 	    attr->key_size    != 4 ||
-	    attr->value_size  != 4 ||
+	    (attr->value_size != sizeof(u32) &&
+	     attr->value_size != sizeof(u64)) ||
 	    attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
 		return ERR_PTR(-EINVAL);
 
@@ -228,6 +230,30 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
 	return ret;
 }
 
+static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
+{
+	struct sk_psock *psock;
+	int ret;
+
+	psock = sk_psock_get_checked(sk);
+	if (IS_ERR(psock))
+		return PTR_ERR(psock);
+
+	if (psock) {
+		tcp_bpf_reinit(sk);
+		return 0;
+	}
+
+	psock = sk_psock_init(sk, map->numa_node);
+	if (!psock)
+		return -ENOMEM;
+
+	ret = tcp_bpf_init(sk);
+	if (ret < 0)
+		sk_psock_put(sk, psock);
+	return ret;
+}
+
 static void sock_map_free(struct bpf_map *map)
 {
 	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
@@ -275,7 +301,22 @@ static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
 
 static void *sock_map_lookup(struct bpf_map *map, void *key)
 {
-	return ERR_PTR(-EOPNOTSUPP);
+	return __sock_map_lookup_elem(map, *(u32 *)key);
+}
+
+static void *sock_map_lookup_sys(struct bpf_map *map, void *key)
+{
+	struct sock *sk;
+
+	if (map->value_size != sizeof(u64))
+		return ERR_PTR(-ENOSPC);
+
+	sk = __sock_map_lookup_elem(map, *(u32 *)key);
+	if (!sk)
+		return ERR_PTR(-ENOENT);
+
+	sock_gen_cookie(sk);
+	return &sk->sk_cookie;
 }
 
 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
@@ -334,6 +375,11 @@ static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
 	return 0;
 }
 
+static bool sock_map_redirect_allowed(const struct sock *sk)
+{
+	return sk->sk_state != TCP_LISTEN;
+}
+
 static int sock_map_update_common(struct bpf_map *map, u32 idx,
 				  struct sock *sk, u64 flags)
 {
@@ -356,7 +402,14 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
 	if (!link)
 		return -ENOMEM;
 
-	ret = sock_map_link(map, &stab->progs, sk);
+	/* Only sockets we can redirect into/from in BPF need to hold
+	 * refs to parser/verdict progs and have their sk_data_ready
+	 * and sk_write_space callbacks overridden.
+	 */
+	if (sock_map_redirect_allowed(sk))
+		ret = sock_map_link(map, &stab->progs, sk);
+	else
+		ret = sock_map_link_no_progs(map, sk);
 	if (ret < 0)
 		goto out_free;
 
@@ -391,7 +444,8 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
 {
 	return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
-	       ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB;
+	       ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB ||
+	       ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB;
 }
 
 static bool sock_map_sk_is_suitable(const struct sock *sk)
@@ -400,14 +454,26 @@ static bool sock_map_sk_is_suitable(const struct sock *sk)
 	       sk->sk_protocol == IPPROTO_TCP;
 }
 
+static bool sock_map_sk_state_allowed(const struct sock *sk)
+{
+	return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN);
+}
+
 static int sock_map_update_elem(struct bpf_map *map, void *key,
 				void *value, u64 flags)
 {
-	u32 ufd = *(u32 *)value;
 	u32 idx = *(u32 *)key;
 	struct socket *sock;
 	struct sock *sk;
 	int ret;
+	u64 ufd;
+
+	if (map->value_size == sizeof(u64))
+		ufd = *(u64 *)value;
+	else
+		ufd = *(u32 *)value;
+	if (ufd > S32_MAX)
+		return -EINVAL;
 
 	sock = sockfd_lookup(ufd, &ret);
 	if (!sock)
@@ -423,7 +489,7 @@ static int sock_map_update_elem(struct bpf_map *map, void *key,
 	}
 
 	sock_map_sk_acquire(sk);
-	if (sk->sk_state != TCP_ESTABLISHED)
+	if (!sock_map_sk_state_allowed(sk))
 		ret = -EOPNOTSUPP;
 	else
 		ret = sock_map_update_common(map, idx, sk, flags);
@@ -460,13 +526,17 @@ BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
 	   struct bpf_map *, map, u32, key, u64, flags)
 {
 	struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
+	struct sock *sk;
 
 	if (unlikely(flags & ~(BPF_F_INGRESS)))
 		return SK_DROP;
-	tcb->bpf.flags = flags;
-	tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key);
-	if (!tcb->bpf.sk_redir)
+
+	sk = __sock_map_lookup_elem(map, key);
+	if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
 		return SK_DROP;
+
+	tcb->bpf.flags = flags;
+	tcb->bpf.sk_redir = sk;
 	return SK_PASS;
 }
 
@@ -483,12 +553,17 @@ const struct bpf_func_proto bpf_sk_redirect_map_proto = {
 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg,
 	   struct bpf_map *, map, u32, key, u64, flags)
 {
+	struct sock *sk;
+
 	if (unlikely(flags & ~(BPF_F_INGRESS)))
 		return SK_DROP;
-	msg->flags = flags;
-	msg->sk_redir = __sock_map_lookup_elem(map, key);
-	if (!msg->sk_redir)
+
+	sk = __sock_map_lookup_elem(map, key);
+	if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
 		return SK_DROP;
+
+	msg->flags = flags;
+	msg->sk_redir = sk;
 	return SK_PASS;
 }
 
@@ -506,6 +581,7 @@ const struct bpf_map_ops sock_map_ops = {
 	.map_alloc		= sock_map_alloc,
 	.map_free		= sock_map_free,
 	.map_get_next_key	= sock_map_get_next_key,
+	.map_lookup_elem_sys_only = sock_map_lookup_sys,
 	.map_update_elem	= sock_map_update_elem,
 	.map_delete_elem	= sock_map_delete_elem,
 	.map_lookup_elem	= sock_map_lookup,
@@ -680,7 +756,14 @@ static int sock_hash_update_common(struct bpf_map *map, void *key,
 	if (!link)
 		return -ENOMEM;
 
-	ret = sock_map_link(map, &htab->progs, sk);
+	/* Only sockets we can redirect into/from in BPF need to hold
+	 * refs to parser/verdict progs and have their sk_data_ready
+	 * and sk_write_space callbacks overridden.
+	 */
+	if (sock_map_redirect_allowed(sk))
+		ret = sock_map_link(map, &htab->progs, sk);
+	else
+		ret = sock_map_link_no_progs(map, sk);
 	if (ret < 0)
 		goto out_free;
 
@@ -729,10 +812,17 @@ static int sock_hash_update_common(struct bpf_map *map, void *key,
 static int sock_hash_update_elem(struct bpf_map *map, void *key,
 				 void *value, u64 flags)
 {
-	u32 ufd = *(u32 *)value;
 	struct socket *sock;
 	struct sock *sk;
 	int ret;
+	u64 ufd;
+
+	if (map->value_size == sizeof(u64))
+		ufd = *(u64 *)value;
+	else
+		ufd = *(u32 *)value;
+	if (ufd > S32_MAX)
+		return -EINVAL;
 
 	sock = sockfd_lookup(ufd, &ret);
 	if (!sock)
@@ -748,7 +838,7 @@ static int sock_hash_update_elem(struct bpf_map *map, void *key,
 	}
 
 	sock_map_sk_acquire(sk);
-	if (sk->sk_state != TCP_ESTABLISHED)
+	if (!sock_map_sk_state_allowed(sk))
 		ret = -EOPNOTSUPP;
 	else
 		ret = sock_hash_update_common(map, key, sk, flags);
@@ -808,7 +898,8 @@ static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
 		return ERR_PTR(-EPERM);
 	if (attr->max_entries == 0 ||
 	    attr->key_size    == 0 ||
-	    attr->value_size  != 4 ||
+	    (attr->value_size != sizeof(u32) &&
+	     attr->value_size != sizeof(u64)) ||
 	    attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
 		return ERR_PTR(-EINVAL);
 	if (attr->key_size > MAX_BPF_STACK)
@@ -885,6 +976,26 @@ static void sock_hash_free(struct bpf_map *map)
 	kfree(htab);
 }
 
+static void *sock_hash_lookup_sys(struct bpf_map *map, void *key)
+{
+	struct sock *sk;
+
+	if (map->value_size != sizeof(u64))
+		return ERR_PTR(-ENOSPC);
+
+	sk = __sock_hash_lookup_elem(map, key);
+	if (!sk)
+		return ERR_PTR(-ENOENT);
+
+	sock_gen_cookie(sk);
+	return &sk->sk_cookie;
+}
+
+static void *sock_hash_lookup(struct bpf_map *map, void *key)
+{
+	return __sock_hash_lookup_elem(map, key);
+}
+
 static void sock_hash_release_progs(struct bpf_map *map)
 {
 	psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs);
@@ -916,13 +1027,17 @@ BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
 	   struct bpf_map *, map, void *, key, u64, flags)
 {
 	struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
+	struct sock *sk;
 
 	if (unlikely(flags & ~(BPF_F_INGRESS)))
 		return SK_DROP;
-	tcb->bpf.flags = flags;
-	tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key);
-	if (!tcb->bpf.sk_redir)
+
+	sk = __sock_hash_lookup_elem(map, key);
+	if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
 		return SK_DROP;
+
+	tcb->bpf.flags = flags;
+	tcb->bpf.sk_redir = sk;
 	return SK_PASS;
 }
 
@@ -939,12 +1054,17 @@ const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg,
 	   struct bpf_map *, map, void *, key, u64, flags)
 {
+	struct sock *sk;
+
 	if (unlikely(flags & ~(BPF_F_INGRESS)))
 		return SK_DROP;
-	msg->flags = flags;
-	msg->sk_redir = __sock_hash_lookup_elem(map, key);
-	if (!msg->sk_redir)
+
+	sk = __sock_hash_lookup_elem(map, key);
+	if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
 		return SK_DROP;
+
+	msg->flags = flags;
+	msg->sk_redir = sk;
 	return SK_PASS;
 }
 
@@ -964,7 +1084,8 @@ const struct bpf_map_ops sock_hash_ops = {
 	.map_get_next_key	= sock_hash_get_next_key,
 	.map_update_elem	= sock_hash_update_elem,
 	.map_delete_elem	= sock_hash_delete_elem,
-	.map_lookup_elem	= sock_map_lookup,
+	.map_lookup_elem	= sock_hash_lookup,
+	.map_lookup_elem_sys_only = sock_hash_lookup_sys,
 	.map_release_uref	= sock_hash_release_progs,
 	.map_check_btf		= map_check_no_btf,
 };
diff --git a/net/core/sock_reuseport.c b/net/core/sock_reuseport.c
index 91e9f22..adcb3ae 100644
--- a/net/core/sock_reuseport.c
+++ b/net/core/sock_reuseport.c
@@ -16,27 +16,8 @@
 
 DEFINE_SPINLOCK(reuseport_lock);
 
-#define REUSEPORT_MIN_ID 1
 static DEFINE_IDA(reuseport_ida);
 
-int reuseport_get_id(struct sock_reuseport *reuse)
-{
-	int id;
-
-	if (reuse->reuseport_id)
-		return reuse->reuseport_id;
-
-	id = ida_simple_get(&reuseport_ida, REUSEPORT_MIN_ID, 0,
-			    /* Called under reuseport_lock */
-			    GFP_ATOMIC);
-	if (id < 0)
-		return id;
-
-	reuse->reuseport_id = id;
-
-	return reuse->reuseport_id;
-}
-
 static struct sock_reuseport *__reuseport_alloc(unsigned int max_socks)
 {
 	unsigned int size = sizeof(struct sock_reuseport) +
@@ -55,6 +36,7 @@ static struct sock_reuseport *__reuseport_alloc(unsigned int max_socks)
 int reuseport_alloc(struct sock *sk, bool bind_inany)
 {
 	struct sock_reuseport *reuse;
+	int id, ret = 0;
 
 	/* bh lock used since this function call may precede hlist lock in
 	 * soft irq of receive path or setsockopt from process context
@@ -78,10 +60,18 @@ int reuseport_alloc(struct sock *sk, bool bind_inany)
 
 	reuse = __reuseport_alloc(INIT_SOCKS);
 	if (!reuse) {
-		spin_unlock_bh(&reuseport_lock);
-		return -ENOMEM;
+		ret = -ENOMEM;
+		goto out;
 	}
 
+	id = ida_alloc(&reuseport_ida, GFP_ATOMIC);
+	if (id < 0) {
+		kfree(reuse);
+		ret = id;
+		goto out;
+	}
+
+	reuse->reuseport_id = id;
 	reuse->socks[0] = sk;
 	reuse->num_socks = 1;
 	reuse->bind_inany = bind_inany;
@@ -90,7 +80,7 @@ int reuseport_alloc(struct sock *sk, bool bind_inany)
 out:
 	spin_unlock_bh(&reuseport_lock);
 
-	return 0;
+	return ret;
 }
 EXPORT_SYMBOL(reuseport_alloc);
 
@@ -134,8 +124,7 @@ static void reuseport_free_rcu(struct rcu_head *head)
 
 	reuse = container_of(head, struct sock_reuseport, rcu);
 	sk_reuseport_prog_free(rcu_dereference_protected(reuse->prog, 1));
-	if (reuse->reuseport_id)
-		ida_simple_remove(&reuseport_ida, reuse->reuseport_id);
+	ida_free(&reuseport_ida, reuse->reuseport_id);
 	kfree(reuse);
 }
 
@@ -199,12 +188,15 @@ void reuseport_detach_sock(struct sock *sk)
 	reuse = rcu_dereference_protected(sk->sk_reuseport_cb,
 					  lockdep_is_held(&reuseport_lock));
 
-	/* At least one of the sk in this reuseport group is added to
-	 * a bpf map.  Notify the bpf side.  The bpf map logic will
-	 * remove the sk if it is indeed added to a bpf map.
+	/* Notify the bpf side. The sk may be added to a sockarray
+	 * map. If so, sockarray logic will remove it from the map.
+	 *
+	 * Other bpf map types that work with reuseport, like sockmap,
+	 * don't need an explicit callback from here. They override sk
+	 * unhash/close ops to remove the sk from the map before we
+	 * get to this point.
 	 */
-	if (reuse->reuseport_id)
-		bpf_sk_reuseport_detach(sk);
+	bpf_sk_reuseport_detach(sk);
 
 	rcu_assign_pointer(sk->sk_reuseport_cb, NULL);
 
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
index 8a01428..7d6e1b7 100644
--- a/net/ipv4/tcp_bpf.c
+++ b/net/ipv4/tcp_bpf.c
@@ -645,8 +645,10 @@ static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock)
 	/* Reinit occurs when program types change e.g. TCP_BPF_TX is removed
 	 * or added requiring sk_prot hook updates. We keep original saved
 	 * hooks in this case.
+	 *
+	 * Pairs with lockless read in sk_clone_lock().
 	 */
-	sk->sk_prot = &tcp_bpf_prots[family][config];
+	WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]);
 }
 
 static int tcp_bpf_assert_proto_ops(struct proto *ops)
@@ -691,3 +693,17 @@ int tcp_bpf_init(struct sock *sk)
 	rcu_read_unlock();
 	return 0;
 }
+
+/* If a child got cloned from a listening socket that had tcp_bpf
+ * protocol callbacks installed, we need to restore the callbacks to
+ * the default ones because the child does not inherit the psock state
+ * that tcp_bpf callbacks expect.
+ */
+void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
+{
+	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
+	struct proto *prot = newsk->sk_prot;
+
+	if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE])
+		newsk->sk_prot = sk->sk_prot_creator;
+}
diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c
index ad3b56d..c827437 100644
--- a/net/ipv4/tcp_minisocks.c
+++ b/net/ipv4/tcp_minisocks.c
@@ -548,6 +548,8 @@ struct sock *tcp_create_openreq_child(const struct sock *sk,
 	newtp->fastopen_req = NULL;
 	RCU_INIT_POINTER(newtp->fastopen_rsk, NULL);
 
+	tcp_bpf_clone(sk, newsk);
+
 	__TCP_INC_STATS(sock_net(sk), TCP_MIB_PASSIVEOPENS);
 
 	return newsk;
diff --git a/net/ipv4/tcp_ulp.c b/net/ipv4/tcp_ulp.c
index 38d3ad1..6c43fa1 100644
--- a/net/ipv4/tcp_ulp.c
+++ b/net/ipv4/tcp_ulp.c
@@ -106,7 +106,8 @@ void tcp_update_ulp(struct sock *sk, struct proto *proto,
 
 	if (!icsk->icsk_ulp_ops) {
 		sk->sk_write_space = write_space;
-		sk->sk_prot = proto;
+		/* Pairs with lockless read in sk_clone_lock() */
+		WRITE_ONCE(sk->sk_prot, proto);
 		return;
 	}
 
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 94774c0..82225bc 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -742,7 +742,8 @@ static void tls_update(struct sock *sk, struct proto *p,
 		ctx->sk_write_space = write_space;
 		ctx->sk_proto = p;
 	} else {
-		sk->sk_prot = p;
+		/* Pairs with lockless read in sk_clone_lock(). */
+		WRITE_ONCE(sk->sk_prot, p);
 		sk->sk_write_space = write_space;
 	}
 }
diff --git a/tools/testing/selftests/bpf/prog_tests/select_reuseport.c b/tools/testing/selftests/bpf/prog_tests/select_reuseport.c
index 098bcae..9ed0ab0 100644
--- a/tools/testing/selftests/bpf/prog_tests/select_reuseport.c
+++ b/tools/testing/selftests/bpf/prog_tests/select_reuseport.c
@@ -36,6 +36,7 @@ static int result_map, tmp_index_ovr_map, linum_map, data_check_map;
 static __u32 expected_results[NR_RESULTS];
 static int sk_fds[REUSEPORT_ARRAY_SIZE];
 static int reuseport_array = -1, outer_map = -1;
+static enum bpf_map_type inner_map_type;
 static int select_by_skb_data_prog;
 static int saved_tcp_syncookie = -1;
 static struct bpf_object *obj;
@@ -63,13 +64,15 @@ static union sa46 {
 	}								\
 })
 
-static int create_maps(void)
+static int create_maps(enum bpf_map_type inner_type)
 {
 	struct bpf_create_map_attr attr = {};
 
+	inner_map_type = inner_type;
+
 	/* Creating reuseport_array */
 	attr.name = "reuseport_array";
-	attr.map_type = BPF_MAP_TYPE_REUSEPORT_SOCKARRAY;
+	attr.map_type = inner_type;
 	attr.key_size = sizeof(__u32);
 	attr.value_size = sizeof(__u32);
 	attr.max_entries = REUSEPORT_ARRAY_SIZE;
@@ -726,12 +729,36 @@ static void cleanup_per_test(bool no_inner_map)
 
 static void cleanup(void)
 {
-	if (outer_map != -1)
+	if (outer_map != -1) {
 		close(outer_map);
-	if (reuseport_array != -1)
+		outer_map = -1;
+	}
+
+	if (reuseport_array != -1) {
 		close(reuseport_array);
-	if (obj)
+		reuseport_array = -1;
+	}
+
+	if (obj) {
 		bpf_object__close(obj);
+		obj = NULL;
+	}
+
+	memset(expected_results, 0, sizeof(expected_results));
+}
+
+static const char *maptype_str(enum bpf_map_type type)
+{
+	switch (type) {
+	case BPF_MAP_TYPE_REUSEPORT_SOCKARRAY:
+		return "reuseport_sockarray";
+	case BPF_MAP_TYPE_SOCKMAP:
+		return "sockmap";
+	case BPF_MAP_TYPE_SOCKHASH:
+		return "sockhash";
+	default:
+		return "unknown";
+	}
 }
 
 static const char *family_str(sa_family_t family)
@@ -779,13 +806,21 @@ static void test_config(int sotype, sa_family_t family, bool inany)
 	const struct test *t;
 
 	for (t = tests; t < tests + ARRAY_SIZE(tests); t++) {
-		snprintf(s, sizeof(s), "%s/%s %s %s",
+		snprintf(s, sizeof(s), "%s %s/%s %s %s",
+			 maptype_str(inner_map_type),
 			 family_str(family), sotype_str(sotype),
 			 inany ? "INANY" : "LOOPBACK", t->name);
 
 		if (!test__start_subtest(s))
 			continue;
 
+		if (sotype == SOCK_DGRAM &&
+		    inner_map_type != BPF_MAP_TYPE_REUSEPORT_SOCKARRAY) {
+			/* SOCKMAP/SOCKHASH don't support UDP yet */
+			test__skip();
+			continue;
+		}
+
 		setup_per_test(sotype, family, inany, t->no_inner_map);
 		t->fn(sotype, family);
 		cleanup_per_test(t->no_inner_map);
@@ -814,13 +849,20 @@ static void test_all(void)
 		test_config(c->sotype, c->family, c->inany);
 }
 
-void test_select_reuseport(void)
+void test_map_type(enum bpf_map_type mt)
 {
-	if (create_maps())
+	if (create_maps(mt))
 		goto out;
 	if (prepare_bpf_obj())
 		goto out;
 
+	test_all();
+out:
+	cleanup();
+}
+
+void test_select_reuseport(void)
+{
 	saved_tcp_fo = read_int_sysctl(TCP_FO_SYSCTL);
 	saved_tcp_syncookie = read_int_sysctl(TCP_SYNCOOKIE_SYSCTL);
 	if (saved_tcp_syncookie < 0 || saved_tcp_syncookie < 0)
@@ -831,8 +873,9 @@ void test_select_reuseport(void)
 	if (disable_syncookie())
 		goto out;
 
-	test_all();
+	test_map_type(BPF_MAP_TYPE_REUSEPORT_SOCKARRAY);
+	test_map_type(BPF_MAP_TYPE_SOCKMAP);
+	test_map_type(BPF_MAP_TYPE_SOCKHASH);
 out:
-	cleanup();
 	restore_sysctls();
 }
diff --git a/tools/testing/selftests/bpf/prog_tests/sockmap_listen.c b/tools/testing/selftests/bpf/prog_tests/sockmap_listen.c
new file mode 100644
index 0000000..b1b2ace
--- /dev/null
+++ b/tools/testing/selftests/bpf/prog_tests/sockmap_listen.c
@@ -0,0 +1,1496 @@
+// SPDX-License-Identifier: GPL-2.0
+// Copyright (c) 2020 Cloudflare
+/*
+ * Test suite for SOCKMAP/SOCKHASH holding listening sockets.
+ * Covers:
+ *  1. BPF map operations - bpf_map_{update,lookup delete}_elem
+ *  2. BPF redirect helpers - bpf_{sk,msg}_redirect_map
+ *  3. BPF reuseport helper - bpf_sk_select_reuseport
+ */
+
+#include <linux/compiler.h>
+#include <errno.h>
+#include <error.h>
+#include <limits.h>
+#include <netinet/in.h>
+#include <pthread.h>
+#include <stdlib.h>
+#include <string.h>
+#include <unistd.h>
+
+#include <bpf/bpf.h>
+#include <bpf/libbpf.h>
+
+#include "bpf_util.h"
+#include "test_progs.h"
+#include "test_sockmap_listen.skel.h"
+
+#define MAX_STRERR_LEN 256
+#define MAX_TEST_NAME 80
+
+#define _FAIL(errnum, fmt...)                                                  \
+	({                                                                     \
+		error_at_line(0, (errnum), __func__, __LINE__, fmt);           \
+		CHECK_FAIL(true);                                              \
+	})
+#define FAIL(fmt...) _FAIL(0, fmt)
+#define FAIL_ERRNO(fmt...) _FAIL(errno, fmt)
+#define FAIL_LIBBPF(err, msg)                                                  \
+	({                                                                     \
+		char __buf[MAX_STRERR_LEN];                                    \
+		libbpf_strerror((err), __buf, sizeof(__buf));                  \
+		FAIL("%s: %s", (msg), __buf);                                  \
+	})
+
+/* Wrappers that fail the test on error and report it. */
+
+#define xaccept(fd, addr, len)                                                 \
+	({                                                                     \
+		int __ret = accept((fd), (addr), (len));                       \
+		if (__ret == -1)                                               \
+			FAIL_ERRNO("accept");                                  \
+		__ret;                                                         \
+	})
+
+#define xbind(fd, addr, len)                                                   \
+	({                                                                     \
+		int __ret = bind((fd), (addr), (len));                         \
+		if (__ret == -1)                                               \
+			FAIL_ERRNO("bind");                                    \
+		__ret;                                                         \
+	})
+
+#define xclose(fd)                                                             \
+	({                                                                     \
+		int __ret = close((fd));                                       \
+		if (__ret == -1)                                               \
+			FAIL_ERRNO("close");                                   \
+		__ret;                                                         \
+	})
+
+#define xconnect(fd, addr, len)                                                \
+	({                                                                     \
+		int __ret = connect((fd), (addr), (len));                      \
+		if (__ret == -1)                                               \
+			FAIL_ERRNO("connect");                                 \
+		__ret;                                                         \
+	})
+
+#define xgetsockname(fd, addr, len)                                            \
+	({                                                                     \
+		int __ret = getsockname((fd), (addr), (len));                  \
+		if (__ret == -1)                                               \
+			FAIL_ERRNO("getsockname");                             \
+		__ret;                                                         \
+	})
+
+#define xgetsockopt(fd, level, name, val, len)                                 \
+	({                                                                     \
+		int __ret = getsockopt((fd), (level), (name), (val), (len));   \
+		if (__ret == -1)                                               \
+			FAIL_ERRNO("getsockopt(" #name ")");                   \
+		__ret;                                                         \
+	})
+
+#define xlisten(fd, backlog)                                                   \
+	({                                                                     \
+		int __ret = listen((fd), (backlog));                           \
+		if (__ret == -1)                                               \
+			FAIL_ERRNO("listen");                                  \
+		__ret;                                                         \
+	})
+
+#define xsetsockopt(fd, level, name, val, len)                                 \
+	({                                                                     \
+		int __ret = setsockopt((fd), (level), (name), (val), (len));   \
+		if (__ret == -1)                                               \
+			FAIL_ERRNO("setsockopt(" #name ")");                   \
+		__ret;                                                         \
+	})
+
+#define xsocket(family, sotype, flags)                                         \
+	({                                                                     \
+		int __ret = socket(family, sotype, flags);                     \
+		if (__ret == -1)                                               \
+			FAIL_ERRNO("socket");                                  \
+		__ret;                                                         \
+	})
+
+#define xbpf_map_delete_elem(fd, key)                                          \
+	({                                                                     \
+		int __ret = bpf_map_delete_elem((fd), (key));                  \
+		if (__ret == -1)                                               \
+			FAIL_ERRNO("map_delete");                              \
+		__ret;                                                         \
+	})
+
+#define xbpf_map_lookup_elem(fd, key, val)                                     \
+	({                                                                     \
+		int __ret = bpf_map_lookup_elem((fd), (key), (val));           \
+		if (__ret == -1)                                               \
+			FAIL_ERRNO("map_lookup");                              \
+		__ret;                                                         \
+	})
+
+#define xbpf_map_update_elem(fd, key, val, flags)                              \
+	({                                                                     \
+		int __ret = bpf_map_update_elem((fd), (key), (val), (flags));  \
+		if (__ret == -1)                                               \
+			FAIL_ERRNO("map_update");                              \
+		__ret;                                                         \
+	})
+
+#define xbpf_prog_attach(prog, target, type, flags)                            \
+	({                                                                     \
+		int __ret =                                                    \
+			bpf_prog_attach((prog), (target), (type), (flags));    \
+		if (__ret == -1)                                               \
+			FAIL_ERRNO("prog_attach(" #type ")");                  \
+		__ret;                                                         \
+	})
+
+#define xbpf_prog_detach2(prog, target, type)                                  \
+	({                                                                     \
+		int __ret = bpf_prog_detach2((prog), (target), (type));        \
+		if (__ret == -1)                                               \
+			FAIL_ERRNO("prog_detach2(" #type ")");                 \
+		__ret;                                                         \
+	})
+
+#define xpthread_create(thread, attr, func, arg)                               \
+	({                                                                     \
+		int __ret = pthread_create((thread), (attr), (func), (arg));   \
+		errno = __ret;                                                 \
+		if (__ret)                                                     \
+			FAIL_ERRNO("pthread_create");                          \
+		__ret;                                                         \
+	})
+
+#define xpthread_join(thread, retval)                                          \
+	({                                                                     \
+		int __ret = pthread_join((thread), (retval));                  \
+		errno = __ret;                                                 \
+		if (__ret)                                                     \
+			FAIL_ERRNO("pthread_join");                            \
+		__ret;                                                         \
+	})
+
+static void init_addr_loopback4(struct sockaddr_storage *ss, socklen_t *len)
+{
+	struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss));
+
+	addr4->sin_family = AF_INET;
+	addr4->sin_port = 0;
+	addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+	*len = sizeof(*addr4);
+}
+
+static void init_addr_loopback6(struct sockaddr_storage *ss, socklen_t *len)
+{
+	struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss));
+
+	addr6->sin6_family = AF_INET6;
+	addr6->sin6_port = 0;
+	addr6->sin6_addr = in6addr_loopback;
+	*len = sizeof(*addr6);
+}
+
+static void init_addr_loopback(int family, struct sockaddr_storage *ss,
+			       socklen_t *len)
+{
+	switch (family) {
+	case AF_INET:
+		init_addr_loopback4(ss, len);
+		return;
+	case AF_INET6:
+		init_addr_loopback6(ss, len);
+		return;
+	default:
+		FAIL("unsupported address family %d", family);
+	}
+}
+
+static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss)
+{
+	return (struct sockaddr *)ss;
+}
+
+static int enable_reuseport(int s, int progfd)
+{
+	int err, one = 1;
+
+	err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
+	if (err)
+		return -1;
+	err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd,
+			  sizeof(progfd));
+	if (err)
+		return -1;
+
+	return 0;
+}
+
+static int listen_loopback_reuseport(int family, int sotype, int progfd)
+{
+	struct sockaddr_storage addr;
+	socklen_t len;
+	int err, s;
+
+	init_addr_loopback(family, &addr, &len);
+
+	s = xsocket(family, sotype, 0);
+	if (s == -1)
+		return -1;
+
+	if (progfd >= 0)
+		enable_reuseport(s, progfd);
+
+	err = xbind(s, sockaddr(&addr), len);
+	if (err)
+		goto close;
+
+	err = xlisten(s, SOMAXCONN);
+	if (err)
+		goto close;
+
+	return s;
+close:
+	xclose(s);
+	return -1;
+}
+
+static int listen_loopback(int family, int sotype)
+{
+	return listen_loopback_reuseport(family, sotype, -1);
+}
+
+static void test_insert_invalid(int family, int sotype, int mapfd)
+{
+	u32 key = 0;
+	u64 value;
+	int err;
+
+	value = -1;
+	err = bpf_map_update_elem(mapfd, &key, &value, BPF_NOEXIST);
+	if (!err || errno != EINVAL)
+		FAIL_ERRNO("map_update: expected EINVAL");
+
+	value = INT_MAX;
+	err = bpf_map_update_elem(mapfd, &key, &value, BPF_NOEXIST);
+	if (!err || errno != EBADF)
+		FAIL_ERRNO("map_update: expected EBADF");
+}
+
+static void test_insert_opened(int family, int sotype, int mapfd)
+{
+	u32 key = 0;
+	u64 value;
+	int err, s;
+
+	s = xsocket(family, sotype, 0);
+	if (s == -1)
+		return;
+
+	errno = 0;
+	value = s;
+	err = bpf_map_update_elem(mapfd, &key, &value, BPF_NOEXIST);
+	if (!err || errno != EOPNOTSUPP)
+		FAIL_ERRNO("map_update: expected EOPNOTSUPP");
+
+	xclose(s);
+}
+
+static void test_insert_bound(int family, int sotype, int mapfd)
+{
+	struct sockaddr_storage addr;
+	socklen_t len;
+	u32 key = 0;
+	u64 value;
+	int err, s;
+
+	init_addr_loopback(family, &addr, &len);
+
+	s = xsocket(family, sotype, 0);
+	if (s == -1)
+		return;
+
+	err = xbind(s, sockaddr(&addr), len);
+	if (err)
+		goto close;
+
+	errno = 0;
+	value = s;
+	err = bpf_map_update_elem(mapfd, &key, &value, BPF_NOEXIST);
+	if (!err || errno != EOPNOTSUPP)
+		FAIL_ERRNO("map_update: expected EOPNOTSUPP");
+close:
+	xclose(s);
+}
+
+static void test_insert_listening(int family, int sotype, int mapfd)
+{
+	u64 value;
+	u32 key;
+	int s;
+
+	s = listen_loopback(family, sotype);
+	if (s < 0)
+		return;
+
+	key = 0;
+	value = s;
+	xbpf_map_update_elem(mapfd, &key, &value, BPF_NOEXIST);
+	xclose(s);
+}
+
+static void test_delete_after_insert(int family, int sotype, int mapfd)
+{
+	u64 value;
+	u32 key;
+	int s;
+
+	s = listen_loopback(family, sotype);
+	if (s < 0)
+		return;
+
+	key = 0;
+	value = s;
+	xbpf_map_update_elem(mapfd, &key, &value, BPF_NOEXIST);
+	xbpf_map_delete_elem(mapfd, &key);
+	xclose(s);
+}
+
+static void test_delete_after_close(int family, int sotype, int mapfd)
+{
+	int err, s;
+	u64 value;
+	u32 key;
+
+	s = listen_loopback(family, sotype);
+	if (s < 0)
+		return;
+
+	key = 0;
+	value = s;
+	xbpf_map_update_elem(mapfd, &key, &value, BPF_NOEXIST);
+
+	xclose(s);
+
+	errno = 0;
+	err = bpf_map_delete_elem(mapfd, &key);
+	if (!err || (errno != EINVAL && errno != ENOENT))
+		/* SOCKMAP and SOCKHASH return different error codes */
+		FAIL_ERRNO("map_delete: expected EINVAL/EINVAL");
+}
+
+static void test_lookup_after_insert(int family, int sotype, int mapfd)
+{
+	u64 cookie, value;
+	socklen_t len;
+	u32 key;
+	int s;
+
+	s = listen_loopback(family, sotype);
+	if (s < 0)
+		return;
+
+	key = 0;
+	value = s;
+	xbpf_map_update_elem(mapfd, &key, &value, BPF_NOEXIST);
+
+	len = sizeof(cookie);
+	xgetsockopt(s, SOL_SOCKET, SO_COOKIE, &cookie, &len);
+
+	xbpf_map_lookup_elem(mapfd, &key, &value);
+
+	if (value != cookie) {
+		FAIL("map_lookup: have %#llx, want %#llx",
+		     (unsigned long long)value, (unsigned long long)cookie);
+	}
+
+	xclose(s);
+}
+
+static void test_lookup_after_delete(int family, int sotype, int mapfd)
+{
+	int err, s;
+	u64 value;
+	u32 key;
+
+	s = listen_loopback(family, sotype);
+	if (s < 0)
+		return;
+
+	key = 0;
+	value = s;
+	xbpf_map_update_elem(mapfd, &key, &value, BPF_NOEXIST);
+	xbpf_map_delete_elem(mapfd, &key);
+
+	errno = 0;
+	err = bpf_map_lookup_elem(mapfd, &key, &value);
+	if (!err || errno != ENOENT)
+		FAIL_ERRNO("map_lookup: expected ENOENT");
+
+	xclose(s);
+}
+
+static void test_lookup_32_bit_value(int family, int sotype, int mapfd)
+{
+	u32 key, value32;
+	int err, s;
+
+	s = listen_loopback(family, sotype);
+	if (s < 0)
+		return;
+
+	mapfd = bpf_create_map(BPF_MAP_TYPE_SOCKMAP, sizeof(key),
+			       sizeof(value32), 1, 0);
+	if (mapfd < 0) {
+		FAIL_ERRNO("map_create");
+		goto close;
+	}
+
+	key = 0;
+	value32 = s;
+	xbpf_map_update_elem(mapfd, &key, &value32, BPF_NOEXIST);
+
+	errno = 0;
+	err = bpf_map_lookup_elem(mapfd, &key, &value32);
+	if (!err || errno != ENOSPC)
+		FAIL_ERRNO("map_lookup: expected ENOSPC");
+
+	xclose(mapfd);
+close:
+	xclose(s);
+}
+
+static void test_update_listening(int family, int sotype, int mapfd)
+{
+	int s1, s2;
+	u64 value;
+	u32 key;
+
+	s1 = listen_loopback(family, sotype);
+	if (s1 < 0)
+		return;
+
+	s2 = listen_loopback(family, sotype);
+	if (s2 < 0)
+		goto close_s1;
+
+	key = 0;
+	value = s1;
+	xbpf_map_update_elem(mapfd, &key, &value, BPF_NOEXIST);
+
+	value = s2;
+	xbpf_map_update_elem(mapfd, &key, &value, BPF_EXIST);
+	xclose(s2);
+close_s1:
+	xclose(s1);
+}
+
+/* Exercise the code path where we destroy child sockets that never
+ * got accept()'ed, aka orphans, when parent socket gets closed.
+ */
+static void test_destroy_orphan_child(int family, int sotype, int mapfd)
+{
+	struct sockaddr_storage addr;
+	socklen_t len;
+	int err, s, c;
+	u64 value;
+	u32 key;
+
+	s = listen_loopback(family, sotype);
+	if (s < 0)
+		return;
+
+	len = sizeof(addr);
+	err = xgetsockname(s, sockaddr(&addr), &len);
+	if (err)
+		goto close_srv;
+
+	key = 0;
+	value = s;
+	xbpf_map_update_elem(mapfd, &key, &value, BPF_NOEXIST);
+
+	c = xsocket(family, sotype, 0);
+	if (c == -1)
+		goto close_srv;
+
+	xconnect(c, sockaddr(&addr), len);
+	xclose(c);
+close_srv:
+	xclose(s);
+}
+
+/* Perform a passive open after removing listening socket from SOCKMAP
+ * to ensure that callbacks get restored properly.
+ */
+static void test_clone_after_delete(int family, int sotype, int mapfd)
+{
+	struct sockaddr_storage addr;
+	socklen_t len;
+	int err, s, c;
+	u64 value;
+	u32 key;
+
+	s = listen_loopback(family, sotype);
+	if (s < 0)
+		return;
+
+	len = sizeof(addr);
+	err = xgetsockname(s, sockaddr(&addr), &len);
+	if (err)
+		goto close_srv;
+
+	key = 0;
+	value = s;
+	xbpf_map_update_elem(mapfd, &key, &value, BPF_NOEXIST);
+	xbpf_map_delete_elem(mapfd, &key);
+
+	c = xsocket(family, sotype, 0);
+	if (c < 0)
+		goto close_srv;
+
+	xconnect(c, sockaddr(&addr), len);
+	xclose(c);
+close_srv:
+	xclose(s);
+}
+
+/* Check that child socket that got created while parent was in a
+ * SOCKMAP, but got accept()'ed only after the parent has been removed
+ * from SOCKMAP, gets cloned without parent psock state or callbacks.
+ */
+static void test_accept_after_delete(int family, int sotype, int mapfd)
+{
+	struct sockaddr_storage addr;
+	const u32 zero = 0;
+	int err, s, c, p;
+	socklen_t len;
+	u64 value;
+
+	s = listen_loopback(family, sotype);
+	if (s == -1)
+		return;
+
+	len = sizeof(addr);
+	err = xgetsockname(s, sockaddr(&addr), &len);
+	if (err)
+		goto close_srv;
+
+	value = s;
+	err = xbpf_map_update_elem(mapfd, &zero, &value, BPF_NOEXIST);
+	if (err)
+		goto close_srv;
+
+	c = xsocket(family, sotype, 0);
+	if (c == -1)
+		goto close_srv;
+
+	/* Create child while parent is in sockmap */
+	err = xconnect(c, sockaddr(&addr), len);
+	if (err)
+		goto close_cli;
+
+	/* Remove parent from sockmap */
+	err = xbpf_map_delete_elem(mapfd, &zero);
+	if (err)
+		goto close_cli;
+
+	p = xaccept(s, NULL, NULL);
+	if (p == -1)
+		goto close_cli;
+
+	/* Check that child sk_user_data is not set */
+	value = p;
+	xbpf_map_update_elem(mapfd, &zero, &value, BPF_NOEXIST);
+
+	xclose(p);
+close_cli:
+	xclose(c);
+close_srv:
+	xclose(s);
+}
+
+/* Check that child socket that got created and accepted while parent
+ * was in a SOCKMAP is cloned without parent psock state or callbacks.
+ */
+static void test_accept_before_delete(int family, int sotype, int mapfd)
+{
+	struct sockaddr_storage addr;
+	const u32 zero = 0, one = 1;
+	int err, s, c, p;
+	socklen_t len;
+	u64 value;
+
+	s = listen_loopback(family, sotype);
+	if (s == -1)
+		return;
+
+	len = sizeof(addr);
+	err = xgetsockname(s, sockaddr(&addr), &len);
+	if (err)
+		goto close_srv;
+
+	value = s;
+	err = xbpf_map_update_elem(mapfd, &zero, &value, BPF_NOEXIST);
+	if (err)
+		goto close_srv;
+
+	c = xsocket(family, sotype, 0);
+	if (c == -1)
+		goto close_srv;
+
+	/* Create & accept child while parent is in sockmap */
+	err = xconnect(c, sockaddr(&addr), len);
+	if (err)
+		goto close_cli;
+
+	p = xaccept(s, NULL, NULL);
+	if (p == -1)
+		goto close_cli;
+
+	/* Check that child sk_user_data is not set */
+	value = p;
+	xbpf_map_update_elem(mapfd, &one, &value, BPF_NOEXIST);
+
+	xclose(p);
+close_cli:
+	xclose(c);
+close_srv:
+	xclose(s);
+}
+
+struct connect_accept_ctx {
+	int sockfd;
+	unsigned int done;
+	unsigned int nr_iter;
+};
+
+static bool is_thread_done(struct connect_accept_ctx *ctx)
+{
+	return READ_ONCE(ctx->done);
+}
+
+static void *connect_accept_thread(void *arg)
+{
+	struct connect_accept_ctx *ctx = arg;
+	struct sockaddr_storage addr;
+	int family, socktype;
+	socklen_t len;
+	int err, i, s;
+
+	s = ctx->sockfd;
+
+	len = sizeof(addr);
+	err = xgetsockname(s, sockaddr(&addr), &len);
+	if (err)
+		goto done;
+
+	len = sizeof(family);
+	err = xgetsockopt(s, SOL_SOCKET, SO_DOMAIN, &family, &len);
+	if (err)
+		goto done;
+
+	len = sizeof(socktype);
+	err = xgetsockopt(s, SOL_SOCKET, SO_TYPE, &socktype, &len);
+	if (err)
+		goto done;
+
+	for (i = 0; i < ctx->nr_iter; i++) {
+		int c, p;
+
+		c = xsocket(family, socktype, 0);
+		if (c < 0)
+			break;
+
+		err = xconnect(c, (struct sockaddr *)&addr, sizeof(addr));
+		if (err) {
+			xclose(c);
+			break;
+		}
+
+		p = xaccept(s, NULL, NULL);
+		if (p < 0) {
+			xclose(c);
+			break;
+		}
+
+		xclose(p);
+		xclose(c);
+	}
+done:
+	WRITE_ONCE(ctx->done, 1);
+	return NULL;
+}
+
+static void test_syn_recv_insert_delete(int family, int sotype, int mapfd)
+{
+	struct connect_accept_ctx ctx = { 0 };
+	struct sockaddr_storage addr;
+	socklen_t len;
+	u32 zero = 0;
+	pthread_t t;
+	int err, s;
+	u64 value;
+
+	s = listen_loopback(family, sotype | SOCK_NONBLOCK);
+	if (s < 0)
+		return;
+
+	len = sizeof(addr);
+	err = xgetsockname(s, sockaddr(&addr), &len);
+	if (err)
+		goto close;
+
+	ctx.sockfd = s;
+	ctx.nr_iter = 1000;
+
+	err = xpthread_create(&t, NULL, connect_accept_thread, &ctx);
+	if (err)
+		goto close;
+
+	value = s;
+	while (!is_thread_done(&ctx)) {
+		err = xbpf_map_update_elem(mapfd, &zero, &value, BPF_NOEXIST);
+		if (err)
+			break;
+
+		err = xbpf_map_delete_elem(mapfd, &zero);
+		if (err)
+			break;
+	}
+
+	xpthread_join(t, NULL);
+close:
+	xclose(s);
+}
+
+static void *listen_thread(void *arg)
+{
+	struct sockaddr unspec = { AF_UNSPEC };
+	struct connect_accept_ctx *ctx = arg;
+	int err, i, s;
+
+	s = ctx->sockfd;
+
+	for (i = 0; i < ctx->nr_iter; i++) {
+		err = xlisten(s, 1);
+		if (err)
+			break;
+		err = xconnect(s, &unspec, sizeof(unspec));
+		if (err)
+			break;
+	}
+
+	WRITE_ONCE(ctx->done, 1);
+	return NULL;
+}
+
+static void test_race_insert_listen(int family, int socktype, int mapfd)
+{
+	struct connect_accept_ctx ctx = { 0 };
+	const u32 zero = 0;
+	const int one = 1;
+	pthread_t t;
+	int err, s;
+	u64 value;
+
+	s = xsocket(family, socktype, 0);
+	if (s < 0)
+		return;
+
+	err = xsetsockopt(s, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one));
+	if (err)
+		goto close;
+
+	ctx.sockfd = s;
+	ctx.nr_iter = 10000;
+
+	err = pthread_create(&t, NULL, listen_thread, &ctx);
+	if (err)
+		goto close;
+
+	value = s;
+	while (!is_thread_done(&ctx)) {
+		err = bpf_map_update_elem(mapfd, &zero, &value, BPF_NOEXIST);
+		/* Expecting EOPNOTSUPP before listen() */
+		if (err && errno != EOPNOTSUPP) {
+			FAIL_ERRNO("map_update");
+			break;
+		}
+
+		err = bpf_map_delete_elem(mapfd, &zero);
+		/* Expecting no entry after unhash on connect(AF_UNSPEC) */
+		if (err && errno != EINVAL && errno != ENOENT) {
+			FAIL_ERRNO("map_delete");
+			break;
+		}
+	}
+
+	xpthread_join(t, NULL);
+close:
+	xclose(s);
+}
+
+static void zero_verdict_count(int mapfd)
+{
+	unsigned int zero = 0;
+	int key;
+
+	key = SK_DROP;
+	xbpf_map_update_elem(mapfd, &key, &zero, BPF_ANY);
+	key = SK_PASS;
+	xbpf_map_update_elem(mapfd, &key, &zero, BPF_ANY);
+}
+
+enum redir_mode {
+	REDIR_INGRESS,
+	REDIR_EGRESS,
+};
+
+static const char *redir_mode_str(enum redir_mode mode)
+{
+	switch (mode) {
+	case REDIR_INGRESS:
+		return "ingress";
+	case REDIR_EGRESS:
+		return "egress";
+	default:
+		return "unknown";
+	}
+}
+
+static void redir_to_connected(int family, int sotype, int sock_mapfd,
+			       int verd_mapfd, enum redir_mode mode)
+{
+	const char *log_prefix = redir_mode_str(mode);
+	struct sockaddr_storage addr;
+	int s, c0, c1, p0, p1;
+	unsigned int pass;
+	socklen_t len;
+	int err, n;
+	u64 value;
+	u32 key;
+	char b;
+
+	zero_verdict_count(verd_mapfd);
+
+	s = listen_loopback(family, sotype | SOCK_NONBLOCK);
+	if (s < 0)
+		return;
+
+	len = sizeof(addr);
+	err = xgetsockname(s, sockaddr(&addr), &len);
+	if (err)
+		goto close_srv;
+
+	c0 = xsocket(family, sotype, 0);
+	if (c0 < 0)
+		goto close_srv;
+	err = xconnect(c0, sockaddr(&addr), len);
+	if (err)
+		goto close_cli0;
+
+	p0 = xaccept(s, NULL, NULL);
+	if (p0 < 0)
+		goto close_cli0;
+
+	c1 = xsocket(family, sotype, 0);
+	if (c1 < 0)
+		goto close_peer0;
+	err = xconnect(c1, sockaddr(&addr), len);
+	if (err)
+		goto close_cli1;
+
+	p1 = xaccept(s, NULL, NULL);
+	if (p1 < 0)
+		goto close_cli1;
+
+	key = 0;
+	value = p0;
+	err = xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
+	if (err)
+		goto close_peer1;
+
+	key = 1;
+	value = p1;
+	err = xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
+	if (err)
+		goto close_peer1;
+
+	n = write(mode == REDIR_INGRESS ? c1 : p1, "a", 1);
+	if (n < 0)
+		FAIL_ERRNO("%s: write", log_prefix);
+	if (n == 0)
+		FAIL("%s: incomplete write", log_prefix);
+	if (n < 1)
+		goto close_peer1;
+
+	key = SK_PASS;
+	err = xbpf_map_lookup_elem(verd_mapfd, &key, &pass);
+	if (err)
+		goto close_peer1;
+	if (pass != 1)
+		FAIL("%s: want pass count 1, have %d", log_prefix, pass);
+
+	n = read(c0, &b, 1);
+	if (n < 0)
+		FAIL_ERRNO("%s: read", log_prefix);
+	if (n == 0)
+		FAIL("%s: incomplete read", log_prefix);
+
+close_peer1:
+	xclose(p1);
+close_cli1:
+	xclose(c1);
+close_peer0:
+	xclose(p0);
+close_cli0:
+	xclose(c0);
+close_srv:
+	xclose(s);
+}
+
+static void test_skb_redir_to_connected(struct test_sockmap_listen *skel,
+					struct bpf_map *inner_map, int family,
+					int sotype)
+{
+	int verdict = bpf_program__fd(skel->progs.prog_skb_verdict);
+	int parser = bpf_program__fd(skel->progs.prog_skb_parser);
+	int verdict_map = bpf_map__fd(skel->maps.verdict_map);
+	int sock_map = bpf_map__fd(inner_map);
+	int err;
+
+	err = xbpf_prog_attach(parser, sock_map, BPF_SK_SKB_STREAM_PARSER, 0);
+	if (err)
+		return;
+	err = xbpf_prog_attach(verdict, sock_map, BPF_SK_SKB_STREAM_VERDICT, 0);
+	if (err)
+		goto detach;
+
+	redir_to_connected(family, sotype, sock_map, verdict_map,
+			   REDIR_INGRESS);
+
+	xbpf_prog_detach2(verdict, sock_map, BPF_SK_SKB_STREAM_VERDICT);
+detach:
+	xbpf_prog_detach2(parser, sock_map, BPF_SK_SKB_STREAM_PARSER);
+}
+
+static void test_msg_redir_to_connected(struct test_sockmap_listen *skel,
+					struct bpf_map *inner_map, int family,
+					int sotype)
+{
+	int verdict = bpf_program__fd(skel->progs.prog_msg_verdict);
+	int verdict_map = bpf_map__fd(skel->maps.verdict_map);
+	int sock_map = bpf_map__fd(inner_map);
+	int err;
+
+	err = xbpf_prog_attach(verdict, sock_map, BPF_SK_MSG_VERDICT, 0);
+	if (err)
+		return;
+
+	redir_to_connected(family, sotype, sock_map, verdict_map, REDIR_EGRESS);
+
+	xbpf_prog_detach2(verdict, sock_map, BPF_SK_MSG_VERDICT);
+}
+
+static void redir_to_listening(int family, int sotype, int sock_mapfd,
+			       int verd_mapfd, enum redir_mode mode)
+{
+	const char *log_prefix = redir_mode_str(mode);
+	struct sockaddr_storage addr;
+	int s, c, p, err, n;
+	unsigned int drop;
+	socklen_t len;
+	u64 value;
+	u32 key;
+
+	zero_verdict_count(verd_mapfd);
+
+	s = listen_loopback(family, sotype | SOCK_NONBLOCK);
+	if (s < 0)
+		return;
+
+	len = sizeof(addr);
+	err = xgetsockname(s, sockaddr(&addr), &len);
+	if (err)
+		goto close_srv;
+
+	c = xsocket(family, sotype, 0);
+	if (c < 0)
+		goto close_srv;
+	err = xconnect(c, sockaddr(&addr), len);
+	if (err)
+		goto close_cli;
+
+	p = xaccept(s, NULL, NULL);
+	if (p < 0)
+		goto close_cli;
+
+	key = 0;
+	value = s;
+	err = xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
+	if (err)
+		goto close_peer;
+
+	key = 1;
+	value = p;
+	err = xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
+	if (err)
+		goto close_peer;
+
+	n = write(mode == REDIR_INGRESS ? c : p, "a", 1);
+	if (n < 0 && errno != EACCES)
+		FAIL_ERRNO("%s: write", log_prefix);
+	if (n == 0)
+		FAIL("%s: incomplete write", log_prefix);
+	if (n < 1)
+		goto close_peer;
+
+	key = SK_DROP;
+	err = xbpf_map_lookup_elem(verd_mapfd, &key, &drop);
+	if (err)
+		goto close_peer;
+	if (drop != 1)
+		FAIL("%s: want drop count 1, have %d", log_prefix, drop);
+
+close_peer:
+	xclose(p);
+close_cli:
+	xclose(c);
+close_srv:
+	xclose(s);
+}
+
+static void test_skb_redir_to_listening(struct test_sockmap_listen *skel,
+					struct bpf_map *inner_map, int family,
+					int sotype)
+{
+	int verdict = bpf_program__fd(skel->progs.prog_skb_verdict);
+	int parser = bpf_program__fd(skel->progs.prog_skb_parser);
+	int verdict_map = bpf_map__fd(skel->maps.verdict_map);
+	int sock_map = bpf_map__fd(inner_map);
+	int err;
+
+	err = xbpf_prog_attach(parser, sock_map, BPF_SK_SKB_STREAM_PARSER, 0);
+	if (err)
+		return;
+	err = xbpf_prog_attach(verdict, sock_map, BPF_SK_SKB_STREAM_VERDICT, 0);
+	if (err)
+		goto detach;
+
+	redir_to_listening(family, sotype, sock_map, verdict_map,
+			   REDIR_INGRESS);
+
+	xbpf_prog_detach2(verdict, sock_map, BPF_SK_SKB_STREAM_VERDICT);
+detach:
+	xbpf_prog_detach2(parser, sock_map, BPF_SK_SKB_STREAM_PARSER);
+}
+
+static void test_msg_redir_to_listening(struct test_sockmap_listen *skel,
+					struct bpf_map *inner_map, int family,
+					int sotype)
+{
+	int verdict = bpf_program__fd(skel->progs.prog_msg_verdict);
+	int verdict_map = bpf_map__fd(skel->maps.verdict_map);
+	int sock_map = bpf_map__fd(inner_map);
+	int err;
+
+	err = xbpf_prog_attach(verdict, sock_map, BPF_SK_MSG_VERDICT, 0);
+	if (err)
+		return;
+
+	redir_to_listening(family, sotype, sock_map, verdict_map, REDIR_EGRESS);
+
+	xbpf_prog_detach2(verdict, sock_map, BPF_SK_MSG_VERDICT);
+}
+
+static void test_reuseport_select_listening(int family, int sotype,
+					    int sock_map, int verd_map,
+					    int reuseport_prog)
+{
+	struct sockaddr_storage addr;
+	unsigned int pass;
+	int s, c, p, err;
+	socklen_t len;
+	u64 value;
+	u32 key;
+
+	zero_verdict_count(verd_map);
+
+	s = listen_loopback_reuseport(family, sotype, reuseport_prog);
+	if (s < 0)
+		return;
+
+	len = sizeof(addr);
+	err = xgetsockname(s, sockaddr(&addr), &len);
+	if (err)
+		goto close_srv;
+
+	key = 0;
+	value = s;
+	err = xbpf_map_update_elem(sock_map, &key, &value, BPF_NOEXIST);
+	if (err)
+		goto close_srv;
+
+	c = xsocket(family, sotype, 0);
+	if (c < 0)
+		goto close_srv;
+	err = xconnect(c, sockaddr(&addr), len);
+	if (err)
+		goto close_cli;
+
+	p = xaccept(s, NULL, NULL);
+	if (p < 0)
+		goto close_cli;
+
+	key = SK_PASS;
+	err = xbpf_map_lookup_elem(verd_map, &key, &pass);
+	if (err)
+		goto close_peer;
+	if (pass != 1)
+		FAIL("want pass count 1, have %d", pass);
+
+close_peer:
+	xclose(p);
+close_cli:
+	xclose(c);
+close_srv:
+	xclose(s);
+}
+
+static void test_reuseport_select_connected(int family, int sotype,
+					    int sock_map, int verd_map,
+					    int reuseport_prog)
+{
+	struct sockaddr_storage addr;
+	int s, c0, c1, p0, err;
+	unsigned int drop;
+	socklen_t len;
+	u64 value;
+	u32 key;
+
+	zero_verdict_count(verd_map);
+
+	s = listen_loopback_reuseport(family, sotype, reuseport_prog);
+	if (s < 0)
+		return;
+
+	/* Populate sock_map[0] to avoid ENOENT on first connection */
+	key = 0;
+	value = s;
+	err = xbpf_map_update_elem(sock_map, &key, &value, BPF_NOEXIST);
+	if (err)
+		goto close_srv;
+
+	len = sizeof(addr);
+	err = xgetsockname(s, sockaddr(&addr), &len);
+	if (err)
+		goto close_srv;
+
+	c0 = xsocket(family, sotype, 0);
+	if (c0 < 0)
+		goto close_srv;
+
+	err = xconnect(c0, sockaddr(&addr), len);
+	if (err)
+		goto close_cli0;
+
+	p0 = xaccept(s, NULL, NULL);
+	if (err)
+		goto close_cli0;
+
+	/* Update sock_map[0] to redirect to a connected socket */
+	key = 0;
+	value = p0;
+	err = xbpf_map_update_elem(sock_map, &key, &value, BPF_EXIST);
+	if (err)
+		goto close_peer0;
+
+	c1 = xsocket(family, sotype, 0);
+	if (c1 < 0)
+		goto close_peer0;
+
+	errno = 0;
+	err = connect(c1, sockaddr(&addr), len);
+	if (!err || errno != ECONNREFUSED)
+		FAIL_ERRNO("connect: expected ECONNREFUSED");
+
+	key = SK_DROP;
+	err = xbpf_map_lookup_elem(verd_map, &key, &drop);
+	if (err)
+		goto close_cli1;
+	if (drop != 1)
+		FAIL("want drop count 1, have %d", drop);
+
+close_cli1:
+	xclose(c1);
+close_peer0:
+	xclose(p0);
+close_cli0:
+	xclose(c0);
+close_srv:
+	xclose(s);
+}
+
+/* Check that redirecting across reuseport groups is not allowed. */
+static void test_reuseport_mixed_groups(int family, int sotype, int sock_map,
+					int verd_map, int reuseport_prog)
+{
+	struct sockaddr_storage addr;
+	int s1, s2, c, err;
+	unsigned int drop;
+	socklen_t len;
+	u64 value;
+	u32 key;
+
+	zero_verdict_count(verd_map);
+
+	/* Create two listeners, each in its own reuseport group */
+	s1 = listen_loopback_reuseport(family, sotype, reuseport_prog);
+	if (s1 < 0)
+		return;
+
+	s2 = listen_loopback_reuseport(family, sotype, reuseport_prog);
+	if (s2 < 0)
+		goto close_srv1;
+
+	key = 0;
+	value = s1;
+	err = xbpf_map_update_elem(sock_map, &key, &value, BPF_NOEXIST);
+	if (err)
+		goto close_srv2;
+
+	key = 1;
+	value = s2;
+	err = xbpf_map_update_elem(sock_map, &key, &value, BPF_NOEXIST);
+
+	/* Connect to s2, reuseport BPF selects s1 via sock_map[0] */
+	len = sizeof(addr);
+	err = xgetsockname(s2, sockaddr(&addr), &len);
+	if (err)
+		goto close_srv2;
+
+	c = xsocket(family, sotype, 0);
+	if (c < 0)
+		goto close_srv2;
+
+	err = connect(c, sockaddr(&addr), len);
+	if (err && errno != ECONNREFUSED) {
+		FAIL_ERRNO("connect: expected ECONNREFUSED");
+		goto close_cli;
+	}
+
+	/* Expect drop, can't redirect outside of reuseport group */
+	key = SK_DROP;
+	err = xbpf_map_lookup_elem(verd_map, &key, &drop);
+	if (err)
+		goto close_cli;
+	if (drop != 1)
+		FAIL("want drop count 1, have %d", drop);
+
+close_cli:
+	xclose(c);
+close_srv2:
+	xclose(s2);
+close_srv1:
+	xclose(s1);
+}
+
+#define TEST(fn)                                                               \
+	{                                                                      \
+		fn, #fn                                                        \
+	}
+
+static void test_ops_cleanup(const struct bpf_map *map)
+{
+	const struct bpf_map_def *def;
+	int err, mapfd;
+	u32 key;
+
+	def = bpf_map__def(map);
+	mapfd = bpf_map__fd(map);
+
+	for (key = 0; key < def->max_entries; key++) {
+		err = bpf_map_delete_elem(mapfd, &key);
+		if (err && errno != EINVAL && errno != ENOENT)
+			FAIL_ERRNO("map_delete: expected EINVAL/ENOENT");
+	}
+}
+
+static const char *family_str(sa_family_t family)
+{
+	switch (family) {
+	case AF_INET:
+		return "IPv4";
+	case AF_INET6:
+		return "IPv6";
+	default:
+		return "unknown";
+	}
+}
+
+static const char *map_type_str(const struct bpf_map *map)
+{
+	const struct bpf_map_def *def;
+
+	def = bpf_map__def(map);
+	if (IS_ERR(def))
+		return "invalid";
+
+	switch (def->type) {
+	case BPF_MAP_TYPE_SOCKMAP:
+		return "sockmap";
+	case BPF_MAP_TYPE_SOCKHASH:
+		return "sockhash";
+	default:
+		return "unknown";
+	}
+}
+
+static void test_ops(struct test_sockmap_listen *skel, struct bpf_map *map,
+		     int family, int sotype)
+{
+	const struct op_test {
+		void (*fn)(int family, int sotype, int mapfd);
+		const char *name;
+	} tests[] = {
+		/* insert */
+		TEST(test_insert_invalid),
+		TEST(test_insert_opened),
+		TEST(test_insert_bound),
+		TEST(test_insert_listening),
+		/* delete */
+		TEST(test_delete_after_insert),
+		TEST(test_delete_after_close),
+		/* lookup */
+		TEST(test_lookup_after_insert),
+		TEST(test_lookup_after_delete),
+		TEST(test_lookup_32_bit_value),
+		/* update */
+		TEST(test_update_listening),
+		/* races with insert/delete */
+		TEST(test_destroy_orphan_child),
+		TEST(test_syn_recv_insert_delete),
+		TEST(test_race_insert_listen),
+		/* child clone */
+		TEST(test_clone_after_delete),
+		TEST(test_accept_after_delete),
+		TEST(test_accept_before_delete),
+	};
+	const char *family_name, *map_name;
+	const struct op_test *t;
+	char s[MAX_TEST_NAME];
+	int map_fd;
+
+	family_name = family_str(family);
+	map_name = map_type_str(map);
+	map_fd = bpf_map__fd(map);
+
+	for (t = tests; t < tests + ARRAY_SIZE(tests); t++) {
+		snprintf(s, sizeof(s), "%s %s %s", map_name, family_name,
+			 t->name);
+
+		if (!test__start_subtest(s))
+			continue;
+
+		t->fn(family, sotype, map_fd);
+		test_ops_cleanup(map);
+	}
+}
+
+static void test_redir(struct test_sockmap_listen *skel, struct bpf_map *map,
+		       int family, int sotype)
+{
+	const struct redir_test {
+		void (*fn)(struct test_sockmap_listen *skel,
+			   struct bpf_map *map, int family, int sotype);
+		const char *name;
+	} tests[] = {
+		TEST(test_skb_redir_to_connected),
+		TEST(test_skb_redir_to_listening),
+		TEST(test_msg_redir_to_connected),
+		TEST(test_msg_redir_to_listening),
+	};
+	const char *family_name, *map_name;
+	const struct redir_test *t;
+	char s[MAX_TEST_NAME];
+
+	family_name = family_str(family);
+	map_name = map_type_str(map);
+
+	for (t = tests; t < tests + ARRAY_SIZE(tests); t++) {
+		snprintf(s, sizeof(s), "%s %s %s", map_name, family_name,
+			 t->name);
+		if (!test__start_subtest(s))
+			continue;
+
+		t->fn(skel, map, family, sotype);
+	}
+}
+
+static void test_reuseport(struct test_sockmap_listen *skel,
+			   struct bpf_map *map, int family, int sotype)
+{
+	const struct reuseport_test {
+		void (*fn)(int family, int sotype, int socket_map,
+			   int verdict_map, int reuseport_prog);
+		const char *name;
+	} tests[] = {
+		TEST(test_reuseport_select_listening),
+		TEST(test_reuseport_select_connected),
+		TEST(test_reuseport_mixed_groups),
+	};
+	int socket_map, verdict_map, reuseport_prog;
+	const char *family_name, *map_name;
+	const struct reuseport_test *t;
+	char s[MAX_TEST_NAME];
+
+	family_name = family_str(family);
+	map_name = map_type_str(map);
+
+	socket_map = bpf_map__fd(map);
+	verdict_map = bpf_map__fd(skel->maps.verdict_map);
+	reuseport_prog = bpf_program__fd(skel->progs.prog_reuseport);
+
+	for (t = tests; t < tests + ARRAY_SIZE(tests); t++) {
+		snprintf(s, sizeof(s), "%s %s %s", map_name, family_name,
+			 t->name);
+
+		if (!test__start_subtest(s))
+			continue;
+
+		t->fn(family, sotype, socket_map, verdict_map, reuseport_prog);
+	}
+}
+
+static void run_tests(struct test_sockmap_listen *skel, struct bpf_map *map,
+		      int family)
+{
+	test_ops(skel, map, family, SOCK_STREAM);
+	test_redir(skel, map, family, SOCK_STREAM);
+	test_reuseport(skel, map, family, SOCK_STREAM);
+}
+
+void test_sockmap_listen(void)
+{
+	struct test_sockmap_listen *skel;
+
+	skel = test_sockmap_listen__open_and_load();
+	if (!skel) {
+		FAIL("skeleton open/load failed");
+		return;
+	}
+
+	skel->bss->test_sockmap = true;
+	run_tests(skel, skel->maps.sock_map, AF_INET);
+	run_tests(skel, skel->maps.sock_map, AF_INET6);
+
+	skel->bss->test_sockmap = false;
+	run_tests(skel, skel->maps.sock_hash, AF_INET);
+	run_tests(skel, skel->maps.sock_hash, AF_INET6);
+
+	test_sockmap_listen__destroy(skel);
+}
diff --git a/tools/testing/selftests/bpf/progs/test_sockmap_listen.c b/tools/testing/selftests/bpf/progs/test_sockmap_listen.c
new file mode 100644
index 0000000..a3a366c
--- /dev/null
+++ b/tools/testing/selftests/bpf/progs/test_sockmap_listen.c
@@ -0,0 +1,98 @@
+// SPDX-License-Identifier: GPL-2.0
+// Copyright (c) 2020 Cloudflare
+
+#include <errno.h>
+#include <stdbool.h>
+#include <linux/bpf.h>
+
+#include <bpf/bpf_helpers.h>
+
+struct {
+	__uint(type, BPF_MAP_TYPE_SOCKMAP);
+	__uint(max_entries, 2);
+	__type(key, __u32);
+	__type(value, __u64);
+} sock_map SEC(".maps");
+
+struct {
+	__uint(type, BPF_MAP_TYPE_SOCKHASH);
+	__uint(max_entries, 2);
+	__type(key, __u32);
+	__type(value, __u64);
+} sock_hash SEC(".maps");
+
+struct {
+	__uint(type, BPF_MAP_TYPE_ARRAY);
+	__uint(max_entries, 2);
+	__type(key, int);
+	__type(value, unsigned int);
+} verdict_map SEC(".maps");
+
+static volatile bool test_sockmap; /* toggled by user-space */
+
+SEC("sk_skb/stream_parser")
+int prog_skb_parser(struct __sk_buff *skb)
+{
+	return skb->len;
+}
+
+SEC("sk_skb/stream_verdict")
+int prog_skb_verdict(struct __sk_buff *skb)
+{
+	unsigned int *count;
+	__u32 zero = 0;
+	int verdict;
+
+	if (test_sockmap)
+		verdict = bpf_sk_redirect_map(skb, &sock_map, zero, 0);
+	else
+		verdict = bpf_sk_redirect_hash(skb, &sock_hash, &zero, 0);
+
+	count = bpf_map_lookup_elem(&verdict_map, &verdict);
+	if (count)
+		(*count)++;
+
+	return verdict;
+}
+
+SEC("sk_msg")
+int prog_msg_verdict(struct sk_msg_md *msg)
+{
+	unsigned int *count;
+	__u32 zero = 0;
+	int verdict;
+
+	if (test_sockmap)
+		verdict = bpf_msg_redirect_map(msg, &sock_map, zero, 0);
+	else
+		verdict = bpf_msg_redirect_hash(msg, &sock_hash, &zero, 0);
+
+	count = bpf_map_lookup_elem(&verdict_map, &verdict);
+	if (count)
+		(*count)++;
+
+	return verdict;
+}
+
+SEC("sk_reuseport")
+int prog_reuseport(struct sk_reuseport_md *reuse)
+{
+	unsigned int *count;
+	int err, verdict;
+	__u32 zero = 0;
+
+	if (test_sockmap)
+		err = bpf_sk_select_reuseport(reuse, &sock_map, &zero, 0);
+	else
+		err = bpf_sk_select_reuseport(reuse, &sock_hash, &zero, 0);
+	verdict = err ? SK_DROP : SK_PASS;
+
+	count = bpf_map_lookup_elem(&verdict_map, &verdict);
+	if (count)
+		(*count)++;
+
+	return verdict;
+}
+
+int _version SEC("version") = 1;
+char _license[] SEC("license") = "GPL";
diff --git a/tools/testing/selftests/bpf/test_maps.c b/tools/testing/selftests/bpf/test_maps.c
index 02eae1e..c6766b2 100644
--- a/tools/testing/selftests/bpf/test_maps.c
+++ b/tools/testing/selftests/bpf/test_maps.c
@@ -756,11 +756,7 @@ static void test_sockmap(unsigned int tasks, void *data)
 	/* Test update without programs */
 	for (i = 0; i < 6; i++) {
 		err = bpf_map_update_elem(fd, &i, &sfd[i], BPF_ANY);
-		if (i < 2 && !err) {
-			printf("Allowed update sockmap '%i:%i' not in ESTABLISHED\n",
-			       i, sfd[i]);
-			goto out_sockmap;
-		} else if (i >= 2 && err) {
+		if (err) {
 			printf("Failed noprog update sockmap '%i:%i'\n",
 			       i, sfd[i]);
 			goto out_sockmap;