iommu/rockchip: Use DMA API to manage coherency

Use DMA API instead of architecture internal functions like
__cpuc_flush_dcache_area() etc.

The biggest difficulty here is that dma_map and _sync calls require some
struct device, while there is no real 1:1 relation between an IOMMU
domain and some device. To overcome this, a simple platform device is
registered for each allocated IOMMU domain.

With this patch, this driver can be used on both ARM and ARM64
platforms, such as RK3288 and RK3399 respectively.

Signed-off-by: Shunqian Zheng <zhengsq@rock-chips.com>
Signed-off-by: Tomasz Figa <tfiga@chromium.org>
Signed-off-by: Joerg Roedel <jroedel@suse.de>
diff --git a/drivers/iommu/rockchip-iommu.c b/drivers/iommu/rockchip-iommu.c
index 8a5bac7..712ed75 100644
--- a/drivers/iommu/rockchip-iommu.c
+++ b/drivers/iommu/rockchip-iommu.c
@@ -4,11 +4,10 @@
  * published by the Free Software Foundation.
  */
 
-#include <asm/cacheflush.h>
-#include <asm/pgtable.h>
 #include <linux/compiler.h>
 #include <linux/delay.h>
 #include <linux/device.h>
+#include <linux/dma-iommu.h>
 #include <linux/errno.h>
 #include <linux/interrupt.h>
 #include <linux/io.h>
@@ -77,7 +76,9 @@
 
 struct rk_iommu_domain {
 	struct list_head iommus;
+	struct platform_device *pdev;
 	u32 *dt; /* page directory table */
+	dma_addr_t dt_dma;
 	spinlock_t iommus_lock; /* lock for iommus list */
 	spinlock_t dt_lock; /* lock for modifying page directory table */
 
@@ -93,14 +94,12 @@
 	struct iommu_domain *domain; /* domain to which iommu is attached */
 };
 
-static inline void rk_table_flush(u32 *va, unsigned int count)
+static inline void rk_table_flush(struct rk_iommu_domain *dom, dma_addr_t dma,
+				  unsigned int count)
 {
-	phys_addr_t pa_start = virt_to_phys(va);
-	phys_addr_t pa_end = virt_to_phys(va + count);
-	size_t size = pa_end - pa_start;
+	size_t size = count * sizeof(u32); /* count of u32 entry */
 
-	__cpuc_flush_dcache_area(va, size);
-	outer_flush_range(pa_start, pa_end);
+	dma_sync_single_for_device(&dom->pdev->dev, dma, size, DMA_TO_DEVICE);
 }
 
 static struct rk_iommu_domain *to_rk_domain(struct iommu_domain *dom)
@@ -183,10 +182,9 @@
 	return dte & RK_DTE_PT_VALID;
 }
 
-static u32 rk_mk_dte(u32 *pt)
+static inline u32 rk_mk_dte(dma_addr_t pt_dma)
 {
-	phys_addr_t pt_phys = virt_to_phys(pt);
-	return (pt_phys & RK_DTE_PT_ADDRESS_MASK) | RK_DTE_PT_VALID;
+	return (pt_dma & RK_DTE_PT_ADDRESS_MASK) | RK_DTE_PT_VALID;
 }
 
 /*
@@ -603,13 +601,16 @@
 static u32 *rk_dte_get_page_table(struct rk_iommu_domain *rk_domain,
 				  dma_addr_t iova)
 {
+	struct device *dev = &rk_domain->pdev->dev;
 	u32 *page_table, *dte_addr;
-	u32 dte;
+	u32 dte_index, dte;
 	phys_addr_t pt_phys;
+	dma_addr_t pt_dma;
 
 	assert_spin_locked(&rk_domain->dt_lock);
 
-	dte_addr = &rk_domain->dt[rk_iova_dte_index(iova)];
+	dte_index = rk_iova_dte_index(iova);
+	dte_addr = &rk_domain->dt[dte_index];
 	dte = *dte_addr;
 	if (rk_dte_is_pt_valid(dte))
 		goto done;
@@ -618,19 +619,27 @@
 	if (!page_table)
 		return ERR_PTR(-ENOMEM);
 
-	dte = rk_mk_dte(page_table);
+	pt_dma = dma_map_single(dev, page_table, SPAGE_SIZE, DMA_TO_DEVICE);
+	if (dma_mapping_error(dev, pt_dma)) {
+		dev_err(dev, "DMA mapping error while allocating page table\n");
+		free_page((unsigned long)page_table);
+		return ERR_PTR(-ENOMEM);
+	}
+
+	dte = rk_mk_dte(pt_dma);
 	*dte_addr = dte;
 
-	rk_table_flush(page_table, NUM_PT_ENTRIES);
-	rk_table_flush(dte_addr, 1);
-
+	rk_table_flush(rk_domain, pt_dma, NUM_PT_ENTRIES);
+	rk_table_flush(rk_domain,
+		       rk_domain->dt_dma + dte_index * sizeof(u32), 1);
 done:
 	pt_phys = rk_dte_pt_address(dte);
 	return (u32 *)phys_to_virt(pt_phys);
 }
 
 static size_t rk_iommu_unmap_iova(struct rk_iommu_domain *rk_domain,
-				  u32 *pte_addr, dma_addr_t iova, size_t size)
+				  u32 *pte_addr, dma_addr_t pte_dma,
+				  size_t size)
 {
 	unsigned int pte_count;
 	unsigned int pte_total = size / SPAGE_SIZE;
@@ -645,14 +654,14 @@
 		pte_addr[pte_count] = rk_mk_pte_invalid(pte);
 	}
 
-	rk_table_flush(pte_addr, pte_count);
+	rk_table_flush(rk_domain, pte_dma, pte_count);
 
 	return pte_count * SPAGE_SIZE;
 }
 
 static int rk_iommu_map_iova(struct rk_iommu_domain *rk_domain, u32 *pte_addr,
-			     dma_addr_t iova, phys_addr_t paddr, size_t size,
-			     int prot)
+			     dma_addr_t pte_dma, dma_addr_t iova,
+			     phys_addr_t paddr, size_t size, int prot)
 {
 	unsigned int pte_count;
 	unsigned int pte_total = size / SPAGE_SIZE;
@@ -671,7 +680,7 @@
 		paddr += SPAGE_SIZE;
 	}
 
-	rk_table_flush(pte_addr, pte_count);
+	rk_table_flush(rk_domain, pte_dma, pte_total);
 
 	/*
 	 * Zap the first and last iova to evict from iotlb any previously
@@ -684,7 +693,8 @@
 	return 0;
 unwind:
 	/* Unmap the range of iovas that we just mapped */
-	rk_iommu_unmap_iova(rk_domain, pte_addr, iova, pte_count * SPAGE_SIZE);
+	rk_iommu_unmap_iova(rk_domain, pte_addr, pte_dma,
+			    pte_count * SPAGE_SIZE);
 
 	iova += pte_count * SPAGE_SIZE;
 	page_phys = rk_pte_page_address(pte_addr[pte_count]);
@@ -699,8 +709,9 @@
 {
 	struct rk_iommu_domain *rk_domain = to_rk_domain(domain);
 	unsigned long flags;
-	dma_addr_t iova = (dma_addr_t)_iova;
+	dma_addr_t pte_dma, iova = (dma_addr_t)_iova;
 	u32 *page_table, *pte_addr;
+	u32 dte_index, pte_index;
 	int ret;
 
 	spin_lock_irqsave(&rk_domain->dt_lock, flags);
@@ -718,8 +729,13 @@
 		return PTR_ERR(page_table);
 	}
 
-	pte_addr = &page_table[rk_iova_pte_index(iova)];
-	ret = rk_iommu_map_iova(rk_domain, pte_addr, iova, paddr, size, prot);
+	dte_index = rk_domain->dt[rk_iova_dte_index(iova)];
+	pte_index = rk_iova_pte_index(iova);
+	pte_addr = &page_table[pte_index];
+	pte_dma = rk_dte_pt_address(dte_index) + pte_index * sizeof(u32);
+	ret = rk_iommu_map_iova(rk_domain, pte_addr, pte_dma, iova,
+				paddr, size, prot);
+
 	spin_unlock_irqrestore(&rk_domain->dt_lock, flags);
 
 	return ret;
@@ -730,7 +746,7 @@
 {
 	struct rk_iommu_domain *rk_domain = to_rk_domain(domain);
 	unsigned long flags;
-	dma_addr_t iova = (dma_addr_t)_iova;
+	dma_addr_t pte_dma, iova = (dma_addr_t)_iova;
 	phys_addr_t pt_phys;
 	u32 dte;
 	u32 *pte_addr;
@@ -754,7 +770,8 @@
 
 	pt_phys = rk_dte_pt_address(dte);
 	pte_addr = (u32 *)phys_to_virt(pt_phys) + rk_iova_pte_index(iova);
-	unmap_size = rk_iommu_unmap_iova(rk_domain, pte_addr, iova, size);
+	pte_dma = pt_phys + rk_iova_pte_index(iova) * sizeof(u32);
+	unmap_size = rk_iommu_unmap_iova(rk_domain, pte_addr, pte_dma, size);
 
 	spin_unlock_irqrestore(&rk_domain->dt_lock, flags);
 
@@ -787,7 +804,6 @@
 	struct rk_iommu_domain *rk_domain = to_rk_domain(domain);
 	unsigned long flags;
 	int ret, i;
-	phys_addr_t dte_addr;
 
 	/*
 	 * Allow 'virtual devices' (e.g., drm) to attach to domain.
@@ -812,9 +828,9 @@
 	if (ret)
 		return ret;
 
-	dte_addr = virt_to_phys(rk_domain->dt);
 	for (i = 0; i < iommu->num_mmu; i++) {
-		rk_iommu_write(iommu->bases[i], RK_MMU_DTE_ADDR, dte_addr);
+		rk_iommu_write(iommu->bases[i], RK_MMU_DTE_ADDR,
+			       rk_domain->dt_dma);
 		rk_iommu_base_command(iommu->bases[i], RK_MMU_CMD_ZAP_CACHE);
 		rk_iommu_write(iommu->bases[i], RK_MMU_INT_MASK, RK_MMU_IRQ_MASK);
 	}
@@ -870,14 +886,30 @@
 static struct iommu_domain *rk_iommu_domain_alloc(unsigned type)
 {
 	struct rk_iommu_domain *rk_domain;
+	struct platform_device *pdev;
+	struct device *iommu_dev;
 
 	if (type != IOMMU_DOMAIN_UNMANAGED)
 		return NULL;
 
-	rk_domain = kzalloc(sizeof(*rk_domain), GFP_KERNEL);
-	if (!rk_domain)
+	/* Register a pdev per domain, so DMA API can base on this *dev
+	 * even some virtual master doesn't have an iommu slave
+	 */
+	pdev = platform_device_register_simple("rk_iommu_domain",
+					       PLATFORM_DEVID_AUTO, NULL, 0);
+	if (IS_ERR(pdev))
 		return NULL;
 
+	rk_domain = devm_kzalloc(&pdev->dev, sizeof(*rk_domain), GFP_KERNEL);
+	if (!rk_domain)
+		goto err_unreg_pdev;
+
+	rk_domain->pdev = pdev;
+
+	/* To init the iovad which is required by iommu_dma_init_domain() */
+	if (iommu_get_dma_cookie(&rk_domain->domain))
+		goto err_unreg_pdev;
+
 	/*
 	 * rk32xx iommus use a 2 level pagetable.
 	 * Each level1 (dt) and level2 (pt) table has 1024 4-byte entries.
@@ -885,9 +917,17 @@
 	 */
 	rk_domain->dt = (u32 *)get_zeroed_page(GFP_KERNEL | GFP_DMA32);
 	if (!rk_domain->dt)
-		goto err_dt;
+		goto err_put_cookie;
 
-	rk_table_flush(rk_domain->dt, NUM_DT_ENTRIES);
+	iommu_dev = &pdev->dev;
+	rk_domain->dt_dma = dma_map_single(iommu_dev, rk_domain->dt,
+					   SPAGE_SIZE, DMA_TO_DEVICE);
+	if (dma_mapping_error(iommu_dev, rk_domain->dt_dma)) {
+		dev_err(iommu_dev, "DMA map error for DT\n");
+		goto err_free_dt;
+	}
+
+	rk_table_flush(rk_domain, rk_domain->dt_dma, NUM_DT_ENTRIES);
 
 	spin_lock_init(&rk_domain->iommus_lock);
 	spin_lock_init(&rk_domain->dt_lock);
@@ -895,8 +935,13 @@
 
 	return &rk_domain->domain;
 
-err_dt:
-	kfree(rk_domain);
+err_free_dt:
+	free_page((unsigned long)rk_domain->dt);
+err_put_cookie:
+	iommu_put_dma_cookie(&rk_domain->domain);
+err_unreg_pdev:
+	platform_device_unregister(pdev);
+
 	return NULL;
 }
 
@@ -912,12 +957,19 @@
 		if (rk_dte_is_pt_valid(dte)) {
 			phys_addr_t pt_phys = rk_dte_pt_address(dte);
 			u32 *page_table = phys_to_virt(pt_phys);
+			dma_unmap_single(&rk_domain->pdev->dev, pt_phys,
+					 SPAGE_SIZE, DMA_TO_DEVICE);
 			free_page((unsigned long)page_table);
 		}
 	}
 
+	dma_unmap_single(&rk_domain->pdev->dev, rk_domain->dt_dma,
+			 SPAGE_SIZE, DMA_TO_DEVICE);
 	free_page((unsigned long)rk_domain->dt);
-	kfree(rk_domain);
+
+	iommu_put_dma_cookie(&rk_domain->domain);
+
+	platform_device_unregister(rk_domain->pdev);
 }
 
 static bool rk_iommu_is_dev_iommu_master(struct device *dev)
@@ -1029,6 +1081,30 @@
 	.pgsize_bitmap = RK_IOMMU_PGSIZE_BITMAP,
 };
 
+static int rk_iommu_domain_probe(struct platform_device *pdev)
+{
+	struct device *dev = &pdev->dev;
+
+	dev->dma_parms = devm_kzalloc(dev, sizeof(*dev->dma_parms), GFP_KERNEL);
+	if (!dev->dma_parms)
+		return -ENOMEM;
+
+	/* Set dma_ops for dev, otherwise it would be dummy_dma_ops */
+	arch_setup_dma_ops(dev, 0, DMA_BIT_MASK(32), NULL, false);
+
+	dma_set_max_seg_size(dev, DMA_BIT_MASK(32));
+	dma_coerce_mask_and_coherent(dev, DMA_BIT_MASK(32));
+
+	return 0;
+}
+
+static struct platform_driver rk_iommu_domain_driver = {
+	.probe = rk_iommu_domain_probe,
+	.driver = {
+		   .name = "rk_iommu_domain",
+	},
+};
+
 static int rk_iommu_probe(struct platform_device *pdev)
 {
 	struct device *dev = &pdev->dev;
@@ -1106,11 +1182,19 @@
 	if (ret)
 		return ret;
 
-	return platform_driver_register(&rk_iommu_driver);
+	ret = platform_driver_register(&rk_iommu_domain_driver);
+	if (ret)
+		return ret;
+
+	ret = platform_driver_register(&rk_iommu_driver);
+	if (ret)
+		platform_driver_unregister(&rk_iommu_domain_driver);
+	return ret;
 }
 static void __exit rk_iommu_exit(void)
 {
 	platform_driver_unregister(&rk_iommu_driver);
+	platform_driver_unregister(&rk_iommu_domain_driver);
 }
 
 subsys_initcall(rk_iommu_init);