vmscan: per memory cgroup slab shrinkers

This patch adds SHRINKER_MEMCG_AWARE flag.  If a shrinker has this flag
set, it will be called per memory cgroup.  The memory cgroup to scan
objects from is passed in shrink_control->memcg.  If the memory cgroup
is NULL, a memcg aware shrinker is supposed to scan objects from the
global list.  Unaware shrinkers are only called on global pressure with
memcg=NULL.

Signed-off-by: Vladimir Davydov <vdavydov@parallels.com>
Cc: Dave Chinner <david@fromorbit.com>
Cc: Johannes Weiner <hannes@cmpxchg.org>
Cc: Michal Hocko <mhocko@suse.cz>
Cc: Greg Thelen <gthelen@google.com>
Cc: Glauber Costa <glommer@gmail.com>
Cc: Alexander Viro <viro@zeniv.linux.org.uk>
Cc: Christoph Lameter <cl@linux.com>
Cc: Pekka Enberg <penberg@kernel.org>
Cc: David Rientjes <rientjes@google.com>
Cc: Joonsoo Kim <iamjoonsoo.kim@lge.com>
Cc: Tejun Heo <tj@kernel.org>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 095c1f9..3c2a1a8 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -352,7 +352,7 @@
 };
 
 #ifdef CONFIG_MEMCG_KMEM
-static bool memcg_kmem_is_active(struct mem_cgroup *memcg)
+bool memcg_kmem_is_active(struct mem_cgroup *memcg)
 {
 	return memcg->kmemcg_id >= 0;
 }
diff --git a/mm/memory-failure.c b/mm/memory-failure.c
index feb803b..1a735fa 100644
--- a/mm/memory-failure.c
+++ b/mm/memory-failure.c
@@ -242,15 +242,8 @@
 	 * Only call shrink_node_slabs here (which would also shrink
 	 * other caches) if access is not potentially fatal.
 	 */
-	if (access) {
-		int nr;
-		int nid = page_to_nid(p);
-		do {
-			nr = shrink_node_slabs(GFP_KERNEL, nid, 1000, 1000);
-			if (page_count(p) == 1)
-				break;
-		} while (nr > 10);
-	}
+	if (access)
+		drop_slab_node(page_to_nid(p));
 }
 EXPORT_SYMBOL_GPL(shake_page);
 
diff --git a/mm/vmscan.c b/mm/vmscan.c
index 8e645ee..803886b 100644
--- a/mm/vmscan.c
+++ b/mm/vmscan.c
@@ -232,10 +232,10 @@
 
 #define SHRINK_BATCH 128
 
-static unsigned long shrink_slabs(struct shrink_control *shrinkctl,
-				  struct shrinker *shrinker,
-				  unsigned long nr_scanned,
-				  unsigned long nr_eligible)
+static unsigned long do_shrink_slab(struct shrink_control *shrinkctl,
+				    struct shrinker *shrinker,
+				    unsigned long nr_scanned,
+				    unsigned long nr_eligible)
 {
 	unsigned long freed = 0;
 	unsigned long long delta;
@@ -344,9 +344,10 @@
 }
 
 /**
- * shrink_node_slabs - shrink slab caches of a given node
+ * shrink_slab - shrink slab caches
  * @gfp_mask: allocation context
  * @nid: node whose slab caches to target
+ * @memcg: memory cgroup whose slab caches to target
  * @nr_scanned: pressure numerator
  * @nr_eligible: pressure denominator
  *
@@ -355,6 +356,12 @@
  * @nid is passed along to shrinkers with SHRINKER_NUMA_AWARE set,
  * unaware shrinkers will receive a node id of 0 instead.
  *
+ * @memcg specifies the memory cgroup to target. If it is not NULL,
+ * only shrinkers with SHRINKER_MEMCG_AWARE set will be called to scan
+ * objects from the memory cgroup specified. Otherwise all shrinkers
+ * are called, and memcg aware shrinkers are supposed to scan the
+ * global list then.
+ *
  * @nr_scanned and @nr_eligible form a ratio that indicate how much of
  * the available objects should be scanned.  Page reclaim for example
  * passes the number of pages scanned and the number of pages on the
@@ -365,13 +372,17 @@
  *
  * Returns the number of reclaimed slab objects.
  */
-unsigned long shrink_node_slabs(gfp_t gfp_mask, int nid,
-				unsigned long nr_scanned,
-				unsigned long nr_eligible)
+static unsigned long shrink_slab(gfp_t gfp_mask, int nid,
+				 struct mem_cgroup *memcg,
+				 unsigned long nr_scanned,
+				 unsigned long nr_eligible)
 {
 	struct shrinker *shrinker;
 	unsigned long freed = 0;
 
+	if (memcg && !memcg_kmem_is_active(memcg))
+		return 0;
+
 	if (nr_scanned == 0)
 		nr_scanned = SWAP_CLUSTER_MAX;
 
@@ -390,12 +401,16 @@
 		struct shrink_control sc = {
 			.gfp_mask = gfp_mask,
 			.nid = nid,
+			.memcg = memcg,
 		};
 
+		if (memcg && !(shrinker->flags & SHRINKER_MEMCG_AWARE))
+			continue;
+
 		if (!(shrinker->flags & SHRINKER_NUMA_AWARE))
 			sc.nid = 0;
 
-		freed += shrink_slabs(&sc, shrinker, nr_scanned, nr_eligible);
+		freed += do_shrink_slab(&sc, shrinker, nr_scanned, nr_eligible);
 	}
 
 	up_read(&shrinker_rwsem);
@@ -404,6 +419,29 @@
 	return freed;
 }
 
+void drop_slab_node(int nid)
+{
+	unsigned long freed;
+
+	do {
+		struct mem_cgroup *memcg = NULL;
+
+		freed = 0;
+		do {
+			freed += shrink_slab(GFP_KERNEL, nid, memcg,
+					     1000, 1000);
+		} while ((memcg = mem_cgroup_iter(NULL, memcg, NULL)) != NULL);
+	} while (freed > 10);
+}
+
+void drop_slab(void)
+{
+	int nid;
+
+	for_each_online_node(nid)
+		drop_slab_node(nid);
+}
+
 static inline int is_page_cache_freeable(struct page *page)
 {
 	/*
@@ -2276,6 +2314,7 @@
 static bool shrink_zone(struct zone *zone, struct scan_control *sc,
 			bool is_classzone)
 {
+	struct reclaim_state *reclaim_state = current->reclaim_state;
 	unsigned long nr_reclaimed, nr_scanned;
 	bool reclaimable = false;
 
@@ -2294,6 +2333,7 @@
 		memcg = mem_cgroup_iter(root, NULL, &reclaim);
 		do {
 			unsigned long lru_pages;
+			unsigned long scanned;
 			struct lruvec *lruvec;
 			int swappiness;
 
@@ -2305,10 +2345,16 @@
 
 			lruvec = mem_cgroup_zone_lruvec(zone, memcg);
 			swappiness = mem_cgroup_swappiness(memcg);
+			scanned = sc->nr_scanned;
 
 			shrink_lruvec(lruvec, swappiness, sc, &lru_pages);
 			zone_lru_pages += lru_pages;
 
+			if (memcg && is_classzone)
+				shrink_slab(sc->gfp_mask, zone_to_nid(zone),
+					    memcg, sc->nr_scanned - scanned,
+					    lru_pages);
+
 			/*
 			 * Direct reclaim and kswapd have to scan all memory
 			 * cgroups to fulfill the overall scan target for the
@@ -2330,19 +2376,14 @@
 		 * Shrink the slab caches in the same proportion that
 		 * the eligible LRU pages were scanned.
 		 */
-		if (global_reclaim(sc) && is_classzone) {
-			struct reclaim_state *reclaim_state;
+		if (global_reclaim(sc) && is_classzone)
+			shrink_slab(sc->gfp_mask, zone_to_nid(zone), NULL,
+				    sc->nr_scanned - nr_scanned,
+				    zone_lru_pages);
 
-			shrink_node_slabs(sc->gfp_mask, zone_to_nid(zone),
-					  sc->nr_scanned - nr_scanned,
-					  zone_lru_pages);
-
-			reclaim_state = current->reclaim_state;
-			if (reclaim_state) {
-				sc->nr_reclaimed +=
-					reclaim_state->reclaimed_slab;
-				reclaim_state->reclaimed_slab = 0;
-			}
+		if (reclaim_state) {
+			sc->nr_reclaimed += reclaim_state->reclaimed_slab;
+			reclaim_state->reclaimed_slab = 0;
 		}
 
 		vmpressure(sc->gfp_mask, sc->target_mem_cgroup,