mm: allow full handling of COW faults in ->fault handlers

Patch series "dax: Clear dirty bits after flushing caches", v5.

Patchset to clear dirty bits from radix tree of DAX inodes when caches
for corresponding pfns have been flushed.  In principle, these patches
enable handlers to easily update PTEs and do other work necessary to
finish the fault without duplicating the functionality present in the
generic code.  I'd like to thank Kirill and Ross for reviews of the
series!

This patch (of 20):

To allow full handling of COW faults add memcg field to struct vm_fault
and a return value of ->fault() handler meaning that COW fault is fully
handled and memcg charge must not be canceled.  This will allow us to
remove knowledge about special DAX locking from the generic fault code.

Link: http://lkml.kernel.org/r/1479460644-25076-9-git-send-email-jack@suse.cz
Signed-off-by: Jan Kara <jack@suse.cz>
Reviewed-by: Ross Zwisler <ross.zwisler@linux.intel.com>
Acked-by: Kirill A. Shutemov <kirill.shutemov@linux.intel.com>
Cc: Dan Williams <dan.j.williams@intel.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
diff --git a/mm/memory.c b/mm/memory.c
index cf74f7c..02504cd 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -2844,9 +2844,8 @@ static int __do_fault(struct vm_fault *vmf)
 	int ret;
 
 	ret = vma->vm_ops->fault(vma, vmf);
-	if (unlikely(ret & (VM_FAULT_ERROR | VM_FAULT_NOPAGE | VM_FAULT_RETRY)))
-		return ret;
-	if (ret & VM_FAULT_DAX_LOCKED)
+	if (unlikely(ret & (VM_FAULT_ERROR | VM_FAULT_NOPAGE | VM_FAULT_RETRY |
+			    VM_FAULT_DAX_LOCKED | VM_FAULT_DONE_COW)))
 		return ret;
 
 	if (unlikely(PageHWPoison(vmf->page))) {
@@ -3226,7 +3225,6 @@ static int do_read_fault(struct vm_fault *vmf)
 static int do_cow_fault(struct vm_fault *vmf)
 {
 	struct vm_area_struct *vma = vmf->vma;
-	struct mem_cgroup *memcg;
 	int ret;
 
 	if (unlikely(anon_vma_prepare(vma)))
@@ -3237,7 +3235,7 @@ static int do_cow_fault(struct vm_fault *vmf)
 		return VM_FAULT_OOM;
 
 	if (mem_cgroup_try_charge(vmf->cow_page, vma->vm_mm, GFP_KERNEL,
-				&memcg, false)) {
+				&vmf->memcg, false)) {
 		put_page(vmf->cow_page);
 		return VM_FAULT_OOM;
 	}
@@ -3245,12 +3243,14 @@ static int do_cow_fault(struct vm_fault *vmf)
 	ret = __do_fault(vmf);
 	if (unlikely(ret & (VM_FAULT_ERROR | VM_FAULT_NOPAGE | VM_FAULT_RETRY)))
 		goto uncharge_out;
+	if (ret & VM_FAULT_DONE_COW)
+		return ret;
 
 	if (!(ret & VM_FAULT_DAX_LOCKED))
 		copy_user_highpage(vmf->cow_page, vmf->page, vmf->address, vma);
 	__SetPageUptodate(vmf->cow_page);
 
-	ret |= alloc_set_pte(vmf, memcg, vmf->cow_page);
+	ret |= alloc_set_pte(vmf, vmf->memcg, vmf->cow_page);
 	if (vmf->pte)
 		pte_unmap_unlock(vmf->pte, vmf->ptl);
 	if (!(ret & VM_FAULT_DAX_LOCKED)) {
@@ -3263,7 +3263,7 @@ static int do_cow_fault(struct vm_fault *vmf)
 		goto uncharge_out;
 	return ret;
 uncharge_out:
-	mem_cgroup_cancel_charge(vmf->cow_page, memcg, false);
+	mem_cgroup_cancel_charge(vmf->cow_page, vmf->memcg, false);
 	put_page(vmf->cow_page);
 	return ret;
 }