xfrm: add support for UDPv6 encapsulation of ESP

This patch adds support for encapsulation of ESP over UDPv6. The code
is very similar to the IPv4 encapsulation implementation, and allows
to easily add espintcp on IPv6 as a follow-up.

Signed-off-by: Sabrina Dubroca <sd@queasysnail.net>
Signed-off-by: Steffen Klassert <steffen.klassert@secunet.com>
diff --git a/net/ipv6/esp6.c b/net/ipv6/esp6.c
index 11143d0..e880096 100644
--- a/net/ipv6/esp6.c
+++ b/net/ipv6/esp6.c
@@ -26,10 +26,12 @@
 #include <linux/random.h>
 #include <linux/slab.h>
 #include <linux/spinlock.h>
+#include <net/ip6_checksum.h>
 #include <net/ip6_route.h>
 #include <net/icmp.h>
 #include <net/ipv6.h>
 #include <net/protocol.h>
+#include <net/udp.h>
 #include <linux/icmpv6.h>
 
 #include <linux/highmem.h>
@@ -39,6 +41,11 @@ struct esp_skb_cb {
 	void *tmp;
 };
 
+struct esp_output_extra {
+	__be32 seqhi;
+	u32 esphoff;
+};
+
 #define ESP_SKB_CB(__skb) ((struct esp_skb_cb *)&((__skb)->cb[0]))
 
 /*
@@ -72,9 +79,9 @@ static void *esp_alloc_tmp(struct crypto_aead *aead, int nfrags, int seqihlen)
 	return kmalloc(len, GFP_ATOMIC);
 }
 
-static inline __be32 *esp_tmp_seqhi(void *tmp)
+static inline void *esp_tmp_extra(void *tmp)
 {
-	return PTR_ALIGN((__be32 *)tmp, __alignof__(__be32));
+	return PTR_ALIGN(tmp, __alignof__(struct esp_output_extra));
 }
 
 static inline u8 *esp_tmp_iv(struct crypto_aead *aead, void *tmp, int seqhilen)
@@ -104,16 +111,17 @@ static inline struct scatterlist *esp_req_sg(struct crypto_aead *aead,
 
 static void esp_ssg_unref(struct xfrm_state *x, void *tmp)
 {
+	struct esp_output_extra *extra = esp_tmp_extra(tmp);
 	struct crypto_aead *aead = x->data;
-	int seqhilen = 0;
+	int extralen = 0;
 	u8 *iv;
 	struct aead_request *req;
 	struct scatterlist *sg;
 
 	if (x->props.flags & XFRM_STATE_ESN)
-		seqhilen += sizeof(__be32);
+		extralen += sizeof(*extra);
 
-	iv = esp_tmp_iv(aead, tmp, seqhilen);
+	iv = esp_tmp_iv(aead, tmp, extralen);
 	req = esp_tmp_req(aead, iv);
 
 	/* Unref skb_frag_pages in the src scatterlist if necessary.
@@ -124,6 +132,23 @@ static void esp_ssg_unref(struct xfrm_state *x, void *tmp)
 			put_page(sg_page(sg));
 }
 
+static void esp_output_encap_csum(struct sk_buff *skb)
+{
+	/* UDP encap with IPv6 requires a valid checksum */
+	if (*skb_mac_header(skb) == IPPROTO_UDP) {
+		struct udphdr *uh = udp_hdr(skb);
+		struct ipv6hdr *ip6h = ipv6_hdr(skb);
+		int len = ntohs(uh->len);
+		unsigned int offset = skb_transport_offset(skb);
+		__wsum csum = skb_checksum(skb, offset, skb->len - offset, 0);
+
+		uh->check = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
+					    len, IPPROTO_UDP, csum);
+		if (uh->check == 0)
+			uh->check = CSUM_MANGLED_0;
+	}
+}
+
 static void esp_output_done(struct crypto_async_request *base, int err)
 {
 	struct sk_buff *skb = base->data;
@@ -143,6 +168,8 @@ static void esp_output_done(struct crypto_async_request *base, int err)
 	esp_ssg_unref(x, tmp);
 	kfree(tmp);
 
+	esp_output_encap_csum(skb);
+
 	if (xo && (xo->flags & XFRM_DEV_RESUME)) {
 		if (err) {
 			XFRM_INC_STATS(xs_net(x), LINUX_MIB_XFRMOUTSTATEPROTOERROR);
@@ -163,7 +190,7 @@ static void esp_restore_header(struct sk_buff *skb, unsigned int offset)
 {
 	struct ip_esp_hdr *esph = (void *)(skb->data + offset);
 	void *tmp = ESP_SKB_CB(skb)->tmp;
-	__be32 *seqhi = esp_tmp_seqhi(tmp);
+	__be32 *seqhi = esp_tmp_extra(tmp);
 
 	esph->seq_no = esph->spi;
 	esph->spi = *seqhi;
@@ -171,27 +198,36 @@ static void esp_restore_header(struct sk_buff *skb, unsigned int offset)
 
 static void esp_output_restore_header(struct sk_buff *skb)
 {
-	esp_restore_header(skb, skb_transport_offset(skb) - sizeof(__be32));
+	void *tmp = ESP_SKB_CB(skb)->tmp;
+	struct esp_output_extra *extra = esp_tmp_extra(tmp);
+
+	esp_restore_header(skb, skb_transport_offset(skb) + extra->esphoff -
+				sizeof(__be32));
 }
 
 static struct ip_esp_hdr *esp_output_set_esn(struct sk_buff *skb,
 					     struct xfrm_state *x,
 					     struct ip_esp_hdr *esph,
-					     __be32 *seqhi)
+					     struct esp_output_extra *extra)
 {
 	/* For ESN we move the header forward by 4 bytes to
 	 * accomodate the high bits.  We will move it back after
 	 * encryption.
 	 */
 	if ((x->props.flags & XFRM_STATE_ESN)) {
+		__u32 seqhi;
 		struct xfrm_offload *xo = xfrm_offload(skb);
 
-		esph = (void *)(skb_transport_header(skb) - sizeof(__be32));
-		*seqhi = esph->spi;
 		if (xo)
-			esph->seq_no = htonl(xo->seq.hi);
+			seqhi = xo->seq.hi;
 		else
-			esph->seq_no = htonl(XFRM_SKB_CB(skb)->seq.output.hi);
+			seqhi = XFRM_SKB_CB(skb)->seq.output.hi;
+
+		extra->esphoff = (unsigned char *)esph -
+				 skb_transport_header(skb);
+		esph = (struct ip_esp_hdr *)((unsigned char *)esph - 4);
+		extra->seqhi = esph->spi;
+		esph->seq_no = htonl(seqhi);
 	}
 
 	esph->spi = x->id.spi;
@@ -207,15 +243,84 @@ static void esp_output_done_esn(struct crypto_async_request *base, int err)
 	esp_output_done(base, err);
 }
 
+static struct ip_esp_hdr *esp6_output_udp_encap(struct sk_buff *skb,
+					       int encap_type,
+					       struct esp_info *esp,
+					       __be16 sport,
+					       __be16 dport)
+{
+	struct udphdr *uh;
+	__be32 *udpdata32;
+	unsigned int len;
+
+	len = skb->len + esp->tailen - skb_transport_offset(skb);
+	if (len > U16_MAX)
+		return ERR_PTR(-EMSGSIZE);
+
+	uh = (struct udphdr *)esp->esph;
+	uh->source = sport;
+	uh->dest = dport;
+	uh->len = htons(len);
+	uh->check = 0;
+
+	*skb_mac_header(skb) = IPPROTO_UDP;
+
+	if (encap_type == UDP_ENCAP_ESPINUDP_NON_IKE) {
+		udpdata32 = (__be32 *)(uh + 1);
+		udpdata32[0] = udpdata32[1] = 0;
+		return (struct ip_esp_hdr *)(udpdata32 + 2);
+	}
+
+	return (struct ip_esp_hdr *)(uh + 1);
+}
+
+static int esp6_output_encap(struct xfrm_state *x, struct sk_buff *skb,
+			    struct esp_info *esp)
+{
+	struct xfrm_encap_tmpl *encap = x->encap;
+	struct ip_esp_hdr *esph;
+	__be16 sport, dport;
+	int encap_type;
+
+	spin_lock_bh(&x->lock);
+	sport = encap->encap_sport;
+	dport = encap->encap_dport;
+	encap_type = encap->encap_type;
+	spin_unlock_bh(&x->lock);
+
+	switch (encap_type) {
+	default:
+	case UDP_ENCAP_ESPINUDP:
+	case UDP_ENCAP_ESPINUDP_NON_IKE:
+		esph = esp6_output_udp_encap(skb, encap_type, esp, sport, dport);
+		break;
+	}
+
+	if (IS_ERR(esph))
+		return PTR_ERR(esph);
+
+	esp->esph = esph;
+
+	return 0;
+}
+
 int esp6_output_head(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *esp)
 {
 	u8 *tail;
 	u8 *vaddr;
 	int nfrags;
+	int esph_offset;
 	struct page *page;
 	struct sk_buff *trailer;
 	int tailen = esp->tailen;
 
+	if (x->encap) {
+		int err = esp6_output_encap(x, skb, esp);
+
+		if (err < 0)
+			return err;
+	}
+
 	if (!skb_cloned(skb)) {
 		if (tailen <= skb_tailroom(skb)) {
 			nfrags = 1;
@@ -274,10 +379,13 @@ int esp6_output_head(struct xfrm_state *x, struct sk_buff *skb, struct esp_info
 	}
 
 cow:
+	esph_offset = (unsigned char *)esp->esph - skb_transport_header(skb);
+
 	nfrags = skb_cow_data(skb, tailen, &trailer);
 	if (nfrags < 0)
 		goto out;
 	tail = skb_tail_pointer(trailer);
+	esp->esph = (struct ip_esp_hdr *)(skb_transport_header(skb) + esph_offset);
 
 skip_cow:
 	esp_output_fill_trailer(tail, esp->tfclen, esp->plen, esp->proto);
@@ -295,20 +403,20 @@ int esp6_output_tail(struct xfrm_state *x, struct sk_buff *skb, struct esp_info
 	void *tmp;
 	int ivlen;
 	int assoclen;
-	int seqhilen;
-	__be32 *seqhi;
+	int extralen;
 	struct page *page;
 	struct ip_esp_hdr *esph;
 	struct aead_request *req;
 	struct crypto_aead *aead;
 	struct scatterlist *sg, *dsg;
+	struct esp_output_extra *extra;
 	int err = -ENOMEM;
 
 	assoclen = sizeof(struct ip_esp_hdr);
-	seqhilen = 0;
+	extralen = 0;
 
 	if (x->props.flags & XFRM_STATE_ESN) {
-		seqhilen += sizeof(__be32);
+		extralen += sizeof(*extra);
 		assoclen += sizeof(__be32);
 	}
 
@@ -316,12 +424,12 @@ int esp6_output_tail(struct xfrm_state *x, struct sk_buff *skb, struct esp_info
 	alen = crypto_aead_authsize(aead);
 	ivlen = crypto_aead_ivsize(aead);
 
-	tmp = esp_alloc_tmp(aead, esp->nfrags + 2, seqhilen);
+	tmp = esp_alloc_tmp(aead, esp->nfrags + 2, extralen);
 	if (!tmp)
 		goto error;
 
-	seqhi = esp_tmp_seqhi(tmp);
-	iv = esp_tmp_iv(aead, tmp, seqhilen);
+	extra = esp_tmp_extra(tmp);
+	iv = esp_tmp_iv(aead, tmp, extralen);
 	req = esp_tmp_req(aead, iv);
 	sg = esp_req_sg(aead, req);
 
@@ -330,7 +438,8 @@ int esp6_output_tail(struct xfrm_state *x, struct sk_buff *skb, struct esp_info
 	else
 		dsg = &sg[esp->nfrags];
 
-	esph = esp_output_set_esn(skb, x, ip_esp_hdr(skb), seqhi);
+	esph = esp_output_set_esn(skb, x, esp->esph, extra);
+	esp->esph = esph;
 
 	sg_init_table(sg, esp->nfrags);
 	err = skb_to_sgvec(skb, sg,
@@ -394,6 +503,7 @@ int esp6_output_tail(struct xfrm_state *x, struct sk_buff *skb, struct esp_info
 	case 0:
 		if ((x->props.flags & XFRM_STATE_ESN))
 			esp_output_restore_header(skb);
+		esp_output_encap_csum(skb);
 	}
 
 	if (sg != dsg)
@@ -438,11 +548,13 @@ static int esp6_output(struct xfrm_state *x, struct sk_buff *skb)
 	esp.plen = esp.clen - skb->len - esp.tfclen;
 	esp.tailen = esp.tfclen + esp.plen + alen;
 
+	esp.esph = ip_esp_hdr(skb);
+
 	esp.nfrags = esp6_output_head(x, skb, &esp);
 	if (esp.nfrags < 0)
 		return esp.nfrags;
 
-	esph = ip_esp_hdr(skb);
+	esph = esp.esph;
 	esph->spi = x->id.spi;
 
 	esph->seq_no = htonl(XFRM_SKB_CB(skb)->seq.output.low);
@@ -517,6 +629,56 @@ int esp6_input_done2(struct sk_buff *skb, int err)
 	if (unlikely(err < 0))
 		goto out;
 
+	if (x->encap) {
+		const struct ipv6hdr *ip6h = ipv6_hdr(skb);
+		struct xfrm_encap_tmpl *encap = x->encap;
+		struct udphdr *uh = (void *)(skb_network_header(skb) + hdr_len);
+		__be16 source;
+
+		switch (x->encap->encap_type) {
+		case UDP_ENCAP_ESPINUDP:
+		case UDP_ENCAP_ESPINUDP_NON_IKE:
+			source = uh->source;
+			break;
+		default:
+			WARN_ON_ONCE(1);
+			err = -EINVAL;
+			goto out;
+		}
+
+		/*
+		 * 1) if the NAT-T peer's IP or port changed then
+		 *    advertize the change to the keying daemon.
+		 *    This is an inbound SA, so just compare
+		 *    SRC ports.
+		 */
+		if (!ipv6_addr_equal(&ip6h->saddr, &x->props.saddr.in6) ||
+		    source != encap->encap_sport) {
+			xfrm_address_t ipaddr;
+
+			memcpy(&ipaddr.a6, &ip6h->saddr.s6_addr, sizeof(ipaddr.a6));
+			km_new_mapping(x, &ipaddr, source);
+
+			/* XXX: perhaps add an extra
+			 * policy check here, to see
+			 * if we should allow or
+			 * reject a packet from a
+			 * different source
+			 * address/port.
+			 */
+		}
+
+		/*
+		 * 2) ignore UDP/TCP checksums in case
+		 *    of NAT-T in Transport Mode, or
+		 *    perform other post-processing fixes
+		 *    as per draft-ietf-ipsec-udp-encaps-06,
+		 *    section 3.1.2
+		 */
+		if (x->props.mode == XFRM_MODE_TRANSPORT)
+			skb->ip_summed = CHECKSUM_UNNECESSARY;
+	}
+
 	skb_postpull_rcsum(skb, skb_network_header(skb),
 			   skb_network_header_len(skb));
 	skb_pull_rcsum(skb, hlen);
@@ -632,7 +794,7 @@ static int esp6_input(struct xfrm_state *x, struct sk_buff *skb)
 		goto out;
 
 	ESP_SKB_CB(skb)->tmp = tmp;
-	seqhi = esp_tmp_seqhi(tmp);
+	seqhi = esp_tmp_extra(tmp);
 	iv = esp_tmp_iv(aead, tmp, seqhilen);
 	req = esp_tmp_req(aead, iv);
 	sg = esp_req_sg(aead, req);
@@ -836,9 +998,6 @@ static int esp6_init_state(struct xfrm_state *x)
 	u32 align;
 	int err;
 
-	if (x->encap)
-		return -EINVAL;
-
 	x->data = NULL;
 
 	if (x->aead)
@@ -867,6 +1026,22 @@ static int esp6_init_state(struct xfrm_state *x)
 		break;
 	}
 
+	if (x->encap) {
+		struct xfrm_encap_tmpl *encap = x->encap;
+
+		switch (encap->encap_type) {
+		default:
+			err = -EINVAL;
+			goto error;
+		case UDP_ENCAP_ESPINUDP:
+			x->props.header_len += sizeof(struct udphdr);
+			break;
+		case UDP_ENCAP_ESPINUDP_NON_IKE:
+			x->props.header_len += sizeof(struct udphdr) + 2 * sizeof(u32);
+			break;
+		}
+	}
+
 	align = ALIGN(crypto_aead_blocksize(aead), 4);
 	x->props.trailer_len = align + 1 + crypto_aead_authsize(aead);
 
@@ -893,6 +1068,7 @@ static const struct xfrm_type esp6_type = {
 
 static struct xfrm6_protocol esp6_protocol = {
 	.handler	=	xfrm6_rcv,
+	.input_handler	=	xfrm_input,
 	.cb_handler	=	esp6_rcv_cb,
 	.err_handler	=	esp6_err,
 	.priority	=	0,