crypto: poly1305 - Pass key as first two message blocks to each desc_ctx

The Poly1305 authenticator requires a unique key for each generated tag. This
implies that we can't set the key per tfm, as multiple users set individual
keys. Instead we pass a desc specific key as the first two blocks of the
message to authenticate in update().

Signed-off-by: Martin Willi <martin@strongswan.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
diff --git a/crypto/chacha20poly1305.c b/crypto/chacha20poly1305.c
index 05fbc59..7b46ed7 100644
--- a/crypto/chacha20poly1305.c
+++ b/crypto/chacha20poly1305.c
@@ -54,14 +54,14 @@
 };
 
 struct chacha_req {
-	/* the key we generate for Poly1305 using Chacha20 */
-	u8 key[POLY1305_KEY_SIZE];
 	u8 iv[CHACHA20_IV_SIZE];
 	struct scatterlist src[1];
 	struct ablkcipher_request req; /* must be last member */
 };
 
 struct chachapoly_req_ctx {
+	/* the key we generate for Poly1305 using Chacha20 */
+	u8 key[POLY1305_KEY_SIZE];
 	/* calculated Poly1305 tag */
 	u8 tag[POLY1305_DIGEST_SIZE];
 	/* length of data to en/decrypt, without ICV */
@@ -294,11 +294,38 @@
 	return poly_adpad(req);
 }
 
-static void poly_init_done(struct crypto_async_request *areq, int err)
+static void poly_setkey_done(struct crypto_async_request *areq, int err)
 {
 	async_done_continue(areq->data, err, poly_ad);
 }
 
+static int poly_setkey(struct aead_request *req)
+{
+	struct chachapoly_ctx *ctx = crypto_aead_ctx(crypto_aead_reqtfm(req));
+	struct chachapoly_req_ctx *rctx = aead_request_ctx(req);
+	struct poly_req *preq = &rctx->u.poly;
+	int err;
+
+	sg_init_table(preq->src, 1);
+	sg_set_buf(preq->src, rctx->key, sizeof(rctx->key));
+
+	ahash_request_set_callback(&preq->req, aead_request_flags(req),
+				   poly_setkey_done, req);
+	ahash_request_set_tfm(&preq->req, ctx->poly);
+	ahash_request_set_crypt(&preq->req, preq->src, NULL, sizeof(rctx->key));
+
+	err = crypto_ahash_update(&preq->req);
+	if (err)
+		return err;
+
+	return poly_ad(req);
+}
+
+static void poly_init_done(struct crypto_async_request *areq, int err)
+{
+	async_done_continue(areq->data, err, poly_setkey);
+}
+
 static int poly_init(struct aead_request *req)
 {
 	struct chachapoly_ctx *ctx = crypto_aead_ctx(crypto_aead_reqtfm(req));
@@ -314,33 +341,12 @@
 	if (err)
 		return err;
 
-	return poly_ad(req);
-}
-
-static int poly_genkey_continue(struct aead_request *req)
-{
-	struct crypto_aead *aead = crypto_aead_reqtfm(req);
-	struct chachapoly_ctx *ctx = crypto_aead_ctx(aead);
-	struct chachapoly_req_ctx *rctx = aead_request_ctx(req);
-	struct chacha_req *creq = &rctx->u.chacha;
-	int err;
-
-	crypto_ahash_clear_flags(ctx->poly, CRYPTO_TFM_REQ_MASK);
-	crypto_ahash_set_flags(ctx->poly, crypto_aead_get_flags(aead) &
-			       CRYPTO_TFM_REQ_MASK);
-
-	err = crypto_ahash_setkey(ctx->poly, creq->key, sizeof(creq->key));
-	crypto_aead_set_flags(aead, crypto_ahash_get_flags(ctx->poly) &
-			      CRYPTO_TFM_RES_MASK);
-	if (err)
-		return err;
-
-	return poly_init(req);
+	return poly_setkey(req);
 }
 
 static void poly_genkey_done(struct crypto_async_request *areq, int err)
 {
-	async_done_continue(areq->data, err, poly_genkey_continue);
+	async_done_continue(areq->data, err, poly_init);
 }
 
 static int poly_genkey(struct aead_request *req)
@@ -351,8 +357,8 @@
 	int err;
 
 	sg_init_table(creq->src, 1);
-	memset(creq->key, 0, sizeof(creq->key));
-	sg_set_buf(creq->src, creq->key, sizeof(creq->key));
+	memset(rctx->key, 0, sizeof(rctx->key));
+	sg_set_buf(creq->src, rctx->key, sizeof(rctx->key));
 
 	chacha_iv(creq->iv, req, 0);
 
@@ -366,7 +372,7 @@
 	if (err)
 		return err;
 
-	return poly_genkey_continue(req);
+	return poly_init(req);
 }
 
 static void chacha_encrypt_done(struct crypto_async_request *areq, int err)
@@ -403,8 +409,9 @@
 
 	/* encrypt call chain:
 	 * - chacha_encrypt/done()
-	 * - poly_genkey/done/continue()
+	 * - poly_genkey/done()
 	 * - poly_init/done()
+	 * - poly_setkey/done()
 	 * - poly_ad/done()
 	 * - poly_adpad/done()
 	 * - poly_cipher/done()
@@ -424,8 +431,9 @@
 	rctx->cryptlen = req->cryptlen - POLY1305_DIGEST_SIZE;
 
 	/* decrypt call chain:
-	 * - poly_genkey/done/continue()
+	 * - poly_genkey/done()
 	 * - poly_init/done()
+	 * - poly_setkey/done()
 	 * - poly_ad/done()
 	 * - poly_adpad/done()
 	 * - poly_cipher/done()