Extend netlink class to fit the data structure

In order to get the tcp_info via netlink socket from kernel,
NetworkStack needs to use netlink class to pack and parse the
InetDiagReq. Current design hardcodes ididag_ext field in
InetDiagReqV2. The structure is also not allowed to take null
id to not to specify certain socket. Update the constructor and
backward support exising constructor.

Bug: 136162280
Test: atest FrameworksNetTests NetworkStackTests
Change-Id: Id66da1797da183ae3d99073f80bad1df929946dc
diff --git a/services/net/java/android/net/netlink/InetDiagMessage.java b/services/net/java/android/net/netlink/InetDiagMessage.java
index 31a2556..ca07630 100644
--- a/services/net/java/android/net/netlink/InetDiagMessage.java
+++ b/services/net/java/android/net/netlink/InetDiagMessage.java
@@ -26,6 +26,7 @@
 import static android.system.OsConstants.IPPROTO_UDP;
 import static android.system.OsConstants.NETLINK_INET_DIAG;
 
+import android.annotation.Nullable;
 import android.net.util.SocketUtils;
 import android.system.ErrnoException;
 import android.util.Log;
@@ -53,7 +54,35 @@
     private static final int TIMEOUT_MS = 500;
 
     public static byte[] InetDiagReqV2(int protocol, InetSocketAddress local,
-                                       InetSocketAddress remote, int family, short flags) {
+            InetSocketAddress remote, int family, short flags) {
+        return InetDiagReqV2(protocol, local, remote, family, flags, 0 /* pad */,
+                0 /* idiagExt */, StructInetDiagReqV2.INET_DIAG_REQ_V2_ALL_STATES);
+    }
+
+    /**
+     * Construct an inet_diag_req_v2 message. This method will throw {@code NullPointerException}
+     * if local and remote are not both null or both non-null.
+     *
+     * @param protocol the request protocol type. This should be set to one of IPPROTO_TCP,
+     *                 IPPROTO_UDP, or IPPROTO_UDPLITE.
+     * @param local local socket address of the target socket. This will be packed into a
+     *              {@Code StructInetDiagSockId}. Request to diagnose for all sockets if both of
+     *              local or remote address is null.
+     * @param remote remote socket address of the target socket. This will be packed into a
+     *              {@Code StructInetDiagSockId}. Request to diagnose for all sockets if both of
+     *              local or remote address is null.
+     * @param family the ip family of the request message. This should be set to either AF_INET or
+     *               AF_INET6 for IPv4 or IPv6 sockets respectively.
+     * @param flags message flags. See <linux_src>/include/uapi/linux/netlink.h.
+     * @param pad for raw socket protocol specification.
+     * @param idiagExt a set of flags defining what kind of extended information to report.
+     * @param state a bit mask that defines a filter of socket states.
+     *
+     * @return bytes array representation of the message
+     **/
+    public static byte[] InetDiagReqV2(int protocol, @Nullable InetSocketAddress local,
+            @Nullable InetSocketAddress remote, int family, short flags, int pad, int idiagExt,
+            int state) throws NullPointerException {
         final byte[] bytes = new byte[StructNlMsgHdr.STRUCT_SIZE + StructInetDiagReqV2.STRUCT_SIZE];
         final ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
         byteBuffer.order(ByteOrder.nativeOrder());
@@ -63,9 +92,9 @@
         nlMsgHdr.nlmsg_type = SOCK_DIAG_BY_FAMILY;
         nlMsgHdr.nlmsg_flags = flags;
         nlMsgHdr.pack(byteBuffer);
+        final StructInetDiagReqV2 inetDiagReqV2 =
+                new StructInetDiagReqV2(protocol, local, remote, family, pad, idiagExt, state);
 
-        final StructInetDiagReqV2 inetDiagReqV2 = new StructInetDiagReqV2(protocol, local, remote,
-                family);
         inetDiagReqV2.pack(byteBuffer);
         return bytes;
     }
diff --git a/services/net/java/android/net/netlink/StructInetDiagReqV2.java b/services/net/java/android/net/netlink/StructInetDiagReqV2.java
index 49a9325..2268113 100644
--- a/services/net/java/android/net/netlink/StructInetDiagReqV2.java
+++ b/services/net/java/android/net/netlink/StructInetDiagReqV2.java
@@ -16,10 +16,10 @@
 
 package android.net.netlink;
 
-import static java.nio.ByteOrder.BIG_ENDIAN;
+import android.annotation.Nullable;
+
 import java.net.InetSocketAddress;
 import java.nio.ByteBuffer;
-import java.nio.ByteOrder;
 
 /**
  * struct inet_diag_req_v2
@@ -40,41 +40,58 @@
 public class StructInetDiagReqV2 {
     public static final int STRUCT_SIZE = 8 + StructInetDiagSockId.STRUCT_SIZE;
 
-    private final byte sdiag_family;
-    private final byte sdiag_protocol;
-    private final StructInetDiagSockId id;
-    private final int INET_DIAG_REQ_V2_ALL_STATES = (int) 0xffffffff;
-
+    private final byte mSdiagFamily;
+    private final byte mSdiagProtocol;
+    private final byte mIdiagExt;
+    private final byte mPad;
+    private final StructInetDiagSockId mId;
+    private final int mState;
+    public static final int INET_DIAG_REQ_V2_ALL_STATES = (int) 0xffffffff;
 
     public StructInetDiagReqV2(int protocol, InetSocketAddress local, InetSocketAddress remote,
-                               int family) {
-        sdiag_family = (byte) family;
-        sdiag_protocol = (byte) protocol;
-        id = new StructInetDiagSockId(local, remote);
+            int family) {
+        this(protocol, local, remote, family, 0 /* pad */, 0 /* extension */,
+                INET_DIAG_REQ_V2_ALL_STATES);
+    }
+
+    public StructInetDiagReqV2(int protocol, @Nullable InetSocketAddress local,
+            @Nullable InetSocketAddress remote, int family, int pad, int extension, int state)
+            throws NullPointerException {
+        mSdiagFamily = (byte) family;
+        mSdiagProtocol = (byte) protocol;
+        // Request for all sockets if no specific socket is requested. Specify the local and remote
+        // socket address information for target request socket.
+        if ((local == null) != (remote == null)) {
+            throw new NullPointerException("Local and remote must be both null or both non-null");
+        }
+        mId = ((local != null && remote != null) ? new StructInetDiagSockId(local, remote) : null);
+        mPad = (byte) pad;
+        mIdiagExt = (byte) extension;
+        mState = state;
     }
 
     public void pack(ByteBuffer byteBuffer) {
         // The ByteOrder must have already been set by the caller.
-        byteBuffer.put((byte) sdiag_family);
-        byteBuffer.put((byte) sdiag_protocol);
-        byteBuffer.put((byte) 0);
-        byteBuffer.put((byte) 0);
-        byteBuffer.putInt(INET_DIAG_REQ_V2_ALL_STATES);
-        id.pack(byteBuffer);
+        byteBuffer.put((byte) mSdiagFamily);
+        byteBuffer.put((byte) mSdiagProtocol);
+        byteBuffer.put((byte) mIdiagExt);
+        byteBuffer.put((byte) mPad);
+        byteBuffer.putInt(mState);
+        if (mId != null) mId.pack(byteBuffer);
     }
 
     @Override
     public String toString() {
-        final String familyStr = NetlinkConstants.stringForAddressFamily(sdiag_family);
-        final String protocolStr = NetlinkConstants.stringForAddressFamily(sdiag_protocol);
+        final String familyStr = NetlinkConstants.stringForAddressFamily(mSdiagFamily);
+        final String protocolStr = NetlinkConstants.stringForAddressFamily(mSdiagProtocol);
 
         return "StructInetDiagReqV2{ "
                 + "sdiag_family{" + familyStr + "}, "
                 + "sdiag_protocol{" + protocolStr + "}, "
-                + "idiag_ext{" + 0 + ")}, "
-                + "pad{" + 0 + "}, "
-                + "idiag_states{" + Integer.toHexString(INET_DIAG_REQ_V2_ALL_STATES) + "}, "
-                + id.toString()
+                + "idiag_ext{" + mIdiagExt + ")}, "
+                + "pad{" + mPad + "}, "
+                + "idiag_states{" + Integer.toHexString(mState) + "}, "
+                + ((mId != null) ? mId.toString() : "inet_diag_sockid=null")
                 + "}";
     }
 }
diff --git a/tests/net/java/android/net/netlink/InetDiagSocketTest.java b/tests/net/java/android/net/netlink/InetDiagSocketTest.java
index 46e27c1..1f2bb0a 100644
--- a/tests/net/java/android/net/netlink/InetDiagSocketTest.java
+++ b/tests/net/java/android/net/netlink/InetDiagSocketTest.java
@@ -30,6 +30,7 @@
 import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 import android.app.Instrumentation;
 import android.content.Context;
@@ -283,6 +284,107 @@
         assertArrayEquals(INET_DIAG_REQ_V2_TCP_INET6_BYTES, msg);
     }
 
+    // Hexadecimal representation of InetDiagReqV2 request with extension, INET_DIAG_INFO.
+    private static final String INET_DIAG_REQ_V2_TCP_INET_INET_DIAG_HEX =
+            // struct nlmsghdr
+            "48000000" +     // length = 72
+            "1400" +         // type = SOCK_DIAG_BY_FAMILY
+            "0100" +         // flags = NLM_F_REQUEST
+            "00000000" +     // seqno
+            "00000000" +     // pid (0 == kernel)
+            // struct inet_diag_req_v2
+            "02" +           // family = AF_INET
+            "06" +           // protcol = IPPROTO_TCP
+            "02" +           // idiag_ext = INET_DIAG_INFO
+            "00" +           // pad
+            "ffffffff" +   // idiag_states
+            // inet_diag_sockid
+            "3039" +         // idiag_sport = 12345
+            "d431" +         // idiag_dport = 54321
+            "01020304000000000000000000000000" + // idiag_src = 1.2.3.4
+            "08080404000000000000000000000000" + // idiag_dst = 8.8.4.4
+            "00000000" +     // idiag_if
+            "ffffffffffffffff"; // idiag_cookie = INET_DIAG_NOCOOKIE
+
+    private static final byte[] INET_DIAG_REQ_V2_TCP_INET_INET_DIAG_BYTES =
+            HexEncoding.decode(INET_DIAG_REQ_V2_TCP_INET_INET_DIAG_HEX.toCharArray(), false);
+    private static final int TCP_ALL_STATES = 0xffffffff;
+    @Test
+    public void testInetDiagReqV2TcpInetWithExt() throws Exception {
+        InetSocketAddress local = new InetSocketAddress(
+                InetAddress.getByName("1.2.3.4"), 12345);
+        InetSocketAddress remote = new InetSocketAddress(InetAddress.getByName("8.8.4.4"),
+                54321);
+        byte[] msg = InetDiagMessage.InetDiagReqV2(IPPROTO_TCP, local, remote, AF_INET,
+                NLM_F_REQUEST, 0 /* pad */, 2 /* idiagExt */, TCP_ALL_STATES);
+
+        assertArrayEquals(INET_DIAG_REQ_V2_TCP_INET_INET_DIAG_BYTES, msg);
+
+        local = new InetSocketAddress(
+                InetAddress.getByName("fe80::86c9:b2ff:fe6a:ed4b"), 42462);
+        remote = new InetSocketAddress(InetAddress.getByName("8.8.8.8"),
+                47473);
+        msg = InetDiagMessage.InetDiagReqV2(IPPROTO_TCP, local, remote, AF_INET6,
+                NLM_F_REQUEST, 0 /* pad */, 0 /* idiagExt */, TCP_ALL_STATES);
+
+        assertArrayEquals(INET_DIAG_REQ_V2_TCP_INET6_BYTES, msg);
+    }
+
+    // Hexadecimal representation of InetDiagReqV2 request with no socket specified.
+    private static final String INET_DIAG_REQ_V2_TCP_INET6_NO_ID_SPECIFIED_HEX =
+            // struct nlmsghdr
+            "48000000" +     // length = 72
+            "1400" +         // type = SOCK_DIAG_BY_FAMILY
+            "0100" +         // flags = NLM_F_REQUEST
+            "00000000" +     // seqno
+            "00000000" +     // pid (0 == kernel)
+            // struct inet_diag_req_v2
+            "0a" +           // family = AF_INET6
+            "06" +           // protcol = IPPROTO_TCP
+            "00" +           // idiag_ext
+            "00" +           // pad
+            "ffffffff" +     // idiag_states
+            // inet_diag_sockid
+            "0000" +         // idiag_sport
+            "0000" +         // idiag_dport
+            "00000000000000000000000000000000" + // idiag_src
+            "00000000000000000000000000000000" + // idiag_dst
+            "00000000" +     // idiag_if
+            "0000000000000000"; // idiag_cookie
+
+    private static final byte[] INET_DIAG_REQ_V2_TCP_INET6_NO_ID_SPECIFED_BYTES =
+            HexEncoding.decode(INET_DIAG_REQ_V2_TCP_INET6_NO_ID_SPECIFIED_HEX.toCharArray(), false);
+
+    @Test
+    public void testInetDiagReqV2TcpInet6NoIdSpecified() throws Exception {
+        InetSocketAddress local = new InetSocketAddress(
+                InetAddress.getByName("fe80::fe6a:ed4b"), 12345);
+        InetSocketAddress remote = new InetSocketAddress(InetAddress.getByName("8.8.4.4"),
+                54321);
+        // Verify no socket specified if either local or remote socket address is null.
+        byte[] msgExt = InetDiagMessage.InetDiagReqV2(IPPROTO_TCP, null, null, AF_INET6,
+                NLM_F_REQUEST, 0 /* pad */, 0 /* idiagExt */, TCP_ALL_STATES);
+        byte[] msg;
+        try {
+            msg = InetDiagMessage.InetDiagReqV2(IPPROTO_TCP, null, remote, AF_INET6,
+                    NLM_F_REQUEST);
+            fail("Both remote and local should be null, expected UnknownHostException");
+        } catch (NullPointerException e) {
+        }
+
+        try {
+            msg = InetDiagMessage.InetDiagReqV2(IPPROTO_TCP, local, null, AF_INET6,
+                    NLM_F_REQUEST, 0 /* pad */, 0 /* idiagExt */, TCP_ALL_STATES);
+            fail("Both remote and local should be null, expected UnknownHostException");
+        } catch (NullPointerException e) {
+        }
+
+        msg = InetDiagMessage.InetDiagReqV2(IPPROTO_TCP, null, null, AF_INET6,
+                NLM_F_REQUEST, 0 /* pad */, 0 /* idiagExt */, TCP_ALL_STATES);
+        assertArrayEquals(INET_DIAG_REQ_V2_TCP_INET6_NO_ID_SPECIFED_BYTES, msg);
+        assertArrayEquals(INET_DIAG_REQ_V2_TCP_INET6_NO_ID_SPECIFED_BYTES, msgExt);
+    }
+
     // Hexadecimal representation of InetDiagReqV2 request.
     private static final String INET_DIAG_MSG_HEX =
             // struct nlmsghdr