Merge "[view-compiler] DexBuilder: Add more instructions" am: 88f0e734b7 am: a310f87ffa
am: 19b6b983b8

Change-Id: I6273dcba4021657f068868cad072156bb15d600e
diff --git a/startop/view_compiler/dex_builder.cc b/startop/view_compiler/dex_builder.cc
index 906d64c..94879d0 100644
--- a/startop/view_compiler/dex_builder.cc
+++ b/startop/view_compiler/dex_builder.cc
@@ -61,18 +61,46 @@
     case Instruction::Op::kInvokeDirect:
       out << "kInvokeDirect";
       return out;
+    case Instruction::Op::kInvokeStatic:
+      out << "kInvokeStatic";
+      return out;
+    case Instruction::Op::kInvokeInterface:
+      out << "kInvokeInterface";
+      return out;
     case Instruction::Op::kBindLabel:
       out << "kBindLabel";
       return out;
     case Instruction::Op::kBranchEqz:
       out << "kBranchEqz";
       return out;
+    case Instruction::Op::kBranchNEqz:
+      out << "kBranchNEqz";
+      return out;
     case Instruction::Op::kNew:
       out << "kNew";
       return out;
   }
 }
 
+std::ostream& operator<<(std::ostream& out, const Value& value) {
+  if (value.is_register()) {
+    out << "Register(" << value.value() << ")";
+  } else if (value.is_parameter()) {
+    out << "Parameter(" << value.value() << ")";
+  } else if (value.is_immediate()) {
+    out << "Immediate(" << value.value() << ")";
+  } else if (value.is_string()) {
+    out << "String(" << value.value() << ")";
+  } else if (value.is_label()) {
+    out << "Label(" << value.value() << ")";
+  } else if (value.is_type()) {
+    out << "Type(" << value.value() << ")";
+  } else {
+    out << "UnknownValue";
+  }
+  return out;
+}
+
 void* TrackingAllocator::Allocate(size_t size) {
   std::unique_ptr<uint8_t[]> buffer = std::make_unique<uint8_t[]>(size);
   void* raw_buffer = buffer.get();
@@ -289,10 +317,16 @@
       return EncodeInvoke(instruction, art::Instruction::INVOKE_VIRTUAL);
     case Instruction::Op::kInvokeDirect:
       return EncodeInvoke(instruction, art::Instruction::INVOKE_DIRECT);
+    case Instruction::Op::kInvokeStatic:
+      return EncodeInvoke(instruction, art::Instruction::INVOKE_STATIC);
+    case Instruction::Op::kInvokeInterface:
+      return EncodeInvoke(instruction, art::Instruction::INVOKE_INTERFACE);
     case Instruction::Op::kBindLabel:
       return BindLabel(instruction.args()[0]);
     case Instruction::Op::kBranchEqz:
       return EncodeBranch(art::Instruction::IF_EQZ, instruction);
+    case Instruction::Op::kBranchNEqz:
+      return EncodeBranch(art::Instruction::IF_NEZ, instruction);
     case Instruction::Op::kNew:
       return EncodeNew(instruction);
   }
@@ -353,7 +387,9 @@
 
   // If there is a return value, add a move-result instruction
   if (instruction.dest().has_value()) {
-    Encode11x(art::Instruction::MOVE_RESULT, RegisterValue(*instruction.dest()));
+    Encode11x(instruction.result_is_object() ? art::Instruction::MOVE_RESULT_OBJECT
+                                             : art::Instruction::MOVE_RESULT,
+              RegisterValue(*instruction.dest()));
   }
 
   max_args_ = std::max(max_args_, instruction.args().size());
@@ -447,7 +483,7 @@
     auto& ir_node = dex_file_->methods_map[new_index];
     SLICER_CHECK(ir_node == nullptr);
     ir_node = decl;
-    decl->orig_index = new_index;
+    decl->orig_index = decl->index = new_index;
 
     entry = {id, decl};
   }
diff --git a/startop/view_compiler/dex_builder.h b/startop/view_compiler/dex_builder.h
index adf82bf..45596ac 100644
--- a/startop/view_compiler/dex_builder.h
+++ b/startop/view_compiler/dex_builder.h
@@ -147,8 +147,11 @@
     kMove,
     kInvokeVirtual,
     kInvokeDirect,
+    kInvokeStatic,
+    kInvokeInterface,
     kBindLabel,
     kBranchEqz,
+    kBranchNEqz,
     kNew
   };
 
@@ -163,19 +166,53 @@
   // For most instructions, which take some number of arguments and have an optional return value.
   template <typename... T>
   static inline Instruction OpWithArgs(Op opcode, std::optional<const Value> dest, T... args) {
-    return Instruction{opcode, /*method_id*/ 0, dest, args...};
+    return Instruction{opcode, /*method_id=*/0, /*result_is_object=*/false, dest, args...};
   }
   // For method calls.
   template <typename... T>
   static inline Instruction InvokeVirtual(size_t method_id, std::optional<const Value> dest,
                                           Value this_arg, T... args) {
-    return Instruction{Op::kInvokeVirtual, method_id, dest, this_arg, args...};
+    return Instruction{
+        Op::kInvokeVirtual, method_id, /*result_is_object=*/false, dest, this_arg, args...};
+  }
+  // Returns an object
+  template <typename... T>
+  static inline Instruction InvokeVirtualObject(size_t method_id, std::optional<const Value> dest,
+                                                Value this_arg, T... args) {
+    return Instruction{
+        Op::kInvokeVirtual, method_id, /*result_is_object=*/true, dest, this_arg, args...};
   }
   // For direct calls (basically, constructors).
   template <typename... T>
   static inline Instruction InvokeDirect(size_t method_id, std::optional<const Value> dest,
                                          Value this_arg, T... args) {
-    return Instruction{Op::kInvokeDirect, method_id, dest, this_arg, args...};
+    return Instruction{
+        Op::kInvokeDirect, method_id, /*result_is_object=*/false, dest, this_arg, args...};
+  }
+  // Returns an object
+  template <typename... T>
+  static inline Instruction InvokeDirectObject(size_t method_id, std::optional<const Value> dest,
+                                               Value this_arg, T... args) {
+    return Instruction{
+        Op::kInvokeDirect, method_id, /*result_is_object=*/true, dest, this_arg, args...};
+  }
+  // For static calls.
+  template <typename... T>
+  static inline Instruction InvokeStatic(size_t method_id, std::optional<const Value> dest,
+                                         T... args) {
+    return Instruction{Op::kInvokeStatic, method_id, /*result_is_object=*/false, dest, args...};
+  }
+  // Returns an object
+  template <typename... T>
+  static inline Instruction InvokeStaticObject(size_t method_id, std::optional<const Value> dest,
+                                               T... args) {
+    return Instruction{Op::kInvokeStatic, method_id, /*result_is_object=*/true, dest, args...};
+  }
+  // For static calls.
+  template <typename... T>
+  static inline Instruction InvokeInterface(size_t method_id, std::optional<const Value> dest,
+                                            T... args) {
+    return Instruction{Op::kInvokeInterface, method_id, /*result_is_object=*/false, dest, args...};
   }
 
   ///////////////
@@ -184,21 +221,27 @@
 
   Op opcode() const { return opcode_; }
   size_t method_id() const { return method_id_; }
+  bool result_is_object() const { return result_is_object_; }
   const std::optional<const Value>& dest() const { return dest_; }
   const std::vector<const Value>& args() const { return args_; }
 
  private:
   inline Instruction(Op opcode, size_t method_id, std::optional<const Value> dest)
-      : opcode_{opcode}, method_id_{method_id}, dest_{dest}, args_{} {}
+      : opcode_{opcode}, method_id_{method_id}, result_is_object_{false}, dest_{dest}, args_{} {}
 
   template <typename... T>
-  inline constexpr Instruction(Op opcode, size_t method_id, std::optional<const Value> dest,
-                               T... args)
-      : opcode_{opcode}, method_id_{method_id}, dest_{dest}, args_{args...} {}
+  inline constexpr Instruction(Op opcode, size_t method_id, bool result_is_object,
+                               std::optional<const Value> dest, T... args)
+      : opcode_{opcode},
+        method_id_{method_id},
+        result_is_object_{result_is_object},
+        dest_{dest},
+        args_{args...} {}
 
   const Op opcode_;
   // The index of the method to invoke, for kInvokeVirtual and similar opcodes.
   const size_t method_id_{0};
+  const bool result_is_object_;
   const std::optional<const Value> dest_;
   const std::vector<const Value> args_;
 };
@@ -244,6 +287,8 @@
 
   // TODO: add builders for more instructions
 
+  DexBuilder* dex_file() const { return dex_; }
+
  private:
   void EncodeInstructions();
   void EncodeInstruction(const Instruction& instruction);
diff --git a/startop/view_compiler/dex_builder_test/src/android/startop/test/DexBuilderTest.java b/startop/view_compiler/dex_builder_test/src/android/startop/test/DexBuilderTest.java
index e20f3a9..1508ad9e 100644
--- a/startop/view_compiler/dex_builder_test/src/android/startop/test/DexBuilderTest.java
+++ b/startop/view_compiler/dex_builder_test/src/android/startop/test/DexBuilderTest.java
@@ -84,6 +84,15 @@
   }
 
   @Test
+  public void returnIfNotZero() throws Exception {
+    ClassLoader loader = loadDexFile("simple.dex");
+    Class clazz = loader.loadClass("android.startop.test.testcases.SimpleTests");
+    Method method = clazz.getMethod("returnIfNotZero", int.class);
+    Assert.assertEquals(3, method.invoke(null, 0));
+    Assert.assertEquals(5, method.invoke(null, 17));
+  }
+
+  @Test
   public void backwardsBranch() throws Exception {
     ClassLoader loader = loadDexFile("simple.dex");
     Class clazz = loader.loadClass("android.startop.test.testcases.SimpleTests");
@@ -124,4 +133,22 @@
     Assert.assertEquals("b", method.invoke(null, 0));
     Assert.assertEquals("a", method.invoke(null, 1));
   }
+
+  @Test
+  public void invokeStaticReturnObject() throws Exception {
+    ClassLoader loader = loadDexFile("simple.dex");
+    Class clazz = loader.loadClass("android.startop.test.testcases.SimpleTests");
+    Method method = clazz.getMethod("invokeStaticReturnObject", int.class, int.class);
+    Assert.assertEquals("10", method.invoke(null, 10, 10));
+    Assert.assertEquals("a", method.invoke(null, 10, 16));
+    Assert.assertEquals("5", method.invoke(null, 5, 16));
+  }
+
+  @Test
+  public void invokeVirtualReturnObject() throws Exception {
+    ClassLoader loader = loadDexFile("simple.dex");
+    Class clazz = loader.loadClass("android.startop.test.testcases.SimpleTests");
+    Method method = clazz.getMethod("invokeVirtualReturnObject", String.class, int.class);
+    Assert.assertEquals("bc", method.invoke(null, "abc", 1));
+  }
 }
diff --git a/startop/view_compiler/dex_testcase_generator.cc b/startop/view_compiler/dex_testcase_generator.cc
index e2bf43bc..2781aa5 100644
--- a/startop/view_compiler/dex_testcase_generator.cc
+++ b/startop/view_compiler/dex_testcase_generator.cc
@@ -108,6 +108,27 @@
   }
   returnIfZero.Encode();
 
+  // int returnIfNotZero(int x) { if (x != 0) { return 5; } else { return 3; } }
+  MethodBuilder returnIfNotZero{cbuilder.CreateMethod(
+      "returnIfNotZero", Prototype{TypeDescriptor::Int(), TypeDescriptor::Int()})};
+  {
+    Value resultIfNotZero{returnIfNotZero.MakeRegister()};
+    Value else_target{returnIfNotZero.MakeLabel()};
+    returnIfNotZero.AddInstruction(Instruction::OpWithArgs(
+        Instruction::Op::kBranchNEqz, /*dest=*/{}, Value::Parameter(0), else_target));
+    // else branch
+    returnIfNotZero.BuildConst4(resultIfNotZero, 3);
+    returnIfNotZero.AddInstruction(
+        Instruction::OpWithArgs(Instruction::Op::kReturn, /*dest=*/{}, resultIfNotZero));
+    // then branch
+    returnIfNotZero.AddInstruction(
+        Instruction::OpWithArgs(Instruction::Op::kBindLabel, /*dest=*/{}, else_target));
+    returnIfNotZero.BuildConst4(resultIfNotZero, 5);
+    returnIfNotZero.AddInstruction(
+        Instruction::OpWithArgs(Instruction::Op::kReturn, /*dest=*/{}, resultIfNotZero));
+  }
+  returnIfNotZero.Encode();
+
   // Make sure backwards branches work too.
   //
   // Pseudo code for test:
@@ -216,6 +237,38 @@
     method.Encode();
   }(returnStringIfZeroBA);
 
+  // Make sure we can invoke static methods that return an object
+  // String invokeStaticReturnObject(int n, int radix) { return java.lang.Integer.toString(n,
+  // radix); }
+  MethodBuilder invokeStaticReturnObject{
+      cbuilder.CreateMethod("invokeStaticReturnObject",
+                            Prototype{string_type, TypeDescriptor::Int(), TypeDescriptor::Int()})};
+  [&](MethodBuilder& method) {
+    Value result{method.MakeRegister()};
+    MethodDeclData to_string{dex_file.GetOrDeclareMethod(
+        TypeDescriptor::FromClassname("java.lang.Integer"),
+        "toString",
+        Prototype{string_type, TypeDescriptor::Int(), TypeDescriptor::Int()})};
+    method.AddInstruction(Instruction::InvokeStaticObject(
+        to_string.id, result, Value::Parameter(0), Value::Parameter(1)));
+    method.BuildReturn(result, /*is_object=*/true);
+    method.Encode();
+  }(invokeStaticReturnObject);
+
+  // Make sure we can invoke virtual methods that return an object
+  // String invokeVirtualReturnObject(String s, int n) { return s.substring(n); }
+  MethodBuilder invokeVirtualReturnObject{cbuilder.CreateMethod(
+      "invokeVirtualReturnObject", Prototype{string_type, string_type, TypeDescriptor::Int()})};
+  [&](MethodBuilder& method) {
+    Value result{method.MakeRegister()};
+    MethodDeclData substring{dex_file.GetOrDeclareMethod(
+        string_type, "substring", Prototype{string_type, TypeDescriptor::Int()})};
+    method.AddInstruction(Instruction::InvokeVirtualObject(
+        substring.id, result, Value::Parameter(0), Value::Parameter(1)));
+    method.BuildReturn(result, /*is_object=*/true);
+    method.Encode();
+  }(invokeVirtualReturnObject);
+
   slicer::MemView image{dex_file.CreateImage()};
   std::ofstream out_file(outdir + "/simple.dex");
   out_file.write(image.ptr<const char>(), image.size());