s390/mm,gmap: segment mapping race

The gmap_map_segment function creates a special invalid segment table
entry with the address of the requested target location in the process
address space. The first access will create the connection between the
gmap segment table and the target page table of the main process.
If two threads do this concurrently both will walk the page tables and
allocate a gmap_rmap structure for the same segment table entry.
To avoid the race recheck the segment table entry after taking to page
table lock.

Signed-off-by: Martin Schwidefsky <schwidefsky@de.ibm.com>
diff --git a/arch/s390/mm/pgtable.c b/arch/s390/mm/pgtable.c
index 2accf71..bd954e9 100644
--- a/arch/s390/mm/pgtable.c
+++ b/arch/s390/mm/pgtable.c
@@ -454,12 +454,11 @@
 }
 EXPORT_SYMBOL_GPL(gmap_translate);
 
-/*
- * this function is assumed to be called with mmap_sem held
- */
-unsigned long __gmap_fault(unsigned long address, struct gmap *gmap)
+static int gmap_connect_pgtable(unsigned long segment,
+				unsigned long *segment_ptr,
+				struct gmap *gmap)
 {
-	unsigned long *segment_ptr, vmaddr, segment;
+	unsigned long vmaddr;
 	struct vm_area_struct *vma;
 	struct gmap_pgtable *mp;
 	struct gmap_rmap *rmap;
@@ -469,48 +468,94 @@
 	pud_t *pud;
 	pmd_t *pmd;
 
+	mm = gmap->mm;
+	vmaddr = segment & _SEGMENT_ENTRY_ORIGIN;
+	vma = find_vma(mm, vmaddr);
+	if (!vma || vma->vm_start > vmaddr)
+		return -EFAULT;
+	/* Walk the parent mm page table */
+	pgd = pgd_offset(mm, vmaddr);
+	pud = pud_alloc(mm, pgd, vmaddr);
+	if (!pud)
+		return -ENOMEM;
+	pmd = pmd_alloc(mm, pud, vmaddr);
+	if (!pmd)
+		return -ENOMEM;
+	if (!pmd_present(*pmd) &&
+	    __pte_alloc(mm, vma, pmd, vmaddr))
+		return -ENOMEM;
+	/* pmd now points to a valid segment table entry. */
+	rmap = kmalloc(sizeof(*rmap), GFP_KERNEL|__GFP_REPEAT);
+	if (!rmap)
+		return -ENOMEM;
+	/* Link gmap segment table entry location to page table. */
+	page = pmd_page(*pmd);
+	mp = (struct gmap_pgtable *) page->index;
+	rmap->entry = segment_ptr;
+	spin_lock(&mm->page_table_lock);
+	if (*segment_ptr == segment) {
+		list_add(&rmap->list, &mp->mapper);
+		/* Set gmap segment table entry to page table. */
+		*segment_ptr = pmd_val(*pmd) & PAGE_MASK;
+		rmap = NULL;
+	}
+	spin_unlock(&mm->page_table_lock);
+	kfree(rmap);
+	return 0;
+}
+
+static void gmap_disconnect_pgtable(struct mm_struct *mm, unsigned long *table)
+{
+	struct gmap_rmap *rmap, *next;
+	struct gmap_pgtable *mp;
+	struct page *page;
+	int flush;
+
+	flush = 0;
+	spin_lock(&mm->page_table_lock);
+	page = pfn_to_page(__pa(table) >> PAGE_SHIFT);
+	mp = (struct gmap_pgtable *) page->index;
+	list_for_each_entry_safe(rmap, next, &mp->mapper, list) {
+		*rmap->entry =
+			_SEGMENT_ENTRY_INV | _SEGMENT_ENTRY_RO | mp->vmaddr;
+		list_del(&rmap->list);
+		kfree(rmap);
+		flush = 1;
+	}
+	spin_unlock(&mm->page_table_lock);
+	if (flush)
+		__tlb_flush_global();
+}
+
+/*
+ * this function is assumed to be called with mmap_sem held
+ */
+unsigned long __gmap_fault(unsigned long address, struct gmap *gmap)
+{
+	unsigned long *segment_ptr, segment;
+	struct gmap_pgtable *mp;
+	struct page *page;
+	int rc;
+
 	current->thread.gmap_addr = address;
 	segment_ptr = gmap_table_walk(address, gmap);
 	if (IS_ERR(segment_ptr))
 		return -EFAULT;
 	/* Convert the gmap address to an mm address. */
-	segment = *segment_ptr;
-	if (!(segment & _SEGMENT_ENTRY_INV)) {
-		page = pfn_to_page(segment >> PAGE_SHIFT);
-		mp = (struct gmap_pgtable *) page->index;
-		return mp->vmaddr | (address & ~PMD_MASK);
-	} else if (segment & _SEGMENT_ENTRY_RO) {
-		mm = gmap->mm;
-		vmaddr = segment & _SEGMENT_ENTRY_ORIGIN;
-		vma = find_vma(mm, vmaddr);
-		if (!vma || vma->vm_start > vmaddr)
-			return -EFAULT;
-
-		/* Walk the parent mm page table */
-		pgd = pgd_offset(mm, vmaddr);
-		pud = pud_alloc(mm, pgd, vmaddr);
-		if (!pud)
-			return -ENOMEM;
-		pmd = pmd_alloc(mm, pud, vmaddr);
-		if (!pmd)
-			return -ENOMEM;
-		if (!pmd_present(*pmd) &&
-		    __pte_alloc(mm, vma, pmd, vmaddr))
-			return -ENOMEM;
-		/* pmd now points to a valid segment table entry. */
-		rmap = kmalloc(sizeof(*rmap), GFP_KERNEL|__GFP_REPEAT);
-		if (!rmap)
-			return -ENOMEM;
-		/* Link gmap segment table entry location to page table. */
-		page = pmd_page(*pmd);
-		mp = (struct gmap_pgtable *) page->index;
-		rmap->entry = segment_ptr;
-		spin_lock(&mm->page_table_lock);
-		list_add(&rmap->list, &mp->mapper);
-		spin_unlock(&mm->page_table_lock);
-		/* Set gmap segment table entry to page table. */
-		*segment_ptr = pmd_val(*pmd) & PAGE_MASK;
-		return vmaddr | (address & ~PMD_MASK);
+	while (1) {
+		segment = *segment_ptr;
+		if (!(segment & _SEGMENT_ENTRY_INV)) {
+			/* Page table is present */
+			page = pfn_to_page(segment >> PAGE_SHIFT);
+			mp = (struct gmap_pgtable *) page->index;
+			return mp->vmaddr | (address & ~PMD_MASK);
+		}
+		if (!(segment & _SEGMENT_ENTRY_RO))
+			/* Nothing mapped in the gmap address space. */
+			break;
+		rc = gmap_connect_pgtable(segment, segment_ptr, gmap);
+		if (rc)
+			return rc;
 	}
 	return -EFAULT;
 }
@@ -574,29 +619,6 @@
 }
 EXPORT_SYMBOL_GPL(gmap_discard);
 
-void gmap_unmap_notifier(struct mm_struct *mm, unsigned long *table)
-{
-	struct gmap_rmap *rmap, *next;
-	struct gmap_pgtable *mp;
-	struct page *page;
-	int flush;
-
-	flush = 0;
-	spin_lock(&mm->page_table_lock);
-	page = pfn_to_page(__pa(table) >> PAGE_SHIFT);
-	mp = (struct gmap_pgtable *) page->index;
-	list_for_each_entry_safe(rmap, next, &mp->mapper, list) {
-		*rmap->entry =
-			_SEGMENT_ENTRY_INV | _SEGMENT_ENTRY_RO | mp->vmaddr;
-		list_del(&rmap->list);
-		kfree(rmap);
-		flush = 1;
-	}
-	spin_unlock(&mm->page_table_lock);
-	if (flush)
-		__tlb_flush_global();
-}
-
 static inline unsigned long *page_table_alloc_pgste(struct mm_struct *mm,
 						    unsigned long vmaddr)
 {
@@ -649,8 +671,8 @@
 {
 }
 
-static inline void gmap_unmap_notifier(struct mm_struct *mm,
-					  unsigned long *table)
+static inline void gmap_disconnect_pgtable(struct mm_struct *mm,
+					   unsigned long *table)
 {
 }
 
@@ -716,7 +738,7 @@
 	unsigned int bit, mask;
 
 	if (mm_has_pgste(mm)) {
-		gmap_unmap_notifier(mm, table);
+		gmap_disconnect_pgtable(mm, table);
 		return page_table_free_pgste(table);
 	}
 	/* Free 1K/2K page table fragment of a 4K page */
@@ -759,7 +781,7 @@
 
 	mm = tlb->mm;
 	if (mm_has_pgste(mm)) {
-		gmap_unmap_notifier(mm, table);
+		gmap_disconnect_pgtable(mm, table);
 		table = (unsigned long *) (__pa(table) | FRAG_MASK);
 		tlb_remove_table(tlb, table);
 		return;