cifs: push rfc1002 generation down the stack

Move the generation of the 4 byte length field down the stack and
generate it immediately before we start writing the data to the socket.

Signed-off-by: Ronnie Sahlberg <lsahlber@redhat.com>
Signed-off-by: Aurelien Aptel <aaptel@suse.com>
Signed-off-by: Steve French <smfrench@gmail.com>
diff --git a/fs/cifs/smb2ops.c b/fs/cifs/smb2ops.c
index 682bcfa..9153407 100644
--- a/fs/cifs/smb2ops.c
+++ b/fs/cifs/smb2ops.c
@@ -2167,7 +2167,7 @@ fill_transform_hdr(struct smb2_transform_hdr *tr_hdr, unsigned int orig_len,
 		   struct smb_rqst *old_rq)
 {
 	struct smb2_sync_hdr *shdr =
-			(struct smb2_sync_hdr *)old_rq->rq_iov[1].iov_base;
+			(struct smb2_sync_hdr *)old_rq->rq_iov[0].iov_base;
 
 	memset(tr_hdr, 0, sizeof(struct smb2_transform_hdr));
 	tr_hdr->ProtocolId = SMB2_TRANSFORM_PROTO_NUM;
@@ -2187,14 +2187,13 @@ static inline void smb2_sg_set_buf(struct scatterlist *sg, const void *buf,
 }
 
 /* Assumes:
- * rqst->rq_iov[0]  is rfc1002 length
- * rqst->rq_iov[1]  is tranform header
- * rqst->rq_iov[2+] data to be encrypted/decrypted
+ * rqst->rq_iov[0]  is tranform header
+ * rqst->rq_iov[1+] data to be encrypted/decrypted
  */
 static struct scatterlist *
 init_sg(struct smb_rqst *rqst, u8 *sign)
 {
-	unsigned int sg_len = rqst->rq_nvec + rqst->rq_npages;
+	unsigned int sg_len = rqst->rq_nvec + rqst->rq_npages + 1;
 	unsigned int assoc_data_len = sizeof(struct smb2_transform_hdr) - 20;
 	struct scatterlist *sg;
 	unsigned int i;
@@ -2205,10 +2204,10 @@ init_sg(struct smb_rqst *rqst, u8 *sign)
 		return NULL;
 
 	sg_init_table(sg, sg_len);
-	smb2_sg_set_buf(&sg[0], rqst->rq_iov[1].iov_base + 20, assoc_data_len);
-	for (i = 1; i < rqst->rq_nvec - 1; i++)
-		smb2_sg_set_buf(&sg[i], rqst->rq_iov[i+1].iov_base,
-						rqst->rq_iov[i+1].iov_len);
+	smb2_sg_set_buf(&sg[0], rqst->rq_iov[0].iov_base + 20, assoc_data_len);
+	for (i = 1; i < rqst->rq_nvec; i++)
+		smb2_sg_set_buf(&sg[i], rqst->rq_iov[i].iov_base,
+						rqst->rq_iov[i].iov_len);
 	for (j = 0; i < sg_len - 1; i++, j++) {
 		unsigned int len, offset;
 
@@ -2240,11 +2239,10 @@ smb2_get_enc_key(struct TCP_Server_Info *server, __u64 ses_id, int enc, u8 *key)
 	return 1;
 }
 /*
- * Encrypt or decrypt @rqst message. @rqst has the following format:
- * iov[0] - rfc1002 length
- * iov[1] - transform header (associate data),
- * iov[2-N] and pages - data to encrypt.
- * On success return encrypted data in iov[2-N] and pages, leave iov[0-1]
+ * Encrypt or decrypt @rqst message. @rqst[0] has the following format:
+ * iov[0]   - transform header (associate data),
+ * iov[1-N] - SMB2 header and pages - data to encrypt.
+ * On success return encrypted data in iov[1-N] and pages, leave iov[0]
  * untouched.
  */
 static int
@@ -2339,10 +2337,6 @@ crypt_message(struct TCP_Server_Info *server, struct smb_rqst *rqst, int enc)
 	return rc;
 }
 
-/*
- * This is called from smb_send_rqst. At this point we have the rfc1002
- * header as the first element in the vector.
- */
 static int
 smb3_init_transform_rq(struct TCP_Server_Info *server, struct smb_rqst *new_rq,
 		       struct smb_rqst *old_rq)
@@ -2351,7 +2345,7 @@ smb3_init_transform_rq(struct TCP_Server_Info *server, struct smb_rqst *new_rq,
 	struct page **pages;
 	struct smb2_transform_hdr *tr_hdr;
 	unsigned int npages = old_rq->rq_npages;
-	unsigned int orig_len = get_rfc1002_length(old_rq->rq_iov[0].iov_base);
+	unsigned int orig_len = 0;
 	int i;
 	int rc = -ENOMEM;
 
@@ -2365,24 +2359,23 @@ smb3_init_transform_rq(struct TCP_Server_Info *server, struct smb_rqst *new_rq,
 	new_rq->rq_pagesz = old_rq->rq_pagesz;
 	new_rq->rq_tailsz = old_rq->rq_tailsz;
 
+	for (i = 0; i < old_rq->rq_nvec; i++)
+		orig_len += old_rq->rq_iov[i].iov_len;
+
 	for (i = 0; i < npages; i++) {
 		pages[i] = alloc_page(GFP_KERNEL|__GFP_HIGHMEM);
 		if (!pages[i])
 			goto err_free_pages;
 	}
 
-	/* Make space for one extra iov to hold the transform header */
 	iov = kmalloc_array(old_rq->rq_nvec + 1, sizeof(struct kvec),
 			    GFP_KERNEL);
 	if (!iov)
 		goto err_free_pages;
 
-	/* copy all iovs from the old except the 1st one (rfc1002 length) */
-	memcpy(&iov[2], &old_rq->rq_iov[1],
-				sizeof(struct kvec) * (old_rq->rq_nvec - 1));
-	/* copy the rfc1002 iov */
-	iov[0].iov_base = old_rq->rq_iov[0].iov_base;
-	iov[0].iov_len  = old_rq->rq_iov[0].iov_len;
+	/* copy all iovs from the old */
+	memcpy(&iov[1], &old_rq->rq_iov[0],
+				sizeof(struct kvec) * old_rq->rq_nvec);
 
 	new_rq->rq_iov = iov;
 	new_rq->rq_nvec = old_rq->rq_nvec + 1;
@@ -2393,12 +2386,8 @@ smb3_init_transform_rq(struct TCP_Server_Info *server, struct smb_rqst *new_rq,
 
 	/* fill the 2nd iov with a transform header */
 	fill_transform_hdr(tr_hdr, orig_len, old_rq);
-	new_rq->rq_iov[1].iov_base = tr_hdr;
-	new_rq->rq_iov[1].iov_len = sizeof(struct smb2_transform_hdr);
-
-	/* Update rfc1002 header */
-	inc_rfc1001_len(new_rq->rq_iov[0].iov_base,
-			sizeof(struct smb2_transform_hdr));
+	new_rq->rq_iov[0].iov_base = tr_hdr;
+	new_rq->rq_iov[0].iov_len = sizeof(struct smb2_transform_hdr);
 
 	/* copy pages form the old */
 	for (i = 0; i < npages; i++) {
@@ -2442,7 +2431,7 @@ smb3_free_transform_rq(struct smb_rqst *rqst)
 		put_page(rqst->rq_pages[i]);
 	kfree(rqst->rq_pages);
 	/* free transform header */
-	kfree(rqst->rq_iov[1].iov_base);
+	kfree(rqst->rq_iov[0].iov_base);
 	kfree(rqst->rq_iov);
 }
 
@@ -2459,19 +2448,17 @@ decrypt_raw_data(struct TCP_Server_Info *server, char *buf,
 		 unsigned int buf_data_size, struct page **pages,
 		 unsigned int npages, unsigned int page_data_size)
 {
-	struct kvec iov[3];
+	struct kvec iov[2];
 	struct smb_rqst rqst = {NULL};
 	int rc;
 
-	iov[0].iov_base = NULL;
-	iov[0].iov_len = 0;
-	iov[1].iov_base = buf;
-	iov[1].iov_len = sizeof(struct smb2_transform_hdr);
-	iov[2].iov_base = buf + sizeof(struct smb2_transform_hdr);
-	iov[2].iov_len = buf_data_size;
+	iov[0].iov_base = buf;
+	iov[0].iov_len = sizeof(struct smb2_transform_hdr);
+	iov[1].iov_base = buf + sizeof(struct smb2_transform_hdr);
+	iov[1].iov_len = buf_data_size;
 
 	rqst.rq_iov = iov;
-	rqst.rq_nvec = 3;
+	rqst.rq_nvec = 2;
 	rqst.rq_pages = pages;
 	rqst.rq_npages = npages;
 	rqst.rq_pagesz = PAGE_SIZE;
@@ -2483,7 +2470,7 @@ decrypt_raw_data(struct TCP_Server_Info *server, char *buf,
 	if (rc)
 		return rc;
 
-	memmove(buf, iov[2].iov_base, buf_data_size);
+	memmove(buf, iov[1].iov_base, buf_data_size);
 
 	server->total_read = buf_data_size + page_data_size;