net: tls: Support 256 bit keys

Wire up support for 256 bit keys from the setsockopt to the crypto
framework

Signed-off-by: Dave Watson <davejwatson@fb.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/include/net/tls.h b/include/net/tls.h
index 4592606..da616db 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -206,7 +206,10 @@ struct cipher_context {
 
 union tls_crypto_context {
 	struct tls_crypto_info info;
-	struct tls12_crypto_info_aes_gcm_128 aes_gcm_128;
+	union {
+		struct tls12_crypto_info_aes_gcm_128 aes_gcm_128;
+		struct tls12_crypto_info_aes_gcm_256 aes_gcm_256;
+	};
 };
 
 struct tls_context {
diff --git a/include/uapi/linux/tls.h b/include/uapi/linux/tls.h
index ff02287..9affcea 100644
--- a/include/uapi/linux/tls.h
+++ b/include/uapi/linux/tls.h
@@ -59,6 +59,13 @@
 #define TLS_CIPHER_AES_GCM_128_TAG_SIZE		16
 #define TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE		8
 
+#define TLS_CIPHER_AES_GCM_256				52
+#define TLS_CIPHER_AES_GCM_256_IV_SIZE			8
+#define TLS_CIPHER_AES_GCM_256_KEY_SIZE		32
+#define TLS_CIPHER_AES_GCM_256_SALT_SIZE		4
+#define TLS_CIPHER_AES_GCM_256_TAG_SIZE		16
+#define TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE		8
+
 #define TLS_SET_RECORD_TYPE	1
 #define TLS_GET_RECORD_TYPE	2
 
@@ -75,4 +82,12 @@ struct tls12_crypto_info_aes_gcm_128 {
 	unsigned char rec_seq[TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE];
 };
 
+struct tls12_crypto_info_aes_gcm_256 {
+	struct tls_crypto_info info;
+	unsigned char iv[TLS_CIPHER_AES_GCM_256_IV_SIZE];
+	unsigned char key[TLS_CIPHER_AES_GCM_256_KEY_SIZE];
+	unsigned char salt[TLS_CIPHER_AES_GCM_256_SALT_SIZE];
+	unsigned char rec_seq[TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE];
+};
+
 #endif /* _UAPI_LINUX_TLS_H */
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index d36d095..0f028cf 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -372,6 +372,30 @@ static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
 			rc = -EFAULT;
 		break;
 	}
+	case TLS_CIPHER_AES_GCM_256: {
+		struct tls12_crypto_info_aes_gcm_256 *
+		  crypto_info_aes_gcm_256 =
+		  container_of(crypto_info,
+			       struct tls12_crypto_info_aes_gcm_256,
+			       info);
+
+		if (len != sizeof(*crypto_info_aes_gcm_256)) {
+			rc = -EINVAL;
+			goto out;
+		}
+		lock_sock(sk);
+		memcpy(crypto_info_aes_gcm_256->iv,
+		       ctx->tx.iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
+		       TLS_CIPHER_AES_GCM_256_IV_SIZE);
+		memcpy(crypto_info_aes_gcm_256->rec_seq, ctx->tx.rec_seq,
+		       TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
+		release_sock(sk);
+		if (copy_to_user(optval,
+				 crypto_info_aes_gcm_256,
+				 sizeof(*crypto_info_aes_gcm_256)))
+			rc = -EFAULT;
+		break;
+	}
 	default:
 		rc = -EINVAL;
 	}
@@ -412,6 +436,7 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
 {
 	struct tls_crypto_info *crypto_info;
 	struct tls_context *ctx = tls_get_ctx(sk);
+	size_t optsize;
 	int rc = 0;
 	int conf;
 
@@ -444,8 +469,12 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
 	}
 
 	switch (crypto_info->cipher_type) {
-	case TLS_CIPHER_AES_GCM_128: {
-		if (optlen != sizeof(struct tls12_crypto_info_aes_gcm_128)) {
+	case TLS_CIPHER_AES_GCM_128:
+	case TLS_CIPHER_AES_GCM_256: {
+		optsize = crypto_info->cipher_type == TLS_CIPHER_AES_GCM_128 ?
+			sizeof(struct tls12_crypto_info_aes_gcm_128) :
+			sizeof(struct tls12_crypto_info_aes_gcm_256);
+		if (optlen != optsize) {
 			rc = -EINVAL;
 			goto err_crypto_info;
 		}
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 3f2a6af..9326c06 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -1999,6 +1999,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 {
 	struct tls_crypto_info *crypto_info;
 	struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
+	struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
 	struct tls_sw_context_tx *sw_ctx_tx = NULL;
 	struct tls_sw_context_rx *sw_ctx_rx = NULL;
 	struct cipher_context *cctx;
@@ -2006,7 +2007,8 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 	struct strp_callbacks cb;
 	u16 nonce_size, tag_size, iv_size, rec_seq_size;
 	struct crypto_tfm *tfm;
-	char *iv, *rec_seq;
+	char *iv, *rec_seq, *key, *salt;
+	size_t keysize;
 	int rc = 0;
 
 	if (!ctx) {
@@ -2067,6 +2069,24 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
 		gcm_128_info =
 			(struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
+		keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE;
+		key = gcm_128_info->key;
+		salt = gcm_128_info->salt;
+		break;
+	}
+	case TLS_CIPHER_AES_GCM_256: {
+		nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
+		tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE;
+		iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
+		iv = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->iv;
+		rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE;
+		rec_seq =
+		 ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->rec_seq;
+		gcm_256_info =
+			(struct tls12_crypto_info_aes_gcm_256 *)crypto_info;
+		keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE;
+		key = gcm_256_info->key;
+		salt = gcm_256_info->salt;
 		break;
 	}
 	default:
@@ -2090,7 +2110,8 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 		rc = -ENOMEM;
 		goto free_priv;
 	}
-	memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+	/* Note: 128 & 256 bit salt are the same size */
+	memcpy(cctx->iv, salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
 	memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
 	cctx->rec_seq_size = rec_seq_size;
 	cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
@@ -2110,8 +2131,8 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 
 	ctx->push_pending_record = tls_sw_push_pending_record;
 
-	rc = crypto_aead_setkey(*aead, gcm_128_info->key,
-				TLS_CIPHER_AES_GCM_128_KEY_SIZE);
+	rc = crypto_aead_setkey(*aead, key, keysize);
+
 	if (rc)
 		goto free_aead;
 
diff --git a/tools/testing/selftests/net/tls.c b/tools/testing/selftests/net/tls.c
index ff68ed1..c356f48 100644
--- a/tools/testing/selftests/net/tls.c
+++ b/tools/testing/selftests/net/tls.c
@@ -763,4 +763,66 @@ TEST_F(tls, control_msg)
 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
 }
 
+TEST(keysizes) {
+	struct tls12_crypto_info_aes_gcm_256 tls12;
+	struct sockaddr_in addr;
+	int sfd, ret, fd, cfd;
+	socklen_t len;
+	bool notls;
+
+	notls = false;
+	len = sizeof(addr);
+
+	memset(&tls12, 0, sizeof(tls12));
+	tls12.info.version = TLS_1_2_VERSION;
+	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
+
+	addr.sin_family = AF_INET;
+	addr.sin_addr.s_addr = htonl(INADDR_ANY);
+	addr.sin_port = 0;
+
+	fd = socket(AF_INET, SOCK_STREAM, 0);
+	sfd = socket(AF_INET, SOCK_STREAM, 0);
+
+	ret = bind(sfd, &addr, sizeof(addr));
+	ASSERT_EQ(ret, 0);
+	ret = listen(sfd, 10);
+	ASSERT_EQ(ret, 0);
+
+	ret = getsockname(sfd, &addr, &len);
+	ASSERT_EQ(ret, 0);
+
+	ret = connect(fd, &addr, sizeof(addr));
+	ASSERT_EQ(ret, 0);
+
+	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
+	if (ret != 0) {
+		notls = true;
+		printf("Failure setting TCP_ULP, testing without tls\n");
+	}
+
+	if (!notls) {
+		ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
+				 sizeof(tls12));
+		EXPECT_EQ(ret, 0);
+	}
+
+	cfd = accept(sfd, &addr, &len);
+	ASSERT_GE(cfd, 0);
+
+	if (!notls) {
+		ret = setsockopt(cfd, IPPROTO_TCP, TCP_ULP, "tls",
+				 sizeof("tls"));
+		EXPECT_EQ(ret, 0);
+
+		ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
+				 sizeof(tls12));
+		EXPECT_EQ(ret, 0);
+	}
+
+	close(sfd);
+	close(fd);
+	close(cfd);
+}
+
 TEST_HARNESS_MAIN