binder: TLS checks trigger properly.
Previously, libbinder_tls ignores the result of
isTriggeredPolled() by always returning OK regardless
of whether the shutdown trigger is triggered or not, causing
program to continue when it shouldn't be. Return the status
properly like FdTrigger::triggerablePoll:
- If any error during poll() return the code
- If shutdown, return -ECANCELED (new in this CL for TLS)
- Otherwise return OK
Refactor RpcTransportTest so that we can add a new test
to check that trigerablePoll() returns -ECANCELED in the
above case.
Test: binderRpcTest
Fixes: 199309623
Change-Id: Ia545ba71cc10be5c46f722a5d3e699f89e1bc70c
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index 880b9ce..2fd63a3 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -53,6 +53,7 @@
#include "RpcCertificateVerifierSimple.h"
using namespace std::chrono_literals;
+using namespace std::placeholders;
using testing::AssertionFailure;
using testing::AssertionResult;
using testing::AssertionSuccess;
@@ -1536,17 +1537,17 @@
ASSERT_TRUE(acceptedFd.ok());
auto serverTransport = mCtx->newTransport(std::move(acceptedFd), mFdTrigger.get());
if (serverTransport == nullptr) return; // handshake failed
- std::string message(kMessage);
- ASSERT_EQ(OK,
- serverTransport->interruptableWriteFully(mFdTrigger.get(), message.data(),
- message.size()));
+ ASSERT_TRUE(mPostConnect(serverTransport.get(), mFdTrigger.get()));
}
void shutdownAndWait() {
- mFdTrigger->trigger();
- if (mThread != nullptr) {
- mThread->join();
- mThread = nullptr;
- }
+ shutdown();
+ join();
+ }
+ void shutdown() { mFdTrigger->trigger(); }
+
+ void setPostConnect(
+ std::function<AssertionResult(RpcTransport*, FdTrigger* fdTrigger)> fn) {
+ mPostConnect = std::move(fn);
}
private:
@@ -1558,6 +1559,26 @@
std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier =
std::make_shared<RpcCertificateVerifierSimple>();
bool mSetup = false;
+ // The function invoked after connection and handshake. By default, it is
+ // |defaultPostConnect| that sends |kMessage| to the client.
+ std::function<AssertionResult(RpcTransport*, FdTrigger* fdTrigger)> mPostConnect =
+ Server::defaultPostConnect;
+
+ void join() {
+ if (mThread != nullptr) {
+ mThread->join();
+ mThread = nullptr;
+ }
+ }
+
+ static AssertionResult defaultPostConnect(RpcTransport* serverTransport,
+ FdTrigger* fdTrigger) {
+ std::string message(kMessage);
+ auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(),
+ message.size());
+ if (status != OK) return AssertionFailure() << statusToString(status);
+ return AssertionSuccess();
+ }
};
class Client {
@@ -1566,8 +1587,6 @@
Client(Client&&) = default;
[[nodiscard]] AssertionResult setUp() {
auto [socketType, rpcSecurity, certificateFormat] = GetParam();
- mFd = mConnectToServer();
- if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server";
mFdTrigger = FdTrigger::make();
mCtx = newFactory(rpcSecurity, mCertVerifier)->newClientCtx();
if (mCtx == nullptr) return AssertionFailure() << "newClientCtx";
@@ -1577,24 +1596,35 @@
std::shared_ptr<RpcCertificateVerifierSimple> getCertVerifier() const {
return mCertVerifier;
}
+ // connect() and do handshake
+ bool setUpTransport() {
+ mFd = mConnectToServer();
+ if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server";
+ mClientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get());
+ return mClientTransport != nullptr;
+ }
+ AssertionResult readMessage(const std::string& expectedMessage = kMessage) {
+ LOG_ALWAYS_FATAL_IF(mClientTransport == nullptr, "setUpTransport not called or failed");
+ std::string readMessage(expectedMessage.size(), '\0');
+ status_t readStatus =
+ mClientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(),
+ readMessage.size());
+ if (readStatus != OK) {
+ return AssertionFailure() << statusToString(readStatus);
+ }
+ if (readMessage != expectedMessage) {
+ return AssertionFailure()
+ << "Expected " << expectedMessage << ", actual " << readMessage;
+ }
+ return AssertionSuccess();
+ }
void run(bool handshakeOk = true, bool readOk = true) {
- auto clientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get());
- if (clientTransport == nullptr) {
+ if (!setUpTransport()) {
ASSERT_FALSE(handshakeOk) << "newTransport returns nullptr, but it shouldn't";
return;
}
ASSERT_TRUE(handshakeOk) << "newTransport does not return nullptr, but it should";
- std::string expectedMessage(kMessage);
- std::string readMessage(expectedMessage.size(), '\0');
- status_t readStatus =
- clientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(),
- readMessage.size());
- if (readOk) {
- ASSERT_EQ(OK, readStatus);
- ASSERT_EQ(readMessage, expectedMessage);
- } else {
- ASSERT_NE(OK, readStatus);
- }
+ ASSERT_EQ(readOk, readMessage());
}
private:
@@ -1604,6 +1634,7 @@
std::unique_ptr<RpcTransportCtx> mCtx;
std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier =
std::make_shared<RpcCertificateVerifierSimple>();
+ std::unique_ptr<RpcTransport> mClientTransport;
};
// Make A trust B.
@@ -1729,6 +1760,68 @@
maliciousClient.run(true, readOk);
}
+TEST_P(RpcTransportTest, Trigger) {
+ std::string msg2 = ", world!";
+ std::mutex writeMutex;
+ std::condition_variable writeCv;
+ bool shouldContinueWriting = false;
+ auto serverPostConnect = [&](RpcTransport* serverTransport, FdTrigger* fdTrigger) {
+ std::string message(kMessage);
+ auto status =
+ serverTransport->interruptableWriteFully(fdTrigger, message.data(), message.size());
+ if (status != OK) return AssertionFailure() << statusToString(status);
+
+ {
+ std::unique_lock<std::mutex> lock(writeMutex);
+ if (!writeCv.wait_for(lock, 3s, [&] { return shouldContinueWriting; })) {
+ return AssertionFailure() << "write barrier not cleared in time!";
+ }
+ }
+
+ status = serverTransport->interruptableWriteFully(fdTrigger, msg2.data(), msg2.size());
+ if (status != -ECANCELED)
+ return AssertionFailure() << "When FdTrigger is shut down, interruptableWriteFully "
+ "should return -ECANCELLED, but it is "
+ << statusToString(status);
+ return AssertionSuccess();
+ };
+
+ auto server = mServers.emplace_back(std::make_unique<Server>()).get();
+ ASSERT_TRUE(server->setUp());
+
+ // Set up client
+ Client client(server->getConnectToServerFn());
+ ASSERT_TRUE(client.setUp());
+
+ // Exchange keys
+ ASSERT_EQ(OK, trust(&client, server));
+ ASSERT_EQ(OK, trust(server, &client));
+
+ server->setPostConnect(serverPostConnect);
+
+ // Start server
+ server->start();
+ // connect() to server and do handshake
+ ASSERT_TRUE(client.setUpTransport());
+ // read the first message. This confirms that server has finished handshake and start handling
+ // client fd. Server thread should pause at waitForWriteBarrier.
+ ASSERT_TRUE(client.readMessage(kMessage));
+ // Trigger server shutdown after server starts handling client FD. This ensures that the second
+ // write is on an FdTrigger that has been shut down.
+ server->shutdown();
+ // Continues server thread to write the second message.
+ {
+ std::unique_lock<std::mutex> lock(writeMutex);
+ shouldContinueWriting = true;
+ lock.unlock();
+ writeCv.notify_all();
+ }
+ // After this line, server thread unblocks and attempts to write the second message, but
+ // shutdown is triggered, so write should failed with -ECANCELLED. See |serverPostConnect|.
+ // On the client side, second read fails with DEAD_OBJECT
+ ASSERT_FALSE(client.readMessage(msg2));
+}
+
std::vector<RpcCertificateFormat> testRpcCertificateFormats() {
return {
RpcCertificateFormat::PEM,