virtio_net: Use temporary storage for accounting rx stats

The purpose is to keep receive_buf arguments simple when more per-queue
counter items are added later.
Also XDP_TX related sq counters will be updated in the following changes
so create a container struct virtnet_rx_stats which will includes both
rq and sq statistics. For now it only covers rq stats.

Signed-off-by: Toshiaki Makita <makita.toshiaki@lab.ntt.co.jp>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/drivers/net/virtio_net.c b/drivers/net/virtio_net.c
index abbd3bc..d03bfc4 100644
--- a/drivers/net/virtio_net.c
+++ b/drivers/net/virtio_net.c
@@ -84,14 +84,22 @@ struct virtnet_sq_stats {
 	u64 bytes;
 };
 
-struct virtnet_rq_stats {
-	struct u64_stats_sync syncp;
+struct virtnet_rq_stat_items {
 	u64 packets;
 	u64 bytes;
 };
 
+struct virtnet_rq_stats {
+	struct u64_stats_sync syncp;
+	struct virtnet_rq_stat_items items;
+};
+
+struct virtnet_rx_stats {
+	struct virtnet_rq_stat_items rx;
+};
+
 #define VIRTNET_SQ_STAT(m)	offsetof(struct virtnet_sq_stats, m)
-#define VIRTNET_RQ_STAT(m)	offsetof(struct virtnet_rq_stats, m)
+#define VIRTNET_RQ_STAT(m)	offsetof(struct virtnet_rq_stat_items, m)
 
 static const struct virtnet_stat_desc virtnet_sq_stats_desc[] = {
 	{ "packets",	VIRTNET_SQ_STAT(packets) },
@@ -587,7 +595,7 @@ static struct sk_buff *receive_small(struct net_device *dev,
 				     void *buf, void *ctx,
 				     unsigned int len,
 				     unsigned int *xdp_xmit,
-				     unsigned int *rbytes)
+				     struct virtnet_rx_stats *stats)
 {
 	struct sk_buff *skb;
 	struct bpf_prog *xdp_prog;
@@ -602,7 +610,7 @@ static struct sk_buff *receive_small(struct net_device *dev,
 	int err;
 
 	len -= vi->hdr_len;
-	*rbytes += len;
+	stats->rx.bytes += len;
 
 	rcu_read_lock();
 	xdp_prog = rcu_dereference(rq->xdp_prog);
@@ -708,12 +716,12 @@ static struct sk_buff *receive_big(struct net_device *dev,
 				   struct receive_queue *rq,
 				   void *buf,
 				   unsigned int len,
-				   unsigned int *rbytes)
+				   struct virtnet_rx_stats *stats)
 {
 	struct page *page = buf;
 	struct sk_buff *skb = page_to_skb(vi, rq, page, 0, len, PAGE_SIZE);
 
-	*rbytes += len - vi->hdr_len;
+	stats->rx.bytes += len - vi->hdr_len;
 	if (unlikely(!skb))
 		goto err;
 
@@ -732,7 +740,7 @@ static struct sk_buff *receive_mergeable(struct net_device *dev,
 					 void *ctx,
 					 unsigned int len,
 					 unsigned int *xdp_xmit,
-					 unsigned int *rbytes)
+					 struct virtnet_rx_stats *stats)
 {
 	struct virtio_net_hdr_mrg_rxbuf *hdr = buf;
 	u16 num_buf = virtio16_to_cpu(vi->vdev, hdr->num_buffers);
@@ -745,7 +753,7 @@ static struct sk_buff *receive_mergeable(struct net_device *dev,
 	int err;
 
 	head_skb = NULL;
-	*rbytes += len - vi->hdr_len;
+	stats->rx.bytes += len - vi->hdr_len;
 
 	rcu_read_lock();
 	xdp_prog = rcu_dereference(rq->xdp_prog);
@@ -883,7 +891,7 @@ static struct sk_buff *receive_mergeable(struct net_device *dev,
 			goto err_buf;
 		}
 
-		*rbytes += len;
+		stats->rx.bytes += len;
 		page = virt_to_head_page(buf);
 
 		truesize = mergeable_ctx_to_truesize(ctx);
@@ -939,7 +947,7 @@ static struct sk_buff *receive_mergeable(struct net_device *dev,
 			dev->stats.rx_length_errors++;
 			break;
 		}
-		*rbytes += len;
+		stats->rx.bytes += len;
 		page = virt_to_head_page(buf);
 		put_page(page);
 	}
@@ -952,7 +960,8 @@ static struct sk_buff *receive_mergeable(struct net_device *dev,
 
 static void receive_buf(struct virtnet_info *vi, struct receive_queue *rq,
 			void *buf, unsigned int len, void **ctx,
-			unsigned int *xdp_xmit, unsigned int *rbytes)
+			unsigned int *xdp_xmit,
+			struct virtnet_rx_stats *stats)
 {
 	struct net_device *dev = vi->dev;
 	struct sk_buff *skb;
@@ -973,11 +982,11 @@ static void receive_buf(struct virtnet_info *vi, struct receive_queue *rq,
 
 	if (vi->mergeable_rx_bufs)
 		skb = receive_mergeable(dev, vi, rq, buf, ctx, len, xdp_xmit,
-					rbytes);
+					stats);
 	else if (vi->big_packets)
-		skb = receive_big(dev, vi, rq, buf, len, rbytes);
+		skb = receive_big(dev, vi, rq, buf, len, stats);
 	else
-		skb = receive_small(dev, vi, rq, buf, ctx, len, xdp_xmit, rbytes);
+		skb = receive_small(dev, vi, rq, buf, ctx, len, xdp_xmit, stats);
 
 	if (unlikely(!skb))
 		return;
@@ -1246,22 +1255,24 @@ static int virtnet_receive(struct receive_queue *rq, int budget,
 			   unsigned int *xdp_xmit)
 {
 	struct virtnet_info *vi = rq->vq->vdev->priv;
-	unsigned int len, received = 0, bytes = 0;
+	struct virtnet_rx_stats stats = {};
+	unsigned int len;
 	void *buf;
+	int i;
 
 	if (!vi->big_packets || vi->mergeable_rx_bufs) {
 		void *ctx;
 
-		while (received < budget &&
+		while (stats.rx.packets < budget &&
 		       (buf = virtqueue_get_buf_ctx(rq->vq, &len, &ctx))) {
-			receive_buf(vi, rq, buf, len, ctx, xdp_xmit, &bytes);
-			received++;
+			receive_buf(vi, rq, buf, len, ctx, xdp_xmit, &stats);
+			stats.rx.packets++;
 		}
 	} else {
-		while (received < budget &&
+		while (stats.rx.packets < budget &&
 		       (buf = virtqueue_get_buf(rq->vq, &len)) != NULL) {
-			receive_buf(vi, rq, buf, len, NULL, xdp_xmit, &bytes);
-			received++;
+			receive_buf(vi, rq, buf, len, NULL, xdp_xmit, &stats);
+			stats.rx.packets++;
 		}
 	}
 
@@ -1271,11 +1282,16 @@ static int virtnet_receive(struct receive_queue *rq, int budget,
 	}
 
 	u64_stats_update_begin(&rq->stats.syncp);
-	rq->stats.bytes += bytes;
-	rq->stats.packets += received;
+	for (i = 0; i < VIRTNET_RQ_STATS_LEN; i++) {
+		size_t offset = virtnet_rq_stats_desc[i].offset;
+		u64 *item;
+
+		item = (u64 *)((u8 *)&rq->stats.items + offset);
+		*item += *(u64 *)((u8 *)&stats.rx + offset);
+	}
 	u64_stats_update_end(&rq->stats.syncp);
 
-	return received;
+	return stats.rx.packets;
 }
 
 static void free_old_xmit_skbs(struct send_queue *sq)
@@ -1628,8 +1644,8 @@ static void virtnet_stats(struct net_device *dev,
 
 		do {
 			start = u64_stats_fetch_begin_irq(&rq->stats.syncp);
-			rpackets = rq->stats.packets;
-			rbytes   = rq->stats.bytes;
+			rpackets = rq->stats.items.packets;
+			rbytes   = rq->stats.items.bytes;
 		} while (u64_stats_fetch_retry_irq(&rq->stats.syncp, start));
 
 		tot->rx_packets += rpackets;
@@ -2019,7 +2035,7 @@ static void virtnet_get_ethtool_stats(struct net_device *dev,
 	for (i = 0; i < vi->curr_queue_pairs; i++) {
 		struct receive_queue *rq = &vi->rq[i];
 
-		stats_base = (u8 *)&rq->stats;
+		stats_base = (u8 *)&rq->stats.items;
 		do {
 			start = u64_stats_fetch_begin_irq(&rq->stats.syncp);
 			for (j = 0; j < VIRTNET_RQ_STATS_LEN; j++) {