Add ReadOrAgain and WriteOrAgain methods to FuseMessage.
These methods return kAgain if operation cannot be done without blocking
the current thread.
The CL also introduecs new helper function SetupMessageSockets so that
FuseMessages are always transfered via sockets that save message
boundaries.
Bug: 34903085
Test: libappfuse_test
Change-Id: I34544372cc1b0c7bc9622e581ae16c018a123fa9
diff --git a/libappfuse/FuseBuffer.cc b/libappfuse/FuseBuffer.cc
index 8fb2dbc..13cfc88 100644
--- a/libappfuse/FuseBuffer.cc
+++ b/libappfuse/FuseBuffer.cc
@@ -23,77 +23,132 @@
#include <algorithm>
#include <type_traits>
+#include <sys/socket.h>
+
#include <android-base/file.h>
#include <android-base/logging.h>
#include <android-base/macros.h>
namespace android {
namespace fuse {
-
-static_assert(
- std::is_standard_layout<FuseBuffer>::value,
- "FuseBuffer must be standard layout union.");
+namespace {
template <typename T>
-bool FuseMessage<T>::CheckHeaderLength(const char* name) const {
- const auto& header = static_cast<const T*>(this)->header;
- if (header.len >= sizeof(header) && header.len <= sizeof(T)) {
+bool CheckHeaderLength(const FuseMessage<T>* self, const char* name) {
+ const auto& header = static_cast<const T*>(self)->header;
+ if (header.len >= sizeof(header) && header.len <= sizeof(T)) {
+ return true;
+ } else {
+ LOG(ERROR) << "Invalid header length is found in " << name << ": " << header.len;
+ return false;
+ }
+}
+
+template <typename T>
+ResultOrAgain ReadInternal(FuseMessage<T>* self, int fd, int sockflag) {
+ char* const buf = reinterpret_cast<char*>(self);
+ const ssize_t result = sockflag ? TEMP_FAILURE_RETRY(recv(fd, buf, sizeof(T), sockflag))
+ : TEMP_FAILURE_RETRY(read(fd, buf, sizeof(T)));
+
+ switch (result) {
+ case 0:
+ // Expected EOF.
+ return ResultOrAgain::kFailure;
+ case -1:
+ if (errno == EAGAIN) {
+ return ResultOrAgain::kAgain;
+ }
+ PLOG(ERROR) << "Failed to read a FUSE message";
+ return ResultOrAgain::kFailure;
+ }
+
+ const auto& header = static_cast<const T*>(self)->header;
+ if (result < static_cast<ssize_t>(sizeof(header))) {
+ LOG(ERROR) << "Read bytes " << result << " are shorter than header size " << sizeof(header);
+ return ResultOrAgain::kFailure;
+ }
+
+ if (!CheckHeaderLength<T>(self, "Read")) {
+ return ResultOrAgain::kFailure;
+ }
+
+ if (static_cast<uint32_t>(result) != header.len) {
+ LOG(ERROR) << "Read bytes " << result << " are different from header.len " << header.len;
+ return ResultOrAgain::kFailure;
+ }
+
+ return ResultOrAgain::kSuccess;
+}
+
+template <typename T>
+ResultOrAgain WriteInternal(const FuseMessage<T>* self, int fd, int sockflag) {
+ if (!CheckHeaderLength<T>(self, "Write")) {
+ return ResultOrAgain::kFailure;
+ }
+
+ const char* const buf = reinterpret_cast<const char*>(self);
+ const auto& header = static_cast<const T*>(self)->header;
+ const int result = sockflag ? TEMP_FAILURE_RETRY(send(fd, buf, header.len, sockflag))
+ : TEMP_FAILURE_RETRY(write(fd, buf, header.len));
+
+ if (result == -1) {
+ if (errno == EAGAIN) {
+ return ResultOrAgain::kAgain;
+ }
+ PLOG(ERROR) << "Failed to write a FUSE message";
+ return ResultOrAgain::kFailure;
+ }
+
+ CHECK(static_cast<uint32_t>(result) == header.len);
+ return ResultOrAgain::kSuccess;
+}
+}
+
+static_assert(std::is_standard_layout<FuseBuffer>::value,
+ "FuseBuffer must be standard layout union.");
+
+bool SetupMessageSockets(base::unique_fd (*result)[2]) {
+ base::unique_fd fds[2];
+ {
+ int raw_fds[2];
+ if (socketpair(AF_UNIX, SOCK_SEQPACKET, 0, raw_fds) == -1) {
+ PLOG(ERROR) << "Failed to create sockets for proxy";
+ return false;
+ }
+ fds[0].reset(raw_fds[0]);
+ fds[1].reset(raw_fds[1]);
+ }
+
+ constexpr int kMaxMessageSize = sizeof(FuseBuffer);
+ if (setsockopt(fds[0], SOL_SOCKET, SO_SNDBUF, &kMaxMessageSize, sizeof(int)) != 0 ||
+ setsockopt(fds[1], SOL_SOCKET, SO_SNDBUF, &kMaxMessageSize, sizeof(int)) != 0) {
+ PLOG(ERROR) << "Failed to update buffer size for socket";
+ return false;
+ }
+
+ (*result)[0] = std::move(fds[0]);
+ (*result)[1] = std::move(fds[1]);
return true;
- } else {
- LOG(ERROR) << "Invalid header length is found in " << name << ": " <<
- header.len;
- return false;
- }
}
template <typename T>
bool FuseMessage<T>::Read(int fd) {
- char* const buf = reinterpret_cast<char*>(this);
- const ssize_t result = TEMP_FAILURE_RETRY(::read(fd, buf, sizeof(T)));
- if (result < 0) {
- PLOG(ERROR) << "Failed to read a FUSE message";
- return false;
- }
+ return ReadInternal(this, fd, 0) == ResultOrAgain::kSuccess;
+}
- const auto& header = static_cast<const T*>(this)->header;
- if (result < static_cast<ssize_t>(sizeof(header))) {
- LOG(ERROR) << "Read bytes " << result << " are shorter than header size " <<
- sizeof(header);
- return false;
- }
-
- if (!CheckHeaderLength("Read")) {
- return false;
- }
-
- if (static_cast<uint32_t>(result) > header.len) {
- LOG(ERROR) << "Read bytes " << result << " are longer than header.len " <<
- header.len;
- return false;
- }
-
- if (!base::ReadFully(fd, buf + result, header.len - result)) {
- PLOG(ERROR) << "ReadFully failed";
- return false;
- }
-
- return true;
+template <typename T>
+ResultOrAgain FuseMessage<T>::ReadOrAgain(int fd) {
+ return ReadInternal(this, fd, MSG_DONTWAIT);
}
template <typename T>
bool FuseMessage<T>::Write(int fd) const {
- if (!CheckHeaderLength("Write")) {
- return false;
- }
+ return WriteInternal(this, fd, 0) == ResultOrAgain::kSuccess;
+}
- const char* const buf = reinterpret_cast<const char*>(this);
- const auto& header = static_cast<const T*>(this)->header;
- if (!base::WriteFully(fd, buf, header.len)) {
- PLOG(ERROR) << "WriteFully failed";
- return false;
- }
-
- return true;
+template <typename T>
+ResultOrAgain FuseMessage<T>::WriteOrAgain(int fd) const {
+ return WriteInternal(this, fd, MSG_DONTWAIT);
}
template class FuseMessage<FuseRequest>;