Merge "Create new mount directory /mnt/runtime/full."
diff --git a/Checkpoint.cpp b/Checkpoint.cpp
index 28855e6..7586a6c 100644
--- a/Checkpoint.cpp
+++ b/Checkpoint.cpp
@@ -229,6 +229,7 @@
     uint32_t magic;
     uint32_t count;
     uint32_t sequence;
+    uint64_t sector0;
     struct log_entry entries[];
 } __attribute__((packed));
 
@@ -289,62 +290,119 @@
 
 }  // namespace
 
+static void read(std::fstream& device, std::vector<log_entry> const& logs, sector_t sector,
+                 char* buffer) {
+    for (auto l = logs.rbegin(); l != logs.rend(); l++)
+        if (sector >= l->source && (sector - l->source) * kSectorSize < l->size)
+            sector = sector - l->source + l->dest;
+
+    device.seekg(sector * kSectorSize);
+    device.read(buffer, kBlockSize);
+}
+
+static std::vector<char> read(std::fstream& device, std::vector<log_entry> const& logs,
+                              bool validating, sector_t sector, uint32_t size) {
+    if (!validating) {
+        std::vector<char> buffer(size);
+        device.seekg(sector * kSectorSize);
+        device.read(&buffer[0], size);
+        return buffer;
+    }
+
+    // Crude approach at first where we do this sector by sector and just scan
+    // the entire logs for remappings each time
+    std::vector<char> buffer(size);
+
+    for (uint32_t i = 0; i < size; i += kBlockSize, sector += kBlockSize / kSectorSize)
+        read(device, logs, sector, &buffer[i]);
+
+    return buffer;
+}
+
 Status cp_restoreCheckpoint(const std::string& blockDevice) {
-    LOG(INFO) << "Restoring checkpoint on " << blockDevice;
-    std::fstream device(blockDevice, std::ios::binary | std::ios::in | std::ios::out);
-    if (!device) {
-        PLOG(ERROR) << "Cannot open " << blockDevice;
-        return Status::fromExceptionCode(errno, ("Cannot open " + blockDevice).c_str());
-    }
-    alignas(alignof(log_sector)) char ls_buffer[kBlockSize];
-    device.read(ls_buffer, kBlockSize);
-    log_sector& ls = *reinterpret_cast<log_sector*>(ls_buffer);
-    if (ls.magic != kMagic) {
-        LOG(ERROR) << "No magic";
-        return Status::fromExceptionCode(EINVAL, "No magic");
-    }
+    bool validating = true;
+    std::string action = "Validating";
 
-    LOG(INFO) << "Restoring " << ls.sequence << " log sectors";
+    for (;;) {
+        std::vector<log_entry> logs;
+        Status status = Status::ok();
 
-    for (int sequence = ls.sequence; sequence >= 0; sequence--) {
-        device.seekg(0);
-        device.read(ls_buffer, kBlockSize);
-        ls = *reinterpret_cast<log_sector*>(ls_buffer);
+        LOG(INFO) << action << " checkpoint on " << blockDevice;
+        std::fstream device(blockDevice, std::ios::binary | std::ios::in | std::ios::out);
+        if (!device) {
+            PLOG(ERROR) << "Cannot open " << blockDevice;
+            return Status::fromExceptionCode(errno, ("Cannot open " + blockDevice).c_str());
+        }
+        auto buffer = read(device, logs, validating, 0, kBlockSize);
+        log_sector& ls = *reinterpret_cast<log_sector*>(&buffer[0]);
         if (ls.magic != kMagic) {
-            LOG(ERROR) << "No magic!";
+            LOG(ERROR) << "No magic";
             return Status::fromExceptionCode(EINVAL, "No magic");
         }
 
-        if ((int)ls.sequence != sequence) {
-            LOG(ERROR) << "Expecting log sector " << sequence << " but got " << ls.sequence;
-            return Status::fromExceptionCode(
-                EINVAL, ("Expecting log sector " + std::to_string(sequence) + " but got " +
-                         std::to_string(ls.sequence))
-                            .c_str());
-        }
+        LOG(INFO) << action << " " << ls.sequence << " log sectors";
 
-        LOG(INFO) << "Restoring from log sector " << ls.sequence;
-
-        for (log_entry* le = &ls.entries[ls.count - 1]; le >= ls.entries; --le) {
-            LOG(INFO) << "Restoring " << le->size << " bytes from sector " << le->dest << " to "
-                      << le->source << " with checksum " << std::hex << le->checksum;
-            std::vector<char> buffer(le->size);
-            device.seekg(le->dest * kSectorSize);
-            device.read(&buffer[0], le->size);
-
-            uint32_t checksum = le->source / (kBlockSize / kSectorSize);
-            for (size_t i = 0; i < le->size; i += kBlockSize) {
-                crc32(&buffer[i], kBlockSize, &checksum);
+        for (int sequence = ls.sequence; sequence >= 0 && status.isOk(); sequence--) {
+            auto buffer = read(device, logs, validating, 0, kBlockSize);
+            log_sector& ls = *reinterpret_cast<log_sector*>(&buffer[0]);
+            if (ls.magic != kMagic) {
+                LOG(ERROR) << "No magic!";
+                status = Status::fromExceptionCode(EINVAL, "No magic");
+                break;
             }
 
-            if (le->checksum && checksum != le->checksum) {
-                LOG(ERROR) << "Checksums don't match " << std::hex << checksum;
-                return Status::fromExceptionCode(EINVAL, "Checksums don't match");
+            if ((int)ls.sequence != sequence) {
+                LOG(ERROR) << "Expecting log sector " << sequence << " but got " << ls.sequence;
+                status = Status::fromExceptionCode(
+                    EINVAL, ("Expecting log sector " + std::to_string(sequence) + " but got " +
+                             std::to_string(ls.sequence))
+                                .c_str());
+                break;
             }
 
-            device.seekg(le->source * kSectorSize);
-            device.write(&buffer[0], le->size);
+            LOG(INFO) << action << " from log sector " << ls.sequence;
+
+            for (log_entry* le = &ls.entries[ls.count - 1]; le >= ls.entries; --le) {
+                LOG(INFO) << action << " " << le->size << " bytes from sector " << le->dest
+                          << " to " << le->source << " with checksum " << std::hex << le->checksum;
+                auto buffer = read(device, logs, validating, le->dest, le->size);
+                uint32_t checksum = le->source / (kBlockSize / kSectorSize);
+                for (size_t i = 0; i < le->size; i += kBlockSize) {
+                    crc32(&buffer[i], kBlockSize, &checksum);
+                }
+
+                if (le->checksum && checksum != le->checksum) {
+                    LOG(ERROR) << "Checksums don't match " << std::hex << checksum;
+                    status = Status::fromExceptionCode(EINVAL, "Checksums don't match");
+                    break;
+                }
+
+                logs.push_back(*le);
+
+                if (!validating) {
+                    device.seekg(le->source * kSectorSize);
+                    device.write(&buffer[0], le->size);
+                }
+            }
         }
+
+        if (!status.isOk()) {
+            if (!validating) {
+                LOG(ERROR) << "Checkpoint restore failed even though checkpoint validation passed";
+                return status;
+            }
+
+            LOG(WARNING) << "Checkpoint validation failed - attempting to roll forward";
+            auto buffer = read(device, logs, false, ls.sector0, kBlockSize);
+            device.seekg(0);
+            device.write(&buffer[0], kBlockSize);
+            return Status::ok();
+        }
+
+        if (!validating) break;
+
+        validating = false;
+        action = "Restoring";
     }
 
     return Status::ok();
diff --git a/Utils.cpp b/Utils.cpp
index 5e12194..4464afc 100644
--- a/Utils.cpp
+++ b/Utils.cpp
@@ -753,9 +753,52 @@
     return android::base::GetBoolProperty("ro.kernel.qemu", false);
 }
 
-status_t UnmountTree(const std::string& prefix) {
-    if (umount2(prefix.c_str(), MNT_DETACH)) {
-        PLOG(ERROR) << "Failed to unmount " << prefix;
+static status_t findMountPointsWithPrefix(const std::string& prefix,
+                                          std::list<std::string>& mountPoints) {
+    // Add a trailing slash if the client didn't provide one so that we don't match /foo/barbaz
+    // when the prefix is /foo/bar
+    std::string prefixWithSlash(prefix);
+    if (prefix.back() != '/') {
+        android::base::StringAppendF(&prefixWithSlash, "/");
+    }
+
+    std::unique_ptr<FILE, int (*)(FILE*)> mnts(setmntent("/proc/mounts", "re"), endmntent);
+    if (!mnts) {
+        PLOG(ERROR) << "Unable to open /proc/mounts";
+        return -errno;
+    }
+
+    // Some volumes can be stacked on each other, so force unmount in
+    // reverse order to give us the best chance of success.
+    struct mntent* mnt;  // getmntent returns a thread local, so it's safe.
+    while ((mnt = getmntent(mnts.get())) != nullptr) {
+        auto mountPoint = std::string(mnt->mnt_dir) + "/";
+        if (android::base::StartsWith(mountPoint, prefixWithSlash)) {
+            mountPoints.push_front(mountPoint);
+        }
+    }
+    return OK;
+}
+
+// Unmount all mountpoints that start with prefix. prefix itself doesn't need to be a mountpoint.
+status_t UnmountTreeWithPrefix(const std::string& prefix) {
+    std::list<std::string> toUnmount;
+    status_t result = findMountPointsWithPrefix(prefix, toUnmount);
+    if (result < 0) {
+        return result;
+    }
+    for (const auto& path : toUnmount) {
+        if (umount2(path.c_str(), MNT_DETACH)) {
+            PLOG(ERROR) << "Failed to unmount " << path;
+            result = -errno;
+        }
+    }
+    return result;
+}
+
+status_t UnmountTree(const std::string& mountPoint) {
+    if (umount2(mountPoint.c_str(), MNT_DETACH)) {
+        PLOG(ERROR) << "Failed to unmount " << mountPoint;
         return -errno;
     }
     return OK;
diff --git a/Utils.h b/Utils.h
index 0b35a7b..48a57d9 100644
--- a/Utils.h
+++ b/Utils.h
@@ -127,7 +127,8 @@
 /* Checks if Android is running in QEMU */
 bool IsRunningInEmulator();
 
-status_t UnmountTree(const std::string& prefix);
+status_t UnmountTreeWithPrefix(const std::string& prefix);
+status_t UnmountTree(const std::string& mountPoint);
 
 status_t WaitForFile(const char* filename, std::chrono::nanoseconds timeout);
 
diff --git a/vold_prepare_subdirs.cpp b/vold_prepare_subdirs.cpp
index 8c3df30..a7c5e3d 100644
--- a/vold_prepare_subdirs.cpp
+++ b/vold_prepare_subdirs.cpp
@@ -128,6 +128,7 @@
             auto misc_de_path = android::vold::BuildDataMiscDePath(user_id);
             if (!prepare_dir(sehandle, 0700, 0, 0, misc_de_path + "/vold")) return false;
             if (!prepare_dir(sehandle, 0700, 0, 0, misc_de_path + "/storaged")) return false;
+            if (!prepare_dir(sehandle, 0700, 0, 0, misc_de_path + "/rollback")) return false;
 
             auto vendor_de_path = android::vold::BuildDataVendorDePath(user_id);
             if (!prepare_dir(sehandle, 0700, AID_SYSTEM, AID_SYSTEM, vendor_de_path + "/fpdata")) {
@@ -138,6 +139,16 @@
             auto misc_ce_path = android::vold::BuildDataMiscCePath(user_id);
             if (!prepare_dir(sehandle, 0700, 0, 0, misc_ce_path + "/vold")) return false;
             if (!prepare_dir(sehandle, 0700, 0, 0, misc_ce_path + "/storaged")) return false;
+            if (!prepare_dir(sehandle, 0700, 0, 0, misc_ce_path + "/rollback")) return false;
+
+            auto system_ce_path = android::vold::BuildDataSystemCePath(user_id);
+            if (!prepare_dir(sehandle, 0700, AID_SYSTEM, AID_SYSTEM, system_ce_path + "/backup")) {
+                return false;
+            }
+            if (!prepare_dir(sehandle, 0700, AID_SYSTEM, AID_SYSTEM,
+                             system_ce_path + "/backup_stage")) {
+                return false;
+            }
         }
     }
     return true;