| /* |
| * Copyright (C) 2019 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except |
| * in compliance with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "adb/tls/tls_connection.h" |
| |
| #include <algorithm> |
| #include <vector> |
| |
| #include <android-base/logging.h> |
| #include <android-base/strings.h> |
| #include <openssl/err.h> |
| #include <openssl/ssl.h> |
| |
| using android::base::borrowed_fd; |
| |
| namespace adb { |
| namespace tls { |
| |
| namespace { |
| |
| static constexpr char kExportedKeyLabel[] = "adb-label"; |
| |
| class TlsConnectionImpl : public TlsConnection { |
| public: |
| explicit TlsConnectionImpl(Role role, std::string_view cert, std::string_view priv_key, |
| borrowed_fd fd); |
| ~TlsConnectionImpl() override; |
| |
| bool AddTrustedCertificate(std::string_view cert) override; |
| void SetCertVerifyCallback(CertVerifyCb cb) override; |
| void SetCertificateCallback(SetCertCb cb) override; |
| void SetClientCAList(STACK_OF(X509_NAME) * ca_list) override; |
| std::vector<uint8_t> ExportKeyingMaterial(size_t length) override; |
| void EnableClientPostHandshakeCheck(bool enable) override; |
| TlsError DoHandshake() override; |
| std::vector<uint8_t> ReadFully(size_t size) override; |
| bool ReadFully(void* buf, size_t size) override; |
| bool WriteFully(std::string_view data) override; |
| |
| static bssl::UniquePtr<EVP_PKEY> EvpPkeyFromPEM(std::string_view pem); |
| static bssl::UniquePtr<CRYPTO_BUFFER> BufferFromPEM(std::string_view pem); |
| |
| private: |
| static int SSLSetCertVerifyCb(X509_STORE_CTX* ctx, void* opaque); |
| static int SSLSetCertCb(SSL* ssl, void* opaque); |
| |
| static bssl::UniquePtr<X509> X509FromBuffer(bssl::UniquePtr<CRYPTO_BUFFER> buffer); |
| static const char* SSLErrorString(); |
| void Invalidate(); |
| TlsError GetFailureReason(int err); |
| |
| Role role_; |
| bssl::UniquePtr<EVP_PKEY> priv_key_; |
| bssl::UniquePtr<CRYPTO_BUFFER> cert_; |
| |
| bssl::UniquePtr<STACK_OF(X509_NAME)> ca_list_; |
| bssl::UniquePtr<SSL_CTX> ssl_ctx_; |
| bssl::UniquePtr<SSL> ssl_; |
| std::vector<bssl::UniquePtr<X509>> known_certificates_; |
| bool client_verify_post_handshake_ = false; |
| |
| CertVerifyCb cert_verify_cb_; |
| SetCertCb set_cert_cb_; |
| borrowed_fd fd_; |
| }; // TlsConnectionImpl |
| |
| TlsConnectionImpl::TlsConnectionImpl(Role role, std::string_view cert, std::string_view priv_key, |
| borrowed_fd fd) |
| : role_(role), fd_(fd) { |
| CHECK(!cert.empty() && !priv_key.empty()); |
| LOG(INFO) << "Initializing adbwifi TlsConnection"; |
| cert_ = BufferFromPEM(cert); |
| priv_key_ = EvpPkeyFromPEM(priv_key); |
| } |
| |
| TlsConnectionImpl::~TlsConnectionImpl() { |
| // shutdown the SSL connection |
| if (ssl_ != nullptr) { |
| SSL_shutdown(ssl_.get()); |
| } |
| } |
| |
| // static |
| const char* TlsConnectionImpl::SSLErrorString() { |
| auto sslerr = ERR_peek_last_error(); |
| return ERR_reason_error_string(sslerr); |
| } |
| |
| // static |
| bssl::UniquePtr<EVP_PKEY> TlsConnectionImpl::EvpPkeyFromPEM(std::string_view pem) { |
| bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(pem.data(), pem.size())); |
| return bssl::UniquePtr<EVP_PKEY>(PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr)); |
| } |
| |
| // static |
| bssl::UniquePtr<CRYPTO_BUFFER> TlsConnectionImpl::BufferFromPEM(std::string_view pem) { |
| bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(pem.data(), pem.size())); |
| char* name = nullptr; |
| char* header = nullptr; |
| uint8_t* data = nullptr; |
| long data_len = 0; |
| |
| if (!PEM_read_bio(bio.get(), &name, &header, &data, &data_len)) { |
| LOG(ERROR) << "Failed to read certificate"; |
| return nullptr; |
| } |
| OPENSSL_free(name); |
| OPENSSL_free(header); |
| |
| auto ret = bssl::UniquePtr<CRYPTO_BUFFER>(CRYPTO_BUFFER_new(data, data_len, nullptr)); |
| OPENSSL_free(data); |
| return ret; |
| } |
| |
| // static |
| bssl::UniquePtr<X509> TlsConnectionImpl::X509FromBuffer(bssl::UniquePtr<CRYPTO_BUFFER> buffer) { |
| if (!buffer) { |
| return nullptr; |
| } |
| return bssl::UniquePtr<X509>(X509_parse_from_buffer(buffer.get())); |
| } |
| |
| // static |
| int TlsConnectionImpl::SSLSetCertVerifyCb(X509_STORE_CTX* ctx, void* opaque) { |
| auto* p = reinterpret_cast<TlsConnectionImpl*>(opaque); |
| return p->cert_verify_cb_(ctx); |
| } |
| |
| // static |
| int TlsConnectionImpl::SSLSetCertCb(SSL* ssl, void* opaque) { |
| auto* p = reinterpret_cast<TlsConnectionImpl*>(opaque); |
| return p->set_cert_cb_(ssl); |
| } |
| |
| bool TlsConnectionImpl::AddTrustedCertificate(std::string_view cert) { |
| // Create X509 buffer from the certificate string |
| auto buf = X509FromBuffer(BufferFromPEM(cert)); |
| if (buf == nullptr) { |
| LOG(ERROR) << "Failed to create a X509 buffer for the certificate."; |
| return false; |
| } |
| known_certificates_.push_back(std::move(buf)); |
| return true; |
| } |
| |
| void TlsConnectionImpl::SetCertVerifyCallback(CertVerifyCb cb) { |
| cert_verify_cb_ = cb; |
| } |
| |
| void TlsConnectionImpl::SetCertificateCallback(SetCertCb cb) { |
| set_cert_cb_ = cb; |
| } |
| |
| void TlsConnectionImpl::SetClientCAList(STACK_OF(X509_NAME) * ca_list) { |
| CHECK(role_ == Role::Server); |
| ca_list_.reset(ca_list != nullptr ? SSL_dup_CA_list(ca_list) : nullptr); |
| } |
| |
| std::vector<uint8_t> TlsConnectionImpl::ExportKeyingMaterial(size_t length) { |
| if (ssl_.get() == nullptr) { |
| return {}; |
| } |
| |
| std::vector<uint8_t> out(length); |
| if (SSL_export_keying_material(ssl_.get(), out.data(), out.size(), kExportedKeyLabel, |
| sizeof(kExportedKeyLabel), nullptr, 0, false) == 0) { |
| return {}; |
| } |
| return out; |
| } |
| |
| void TlsConnectionImpl::EnableClientPostHandshakeCheck(bool enable) { |
| client_verify_post_handshake_ = enable; |
| } |
| |
| TlsConnection::TlsError TlsConnectionImpl::GetFailureReason(int err) { |
| switch (ERR_GET_REASON(err)) { |
| case SSL_R_SSLV3_ALERT_BAD_CERTIFICATE: |
| case SSL_R_SSLV3_ALERT_UNSUPPORTED_CERTIFICATE: |
| case SSL_R_SSLV3_ALERT_CERTIFICATE_REVOKED: |
| case SSL_R_SSLV3_ALERT_CERTIFICATE_EXPIRED: |
| case SSL_R_SSLV3_ALERT_CERTIFICATE_UNKNOWN: |
| case SSL_R_TLSV1_ALERT_ACCESS_DENIED: |
| case SSL_R_TLSV1_ALERT_UNKNOWN_CA: |
| case SSL_R_TLSV1_CERTIFICATE_REQUIRED: |
| return TlsError::PeerRejectedCertificate; |
| case SSL_R_CERTIFICATE_VERIFY_FAILED: |
| return TlsError::CertificateRejected; |
| default: |
| return TlsError::UnknownFailure; |
| } |
| } |
| |
| TlsConnection::TlsError TlsConnectionImpl::DoHandshake() { |
| int err = -1; |
| LOG(INFO) << "Starting adbwifi tls handshake"; |
| ssl_ctx_.reset(SSL_CTX_new(TLS_method())); |
| // TODO: Remove set_max_proto_version() once external/boringssl is updated |
| // past |
| // https://boringssl.googlesource.com/boringssl/+/58d56f4c59969a23e5f52014e2651c76fea2f877 |
| if (ssl_ctx_.get() == nullptr || |
| !SSL_CTX_set_min_proto_version(ssl_ctx_.get(), TLS1_3_VERSION) || |
| !SSL_CTX_set_max_proto_version(ssl_ctx_.get(), TLS1_3_VERSION)) { |
| LOG(ERROR) << "Failed to create SSL context"; |
| return TlsError::UnknownFailure; |
| } |
| |
| // Register user-supplied known certificates |
| for (auto const& cert : known_certificates_) { |
| if (X509_STORE_add_cert(SSL_CTX_get_cert_store(ssl_ctx_.get()), cert.get()) == 0) { |
| LOG(ERROR) << "Unable to add certificates into the X509_STORE"; |
| return TlsError::UnknownFailure; |
| } |
| } |
| |
| // Custom certificate verification |
| if (cert_verify_cb_) { |
| SSL_CTX_set_cert_verify_callback(ssl_ctx_.get(), SSLSetCertVerifyCb, this); |
| } |
| |
| // set select certificate callback, if any. |
| if (set_cert_cb_) { |
| SSL_CTX_set_cert_cb(ssl_ctx_.get(), SSLSetCertCb, this); |
| } |
| |
| // Server-allowed client CA list |
| if (ca_list_ != nullptr) { |
| bssl::UniquePtr<STACK_OF(X509_NAME)> names(SSL_dup_CA_list(ca_list_.get())); |
| SSL_CTX_set_client_CA_list(ssl_ctx_.get(), names.release()); |
| } |
| |
| // Register our certificate and private key. |
| std::vector<CRYPTO_BUFFER*> cert_chain = { |
| cert_.get(), |
| }; |
| if (!SSL_CTX_set_chain_and_key(ssl_ctx_.get(), cert_chain.data(), cert_chain.size(), |
| priv_key_.get(), nullptr)) { |
| LOG(ERROR) << "Unable to register the certificate chain file and private key [" |
| << SSLErrorString() << "]"; |
| Invalidate(); |
| return TlsError::UnknownFailure; |
| } |
| |
| SSL_CTX_set_verify(ssl_ctx_.get(), SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); |
| |
| // Okay! Let's try to do the handshake! |
| ssl_.reset(SSL_new(ssl_ctx_.get())); |
| if (!SSL_set_fd(ssl_.get(), fd_.get())) { |
| LOG(ERROR) << "SSL_set_fd failed. [" << SSLErrorString() << "]"; |
| return TlsError::UnknownFailure; |
| } |
| switch (role_) { |
| case Role::Server: |
| err = SSL_accept(ssl_.get()); |
| break; |
| case Role::Client: |
| err = SSL_connect(ssl_.get()); |
| break; |
| } |
| if (err != 1) { |
| LOG(ERROR) << "Handshake failed in SSL_accept/SSL_connect [" << SSLErrorString() << "]"; |
| auto sslerr = ERR_get_error(); |
| Invalidate(); |
| return GetFailureReason(sslerr); |
| } |
| |
| if (client_verify_post_handshake_ && role_ == Role::Client) { |
| uint8_t check; |
| // Try to peek one byte for any failures. This assumes on success that |
| // the server actually sends something. |
| err = SSL_peek(ssl_.get(), &check, 1); |
| if (err <= 0) { |
| LOG(ERROR) << "Post-handshake SSL_peek failed [" << SSLErrorString() << "]"; |
| auto sslerr = ERR_get_error(); |
| Invalidate(); |
| return GetFailureReason(sslerr); |
| } |
| } |
| |
| LOG(INFO) << "Handshake succeeded."; |
| return TlsError::Success; |
| } |
| |
| void TlsConnectionImpl::Invalidate() { |
| ssl_.reset(); |
| ssl_ctx_.reset(); |
| } |
| |
| std::vector<uint8_t> TlsConnectionImpl::ReadFully(size_t size) { |
| std::vector<uint8_t> buf(size); |
| if (!ReadFully(buf.data(), buf.size())) { |
| return {}; |
| } |
| |
| return buf; |
| } |
| |
| bool TlsConnectionImpl::ReadFully(void* buf, size_t size) { |
| CHECK_GT(size, 0U); |
| if (!ssl_) { |
| LOG(ERROR) << "Tried to read on a null SSL connection"; |
| return false; |
| } |
| |
| size_t offset = 0; |
| uint8_t* p8 = reinterpret_cast<uint8_t*>(buf); |
| while (size > 0) { |
| int bytes_read = |
| SSL_read(ssl_.get(), p8 + offset, std::min(static_cast<size_t>(INT_MAX), size)); |
| if (bytes_read <= 0) { |
| LOG(WARNING) << "SSL_read failed [" << SSLErrorString() << "]"; |
| return false; |
| } |
| size -= bytes_read; |
| offset += bytes_read; |
| } |
| return true; |
| } |
| |
| bool TlsConnectionImpl::WriteFully(std::string_view data) { |
| CHECK(!data.empty()); |
| if (!ssl_) { |
| LOG(ERROR) << "Tried to read on a null SSL connection"; |
| return false; |
| } |
| |
| while (!data.empty()) { |
| int bytes_out = SSL_write(ssl_.get(), data.data(), |
| std::min(static_cast<size_t>(INT_MAX), data.size())); |
| if (bytes_out <= 0) { |
| LOG(WARNING) << "SSL_write failed [" << SSLErrorString() << "]"; |
| return false; |
| } |
| data = data.substr(bytes_out); |
| } |
| return true; |
| } |
| } // namespace |
| |
| // static |
| std::unique_ptr<TlsConnection> TlsConnection::Create(TlsConnection::Role role, |
| std::string_view cert, |
| std::string_view priv_key, borrowed_fd fd) { |
| CHECK(!cert.empty()); |
| CHECK(!priv_key.empty()); |
| |
| return std::make_unique<TlsConnectionImpl>(role, cert, priv_key, fd); |
| } |
| |
| // static |
| bool TlsConnection::SetCertAndKey(SSL* ssl, std::string_view cert, std::string_view priv_key) { |
| CHECK(ssl); |
| // Note: declaring these in local scope is okay because |
| // SSL_set_chain_and_key will increase the refcount (bssl::UpRef). |
| auto x509_cert = TlsConnectionImpl::BufferFromPEM(cert); |
| auto evp_pkey = TlsConnectionImpl::EvpPkeyFromPEM(priv_key); |
| if (x509_cert == nullptr || evp_pkey == nullptr) { |
| return false; |
| } |
| |
| std::vector<CRYPTO_BUFFER*> cert_chain = { |
| x509_cert.get(), |
| }; |
| if (!SSL_set_chain_and_key(ssl, cert_chain.data(), cert_chain.size(), evp_pkey.get(), |
| nullptr)) { |
| LOG(ERROR) << "SSL_set_chain_and_key failed"; |
| return false; |
| } |
| |
| return true; |
| } |
| |
| } // namespace tls |
| } // namespace adb |