adb: make disconnect stop reconnection immediately.

Make `adb disconnect` remove transports immediately, instead of on
their next reconnection cycle.

Test: adb connect unreachable:12345; adb devices; adb disconnect; adb devices
Change-Id: I35c8b57344e847575596d09216fc636be47dde64
diff --git a/test_adb.py b/test_adb.py
index d4c98e4..86c13d0 100755
--- a/test_adb.py
+++ b/test_adb.py
@@ -28,6 +28,7 @@
 import struct
 import subprocess
 import threading
+import time
 import unittest
 
 
@@ -90,7 +91,7 @@
     server_thread.start()
 
     try:
-        yield port
+        yield port, writesock
     finally:
         writesock.close()
         server_thread.join()
@@ -120,7 +121,7 @@
 def adb_server():
     """Context manager for an ADB server.
 
-    This creates an ADB server and returns the port it"s listening on.
+    This creates an ADB server and returns the port it's listening on.
     """
 
     port = 5038
@@ -342,7 +343,7 @@
         Bug: http://b/78991667
         """
         with adb_server() as server_port:
-            with fake_adbd() as port:
+            with fake_adbd() as (port, _):
                 serial = "emulator-{}".format(port - 1)
                 # Ensure that the emulator is not there.
                 try:
@@ -380,7 +381,7 @@
         """
         for protocol in (socket.AF_INET, socket.AF_INET6):
             try:
-                with fake_adbd(protocol=protocol) as port:
+                with fake_adbd(protocol=protocol) as (port, _):
                     serial = "localhost:{}".format(port)
                     with adb_connect(self, serial):
                         pass
@@ -391,7 +392,7 @@
     def test_already_connected(self):
         """Ensure that an already-connected device stays connected."""
 
-        with fake_adbd() as port:
+        with fake_adbd() as (port, _):
             serial = "localhost:{}".format(port)
             with adb_connect(self, serial):
                 # b/31250450: this always returns 0 but probably shouldn't.
@@ -403,7 +404,7 @@
     def test_reconnect(self):
         """Ensure that a disconnected device reconnects."""
 
-        with fake_adbd() as port:
+        with fake_adbd() as (port, _):
             serial = "localhost:{}".format(port)
             with adb_connect(self, serial):
                 output = subprocess.check_output(["adb", "-s", serial,
@@ -439,6 +440,46 @@
                         "error: device '{}' not found".format(serial).encode("utf8"))
 
 
+class DisconnectionTest(unittest.TestCase):
+    """Tests for adb disconnect."""
+
+    def test_disconnect(self):
+        """Ensure that `adb disconnect` takes effect immediately."""
+
+        def _devices(port):
+            output = subprocess.check_output(["adb", "-P", str(port), "devices"])
+            return [x.split("\t") for x in output.decode("utf8").strip().splitlines()[1:]]
+
+        with adb_server() as server_port:
+            with fake_adbd() as (port, sock):
+                device_name = "localhost:{}".format(port)
+                output = subprocess.check_output(["adb", "-P", str(server_port),
+                                                  "connect", device_name])
+                self.assertEqual(output.strip(),
+                                  "connected to {}".format(device_name).encode("utf8"))
+
+
+                self.assertEqual(_devices(server_port), [[device_name, "device"]])
+
+                # Send a deliberately malformed packet to make the device go offline.
+                packet = struct.pack("IIIIII", 0, 0, 0, 0, 0, 0)
+                sock.sendall(packet)
+
+                # Wait a bit.
+                time.sleep(0.1)
+
+                self.assertEqual(_devices(server_port), [[device_name, "offline"]])
+
+                # Disconnect the device.
+                output = subprocess.check_output(["adb", "-P", str(server_port),
+                                                  "disconnect", device_name])
+
+                # Wait a bit.
+                time.sleep(0.1)
+
+                self.assertEqual(_devices(server_port), [])
+
+
 def main():
     """Main entrypoint."""
     random.seed(0)