nvme-pci: use max of PRP or SGL for iod size

>From the initial implementation of NVMe SGL kernel support
commit a7a7cbe353a5 ("nvme-pci: add SGL support") with addition of the
commit 943e942e6266 ("nvme-pci: limit max IO size and segments to avoid
high order allocations") now there is only caller left for
nvme_pci_iod_alloc_size() which statically passes true for last
parameter that calculates allocation size based on SGL since we need
size of biggest command supported for mempool allocation.

This patch modifies the helper functions nvme_pci_iod_alloc_size() such
that it is now uses maximum of PRP and SGL size for iod allocation size
calculation.

Signed-off-by: Chaitanya Kulkarni <chaitanya.kulkarni@wdc.com>
Signed-off-by: Christoph Hellwig <hch@lst.de>
diff --git a/drivers/nvme/host/pci.c b/drivers/nvme/host/pci.c
index 0ab0dcf..96c1ae3 100644
--- a/drivers/nvme/host/pci.c
+++ b/drivers/nvme/host/pci.c
@@ -346,9 +346,9 @@ static bool nvme_dbbuf_update_and_check_event(u16 value, u32 *dbbuf_db,
  * as it only leads to a small amount of wasted memory for the lifetime of
  * the I/O.
  */
-static int nvme_npages(unsigned size, struct nvme_dev *dev)
+static int nvme_pci_npages_prp(void)
 {
-	unsigned nprps = DIV_ROUND_UP(size + NVME_CTRL_PAGE_SIZE,
+	unsigned nprps = DIV_ROUND_UP(NVME_MAX_KB_SZ + NVME_CTRL_PAGE_SIZE,
 				      NVME_CTRL_PAGE_SIZE);
 	return DIV_ROUND_UP(8 * nprps, PAGE_SIZE - 8);
 }
@@ -357,22 +357,18 @@ static int nvme_npages(unsigned size, struct nvme_dev *dev)
  * Calculates the number of pages needed for the SGL segments. For example a 4k
  * page can accommodate 256 SGL descriptors.
  */
-static int nvme_pci_npages_sgl(unsigned int num_seg)
+static int nvme_pci_npages_sgl(void)
 {
-	return DIV_ROUND_UP(num_seg * sizeof(struct nvme_sgl_desc), PAGE_SIZE);
+	return DIV_ROUND_UP(NVME_MAX_SEGS * sizeof(struct nvme_sgl_desc),
+			PAGE_SIZE);
 }
 
-static size_t nvme_pci_iod_alloc_size(struct nvme_dev *dev,
-		unsigned int size, unsigned int nseg, bool use_sgl)
+static size_t nvme_pci_iod_alloc_size(void)
 {
-	size_t alloc_size;
+	size_t npages = max(nvme_pci_npages_prp(), nvme_pci_npages_sgl());
 
-	if (use_sgl)
-		alloc_size = sizeof(__le64 *) * nvme_pci_npages_sgl(nseg);
-	else
-		alloc_size = sizeof(__le64 *) * nvme_npages(size, dev);
-
-	return alloc_size + sizeof(struct scatterlist) * nseg;
+	return sizeof(__le64 *) * npages +
+		sizeof(struct scatterlist) * NVME_MAX_SEGS;
 }
 
 static int nvme_admin_init_hctx(struct blk_mq_hw_ctx *hctx, void *data,
@@ -2811,8 +2807,7 @@ static int nvme_probe(struct pci_dev *pdev, const struct pci_device_id *id)
 	 * Double check that our mempool alloc size will cover the biggest
 	 * command we support.
 	 */
-	alloc_size = nvme_pci_iod_alloc_size(dev, NVME_MAX_KB_SZ,
-						NVME_MAX_SEGS, true);
+	alloc_size = nvme_pci_iod_alloc_size();
 	WARN_ON_ONCE(alloc_size > PAGE_SIZE);
 
 	dev->iod_mempool = mempool_create_node(1, mempool_kmalloc,