usb: xhci: merge xhci_queue_bulk_tx and queue_bulk_sg_tx functions

In drivers/usb/host/xhci-ring.c there are two functions
(xhci_queue_bulk_tx and queue_bulk_sg_tx) that are very similar,
so a lot of code duplication.
This patch merges these functions into to one xhci_queue_bulk_tx.
Also counting the needed TRBs is merged and refactored.

Signed-off-by: Alexandr Ivanov <alexandr.sky@gmail.com>
Signed-off-by: Mathias Nyman <mathias.nyman@linux.intel.com>
Signed-off-by: Greg Kroah-Hartman <gregkh@linuxfoundation.org>
diff --git a/drivers/usb/host/xhci-ring.c b/drivers/usb/host/xhci-ring.c
index 251e299..d7dd32e 100644
--- a/drivers/usb/host/xhci-ring.c
+++ b/drivers/usb/host/xhci-ring.c
@@ -2938,46 +2938,55 @@
 	return 0;
 }
 
-static unsigned int count_sg_trbs_needed(struct xhci_hcd *xhci, struct urb *urb)
+static unsigned int count_trbs(u64 addr, u64 len)
 {
-	int num_sgs, num_trbs, running_total, temp, i;
-	struct scatterlist *sg;
+	unsigned int num_trbs;
 
-	sg = NULL;
-	num_sgs = urb->num_mapped_sgs;
-	temp = urb->transfer_buffer_length;
+	num_trbs = DIV_ROUND_UP(len + (addr & (TRB_MAX_BUFF_SIZE - 1)),
+			TRB_MAX_BUFF_SIZE);
+	if (num_trbs == 0)
+		num_trbs++;
 
-	num_trbs = 0;
-	for_each_sg(urb->sg, sg, num_sgs, i) {
-		unsigned int len = sg_dma_len(sg);
-
-		/* Scatter gather list entries may cross 64KB boundaries */
-		running_total = TRB_MAX_BUFF_SIZE -
-			(sg_dma_address(sg) & (TRB_MAX_BUFF_SIZE - 1));
-		running_total &= TRB_MAX_BUFF_SIZE - 1;
-		if (running_total != 0)
-			num_trbs++;
-
-		/* How many more 64KB chunks to transfer, how many more TRBs? */
-		while (running_total < sg_dma_len(sg) && running_total < temp) {
-			num_trbs++;
-			running_total += TRB_MAX_BUFF_SIZE;
-		}
-		len = min_t(int, len, temp);
-		temp -= len;
-		if (temp == 0)
-			break;
-	}
 	return num_trbs;
 }
 
-static void check_trb_math(struct urb *urb, int num_trbs, int running_total)
+static inline unsigned int count_trbs_needed(struct urb *urb)
 {
-	if (num_trbs != 0)
-		dev_err(&urb->dev->dev, "%s - ep %#x - Miscalculated number of "
-				"TRBs, %d left\n", __func__,
-				urb->ep->desc.bEndpointAddress, num_trbs);
-	if (running_total != urb->transfer_buffer_length)
+	return count_trbs(urb->transfer_dma, urb->transfer_buffer_length);
+}
+
+static unsigned int count_sg_trbs_needed(struct urb *urb)
+{
+	struct scatterlist *sg;
+	unsigned int i, len, full_len, num_trbs = 0;
+
+	full_len = urb->transfer_buffer_length;
+
+	for_each_sg(urb->sg, sg, urb->num_mapped_sgs, i) {
+		len = sg_dma_len(sg);
+		num_trbs += count_trbs(sg_dma_address(sg), len);
+		len = min_t(unsigned int, len, full_len);
+		full_len -= len;
+		if (full_len == 0)
+			break;
+	}
+
+	return num_trbs;
+}
+
+static unsigned int count_isoc_trbs_needed(struct urb *urb, int i)
+{
+	u64 addr, len;
+
+	addr = (u64) (urb->transfer_dma + urb->iso_frame_desc[i].offset);
+	len = urb->iso_frame_desc[i].length;
+
+	return count_trbs(addr, len);
+}
+
+static void check_trb_math(struct urb *urb, int running_total)
+{
+	if (unlikely(running_total != urb->transfer_buffer_length))
 		dev_err(&urb->dev->dev, "%s - ep %#x - Miscalculated tx length, "
 				"queued %#x (%d), asked for %#x (%d)\n",
 				__func__,
@@ -3086,177 +3095,6 @@
 	return (total_packet_count - ((transferred + trb_buff_len) / maxp));
 }
 
-
-static int queue_bulk_sg_tx(struct xhci_hcd *xhci, gfp_t mem_flags,
-		struct urb *urb, int slot_id, unsigned int ep_index)
-{
-	struct xhci_ring *ep_ring;
-	unsigned int num_trbs;
-	struct urb_priv *urb_priv;
-	struct xhci_td *td;
-	struct scatterlist *sg;
-	int num_sgs;
-	int trb_buff_len, this_sg_len, running_total, ret;
-	unsigned int total_packet_count;
-	bool zero_length_needed;
-	bool first_trb;
-	int last_trb_num;
-	u64 addr;
-	bool more_trbs_coming;
-
-	struct xhci_generic_trb *start_trb;
-	int start_cycle;
-
-	ep_ring = xhci_urb_to_transfer_ring(xhci, urb);
-	if (!ep_ring)
-		return -EINVAL;
-
-	num_trbs = count_sg_trbs_needed(xhci, urb);
-	num_sgs = urb->num_mapped_sgs;
-	total_packet_count = DIV_ROUND_UP(urb->transfer_buffer_length,
-			usb_endpoint_maxp(&urb->ep->desc));
-
-	ret = prepare_transfer(xhci, xhci->devs[slot_id],
-			ep_index, urb->stream_id,
-			num_trbs, urb, 0, mem_flags);
-	if (ret < 0)
-		return ret;
-
-	urb_priv = urb->hcpriv;
-
-	/* Deal with URB_ZERO_PACKET - need one more td/trb */
-	zero_length_needed = urb->transfer_flags & URB_ZERO_PACKET &&
-		urb_priv->length == 2;
-	if (zero_length_needed) {
-		num_trbs++;
-		xhci_dbg(xhci, "Creating zero length td.\n");
-		ret = prepare_transfer(xhci, xhci->devs[slot_id],
-				ep_index, urb->stream_id,
-				1, urb, 1, mem_flags);
-		if (ret < 0)
-			return ret;
-	}
-
-	td = urb_priv->td[0];
-
-	/*
-	 * Don't give the first TRB to the hardware (by toggling the cycle bit)
-	 * until we've finished creating all the other TRBs.  The ring's cycle
-	 * state may change as we enqueue the other TRBs, so save it too.
-	 */
-	start_trb = &ep_ring->enqueue->generic;
-	start_cycle = ep_ring->cycle_state;
-
-	running_total = 0;
-	/*
-	 * How much data is in the first TRB?
-	 *
-	 * There are three forces at work for TRB buffer pointers and lengths:
-	 * 1. We don't want to walk off the end of this sg-list entry buffer.
-	 * 2. The transfer length that the driver requested may be smaller than
-	 *    the amount of memory allocated for this scatter-gather list.
-	 * 3. TRBs buffers can't cross 64KB boundaries.
-	 */
-	sg = urb->sg;
-	addr = (u64) sg_dma_address(sg);
-	this_sg_len = sg_dma_len(sg);
-	trb_buff_len = TRB_MAX_BUFF_SIZE - (addr & (TRB_MAX_BUFF_SIZE - 1));
-	trb_buff_len = min_t(int, trb_buff_len, this_sg_len);
-	if (trb_buff_len > urb->transfer_buffer_length)
-		trb_buff_len = urb->transfer_buffer_length;
-
-	first_trb = true;
-	last_trb_num = zero_length_needed ? 2 : 1;
-	/* Queue the first TRB, even if it's zero-length */
-	do {
-		u32 field = 0;
-		u32 length_field = 0;
-		u32 remainder = 0;
-
-		/* Don't change the cycle bit of the first TRB until later */
-		if (first_trb) {
-			first_trb = false;
-			if (start_cycle == 0)
-				field |= 0x1;
-		} else
-			field |= ep_ring->cycle_state;
-
-		/* Chain all the TRBs together; clear the chain bit in the last
-		 * TRB to indicate it's the last TRB in the chain.
-		 */
-		if (num_trbs > last_trb_num) {
-			field |= TRB_CHAIN;
-		} else if (num_trbs == last_trb_num) {
-			td->last_trb = ep_ring->enqueue;
-			field |= TRB_IOC;
-		} else if (zero_length_needed && num_trbs == 1) {
-			trb_buff_len = 0;
-			urb_priv->td[1]->last_trb = ep_ring->enqueue;
-			field |= TRB_IOC;
-		}
-
-		/* Only set interrupt on short packet for IN endpoints */
-		if (usb_urb_dir_in(urb))
-			field |= TRB_ISP;
-
-		if (TRB_MAX_BUFF_SIZE -
-				(addr & (TRB_MAX_BUFF_SIZE - 1)) < trb_buff_len) {
-			xhci_warn(xhci, "WARN: sg dma xfer crosses 64KB boundaries!\n");
-			xhci_dbg(xhci, "Next boundary at %#x, end dma = %#x\n",
-					(unsigned int) (addr + TRB_MAX_BUFF_SIZE) & ~(TRB_MAX_BUFF_SIZE - 1),
-					(unsigned int) addr + trb_buff_len);
-		}
-
-		/* Set the TRB length, TD size, and interrupter fields. */
-		remainder = xhci_td_remainder(xhci, running_total, trb_buff_len,
-					   urb->transfer_buffer_length,
-					   urb, num_trbs - 1);
-
-		length_field = TRB_LEN(trb_buff_len) |
-			TRB_TD_SIZE(remainder) |
-			TRB_INTR_TARGET(0);
-
-		if (num_trbs > 1)
-			more_trbs_coming = true;
-		else
-			more_trbs_coming = false;
-		queue_trb(xhci, ep_ring, more_trbs_coming,
-				lower_32_bits(addr),
-				upper_32_bits(addr),
-				length_field,
-				field | TRB_TYPE(TRB_NORMAL));
-		--num_trbs;
-		running_total += trb_buff_len;
-
-		/* Calculate length for next transfer --
-		 * Are we done queueing all the TRBs for this sg entry?
-		 */
-		this_sg_len -= trb_buff_len;
-		if (this_sg_len == 0) {
-			--num_sgs;
-			if (num_sgs == 0)
-				break;
-			sg = sg_next(sg);
-			addr = (u64) sg_dma_address(sg);
-			this_sg_len = sg_dma_len(sg);
-		} else {
-			addr += trb_buff_len;
-		}
-
-		trb_buff_len = TRB_MAX_BUFF_SIZE -
-			(addr & (TRB_MAX_BUFF_SIZE - 1));
-		trb_buff_len = min_t(int, trb_buff_len, this_sg_len);
-		if (running_total + trb_buff_len > urb->transfer_buffer_length)
-			trb_buff_len =
-				urb->transfer_buffer_length - running_total;
-	} while (num_trbs > 0);
-
-	check_trb_math(urb, num_trbs, running_total);
-	giveback_first_trb(xhci, slot_id, ep_index, urb->stream_id,
-			start_cycle, start_trb);
-	return 0;
-}
-
 /* This is very similar to what ehci-q.c qtd_fill() does */
 int xhci_queue_bulk_tx(struct xhci_hcd *xhci, gfp_t mem_flags,
 		struct urb *urb, int slot_id, unsigned int ep_index)
@@ -3264,51 +3102,40 @@
 	struct xhci_ring *ep_ring;
 	struct urb_priv *urb_priv;
 	struct xhci_td *td;
-	int num_trbs;
 	struct xhci_generic_trb *start_trb;
-	bool first_trb;
-	int last_trb_num;
+	struct scatterlist *sg = NULL;
 	bool more_trbs_coming;
 	bool zero_length_needed;
-	int start_cycle;
-	u32 field, length_field;
-
-	int running_total, trb_buff_len, ret;
-	unsigned int total_packet_count;
+	unsigned int num_trbs, last_trb_num, i;
+	unsigned int start_cycle, num_sgs = 0;
+	unsigned int running_total, block_len, trb_buff_len;
+	unsigned int full_len;
+	int ret;
+	u32 field, length_field, remainder;
 	u64 addr;
 
-	if (urb->num_sgs)
-		return queue_bulk_sg_tx(xhci, mem_flags, urb, slot_id, ep_index);
-
 	ep_ring = xhci_urb_to_transfer_ring(xhci, urb);
 	if (!ep_ring)
 		return -EINVAL;
 
-	num_trbs = 0;
-	/* How much data is (potentially) left before the 64KB boundary? */
-	running_total = TRB_MAX_BUFF_SIZE -
-		(urb->transfer_dma & (TRB_MAX_BUFF_SIZE - 1));
-	running_total &= TRB_MAX_BUFF_SIZE - 1;
-
-	/* If there's some data on this 64KB chunk, or we have to send a
-	 * zero-length transfer, we need at least one TRB
-	 */
-	if (running_total != 0 || urb->transfer_buffer_length == 0)
-		num_trbs++;
-	/* How many more 64KB chunks to transfer, how many more TRBs? */
-	while (running_total < urb->transfer_buffer_length) {
-		num_trbs++;
-		running_total += TRB_MAX_BUFF_SIZE;
-	}
+	/* If we have scatter/gather list, we use it. */
+	if (urb->num_sgs) {
+		num_sgs = urb->num_mapped_sgs;
+		sg = urb->sg;
+		num_trbs = count_sg_trbs_needed(urb);
+	} else
+		num_trbs = count_trbs_needed(urb);
 
 	ret = prepare_transfer(xhci, xhci->devs[slot_id],
 			ep_index, urb->stream_id,
 			num_trbs, urb, 0, mem_flags);
-	if (ret < 0)
+	if (unlikely(ret < 0))
 		return ret;
 
 	urb_priv = urb->hcpriv;
 
+	last_trb_num = num_trbs - 1;
+
 	/* Deal with URB_ZERO_PACKET - need one more td/trb */
 	zero_length_needed = urb->transfer_flags & URB_ZERO_PACKET &&
 		urb_priv->length == 2;
@@ -3318,7 +3145,7 @@
 		ret = prepare_transfer(xhci, xhci->devs[slot_id],
 				ep_index, urb->stream_id,
 				1, urb, 1, mem_flags);
-		if (ret < 0)
+		if (unlikely(ret < 0))
 			return ret;
 	}
 
@@ -3332,43 +3159,58 @@
 	start_trb = &ep_ring->enqueue->generic;
 	start_cycle = ep_ring->cycle_state;
 
+	full_len = urb->transfer_buffer_length;
 	running_total = 0;
-	total_packet_count = DIV_ROUND_UP(urb->transfer_buffer_length,
-			usb_endpoint_maxp(&urb->ep->desc));
-	/* How much data is in the first TRB? */
-	addr = (u64) urb->transfer_dma;
-	trb_buff_len = TRB_MAX_BUFF_SIZE -
-		(urb->transfer_dma & (TRB_MAX_BUFF_SIZE - 1));
-	if (trb_buff_len > urb->transfer_buffer_length)
-		trb_buff_len = urb->transfer_buffer_length;
+	block_len = 0;
 
-	first_trb = true;
-	last_trb_num = zero_length_needed ? 2 : 1;
-	/* Queue the first TRB, even if it's zero-length */
-	do {
-		u32 remainder = 0;
-		field = 0;
+	/* Queue the TRBs, even if they are zero-length */
+	for (i = 0; i < num_trbs; i++) {
+		field = TRB_TYPE(TRB_NORMAL);
+
+		if (block_len == 0) {
+			/* A new contiguous block. */
+			if (sg) {
+				addr = (u64) sg_dma_address(sg);
+				block_len = sg_dma_len(sg);
+			} else {
+				addr = (u64) urb->transfer_dma;
+				block_len = full_len;
+			}
+			/* TRB buffer should not cross 64KB boundaries */
+			trb_buff_len = TRB_BUFF_LEN_UP_TO_BOUNDARY(addr);
+			trb_buff_len = min_t(unsigned int,
+								trb_buff_len,
+								block_len);
+		} else {
+			/* Further through the contiguous block. */
+			trb_buff_len = block_len;
+			if (trb_buff_len > TRB_MAX_BUFF_SIZE)
+				trb_buff_len = TRB_MAX_BUFF_SIZE;
+		}
+
+		if (running_total + trb_buff_len > full_len)
+			trb_buff_len = full_len - running_total;
 
 		/* Don't change the cycle bit of the first TRB until later */
-		if (first_trb) {
-			first_trb = false;
+		if (i == 0) {
 			if (start_cycle == 0)
-				field |= 0x1;
+				field |= TRB_CYCLE;
 		} else
 			field |= ep_ring->cycle_state;
 
 		/* Chain all the TRBs together; clear the chain bit in the last
 		 * TRB to indicate it's the last TRB in the chain.
 		 */
-		if (num_trbs > last_trb_num) {
+		if (i < last_trb_num) {
 			field |= TRB_CHAIN;
-		} else if (num_trbs == last_trb_num) {
-			td->last_trb = ep_ring->enqueue;
+		} else {
 			field |= TRB_IOC;
-		} else if (zero_length_needed && num_trbs == 1) {
-			trb_buff_len = 0;
-			urb_priv->td[1]->last_trb = ep_ring->enqueue;
-			field |= TRB_IOC;
+			if (i == last_trb_num)
+				td->last_trb = ep_ring->enqueue;
+			else if (zero_length_needed) {
+				trb_buff_len = 0;
+				urb_priv->td[1]->last_trb = ep_ring->enqueue;
+			}
 		}
 
 		/* Only set interrupt on short packet for IN endpoints */
@@ -3376,15 +3218,15 @@
 			field |= TRB_ISP;
 
 		/* Set the TRB length, TD size, and interrupter fields. */
-		remainder = xhci_td_remainder(xhci, running_total, trb_buff_len,
-					   urb->transfer_buffer_length,
-					   urb, num_trbs - 1);
+		remainder = xhci_td_remainder(xhci, running_total,
+							trb_buff_len, full_len,
+							urb, num_trbs - i - 1);
 
 		length_field = TRB_LEN(trb_buff_len) |
 			TRB_TD_SIZE(remainder) |
 			TRB_INTR_TARGET(0);
 
-		if (num_trbs > 1)
+		if (i < num_trbs - 1)
 			more_trbs_coming = true;
 		else
 			more_trbs_coming = false;
@@ -3392,18 +3234,24 @@
 				lower_32_bits(addr),
 				upper_32_bits(addr),
 				length_field,
-				field | TRB_TYPE(TRB_NORMAL));
-		--num_trbs;
+				field);
+
 		running_total += trb_buff_len;
-
-		/* Calculate length for next transfer */
 		addr += trb_buff_len;
-		trb_buff_len = urb->transfer_buffer_length - running_total;
-		if (trb_buff_len > TRB_MAX_BUFF_SIZE)
-			trb_buff_len = TRB_MAX_BUFF_SIZE;
-	} while (num_trbs > 0);
+		block_len -= trb_buff_len;
 
-	check_trb_math(urb, num_trbs, running_total);
+		if (sg) {
+			if (block_len == 0) {
+				/* New sg entry */
+				--num_sgs;
+				if (num_sgs == 0)
+					break;
+				sg = sg_next(sg);
+			}
+		}
+	}
+
+	check_trb_math(urb, running_total);
 	giveback_first_trb(xhci, slot_id, ep_index, urb->stream_id,
 			start_cycle, start_trb);
 	return 0;
@@ -3532,23 +3380,6 @@
 	return 0;
 }
 
-static int count_isoc_trbs_needed(struct xhci_hcd *xhci,
-		struct urb *urb, int i)
-{
-	int num_trbs = 0;
-	u64 addr, td_len;
-
-	addr = (u64) (urb->transfer_dma + urb->iso_frame_desc[i].offset);
-	td_len = urb->iso_frame_desc[i].length;
-
-	num_trbs = DIV_ROUND_UP(td_len + (addr & (TRB_MAX_BUFF_SIZE - 1)),
-			TRB_MAX_BUFF_SIZE);
-	if (num_trbs == 0)
-		num_trbs++;
-
-	return num_trbs;
-}
-
 /*
  * The transfer burst count field of the isochronous TRB defines the number of
  * bursts that are required to move all packets in this TD.  Only SuperSpeed
@@ -3746,7 +3577,7 @@
 		last_burst_pkt_count = xhci_get_last_burst_packet_count(xhci,
 							urb, total_pkt_count);
 
-		trbs_per_td = count_isoc_trbs_needed(xhci, urb, i);
+		trbs_per_td = count_isoc_trbs_needed(urb, i);
 
 		ret = prepare_transfer(xhci, xhci->devs[slot_id], ep_index,
 				urb->stream_id, trbs_per_td, urb, i, mem_flags);
@@ -3807,8 +3638,7 @@
 					field |= TRB_BEI;
 			}
 			/* Calculate TRB length */
-			trb_buff_len = TRB_MAX_BUFF_SIZE -
-				(addr & ((1 << TRB_MAX_BUFF_SHIFT) - 1));
+			trb_buff_len = TRB_BUFF_LEN_UP_TO_BOUNDARY(addr);
 			if (trb_buff_len > td_remain_len)
 				trb_buff_len = td_remain_len;
 
@@ -3912,7 +3742,7 @@
 	num_trbs = 0;
 	num_tds = urb->number_of_packets;
 	for (i = 0; i < num_tds; i++)
-		num_trbs += count_isoc_trbs_needed(xhci, urb, i);
+		num_trbs += count_isoc_trbs_needed(urb, i);
 
 	/* Check the ring to guarantee there is enough room for the whole urb.
 	 * Do not insert any td of the urb to the ring if the check failed.
diff --git a/drivers/usb/host/xhci.h b/drivers/usb/host/xhci.h
index 6c629c9..8fd35a65 100644
--- a/drivers/usb/host/xhci.h
+++ b/drivers/usb/host/xhci.h
@@ -1338,6 +1338,9 @@
 /* TRB buffer pointers can't cross 64KB boundaries */
 #define TRB_MAX_BUFF_SHIFT		16
 #define TRB_MAX_BUFF_SIZE	(1 << TRB_MAX_BUFF_SHIFT)
+/* How much data is left before the 64KB boundary? */
+#define TRB_BUFF_LEN_UP_TO_BOUNDARY(addr)	(TRB_MAX_BUFF_SIZE - \
+					(addr & (TRB_MAX_BUFF_SIZE - 1)))
 
 struct xhci_segment {
 	union xhci_trb		*trbs;