Apply 3D transformations to gradient shaders.

This fixes only linear gradients. Sweep and radial gradients, as well as
bitmap shaders, will be fixed in a future commit.

Change-Id: I4eee4ff62e9bbf3b9339fc111a780167449ecfef
diff --git a/libs/hwui/ProgramCache.cpp b/libs/hwui/ProgramCache.cpp
index 3e9412c..d1a1b45 100644
--- a/libs/hwui/ProgramCache.cpp
+++ b/libs/hwui/ProgramCache.cpp
@@ -35,9 +35,6 @@
         "uniform mat4 transform;\n";
 const char* gVS_Header_Uniforms_HasGradient[3] = {
         // Linear
-        "uniform float gradientLength;\n"
-        "uniform vec2 gradient;\n"
-        "uniform vec2 gradientStart;\n"
         "uniform mat4 screenSpace;\n",
         // Circular
         "uniform vec2 gradientStart;\n"
@@ -69,8 +66,7 @@
         "    outTexCoords = texCoords;\n";
 const char* gVS_Main_OutGradient[3] = {
         // Linear
-        "    vec4 location = screenSpace * position;\n"
-        "    index = dot(location.xy - gradientStart, gradient) * gradientLength;\n",
+        "    index = (screenSpace * position).x;\n",
         // Circular
         "    vec4 location = screenSpace * position;\n"
         "    circular = (gradientMatrix * vec4(location.xy - gradientStart, 0.0, 0.0)).xy;\n",
diff --git a/libs/hwui/SkiaShader.cpp b/libs/hwui/SkiaShader.cpp
index 165c0da..83de2b2 100644
--- a/libs/hwui/SkiaShader.cpp
+++ b/libs/hwui/SkiaShader.cpp
@@ -153,11 +153,31 @@
 // Linear gradient shader
 ///////////////////////////////////////////////////////////////////////////////
 
+static void toUnitMatrix(const SkPoint pts[2], SkMatrix* matrix) {
+    SkVector vec = pts[1] - pts[0];
+    const float mag = vec.length();
+    const float inv = mag ? 1.0f / mag : 0;
+
+    vec.scale(inv);
+    matrix->setSinCos(-vec.fY, vec.fX, pts[0].fX, pts[0].fY);
+    matrix->postTranslate(-pts[0].fX, -pts[0].fY);
+    matrix->postScale(inv, inv);
+}
+
 SkiaLinearGradientShader::SkiaLinearGradientShader(float* bounds, uint32_t* colors,
         float* positions, int count, SkShader* key, SkShader::TileMode tileMode,
         SkMatrix* matrix, bool blend):
         SkiaShader(kLinearGradient, key, tileMode, tileMode, matrix, blend),
         mBounds(bounds), mColors(colors), mPositions(positions), mCount(count) {
+    SkPoint points[2];
+    points[0].set(bounds[0], bounds[1]);
+    points[1].set(bounds[2], bounds[3]);
+
+    SkMatrix unitMatrix;
+    toUnitMatrix(points, &unitMatrix);
+    mUnitMatrix.load(unitMatrix);
+
+    updateLocalMatrix(matrix);
 }
 
 SkiaLinearGradientShader::~SkiaLinearGradientShader() {
@@ -172,6 +192,23 @@
     description.gradientType = ProgramDescription::kGradientLinear;
 }
 
+void SkiaLinearGradientShader::computeScreenSpaceMatrix(mat4& screenSpace, const mat4& modelView) {
+    screenSpace.loadMultiply(mUnitMatrix, mShaderMatrix);
+    screenSpace.multiply(modelView);
+}
+
+void SkiaLinearGradientShader::updateLocalMatrix(const SkMatrix* matrix) {
+    if (matrix) {
+        mat4 localMatrix(*matrix);
+        mShaderMatrix.loadInverse(localMatrix);
+    }
+}
+
+void SkiaLinearGradientShader::setMatrix(SkMatrix* matrix) {
+    SkiaShader::setMatrix(matrix);
+    updateLocalMatrix(matrix);
+}
+
 void SkiaLinearGradientShader::setupProgram(Program* program, const mat4& modelView,
         const Snapshot& snapshot, GLuint* textureUnit) {
     GLuint textureSlot = (*textureUnit)++;
@@ -182,34 +219,19 @@
         texture = mGradientCache->addLinearGradient(mKey, mColors, mPositions, mCount, mTileX);
     }
 
-    Rect start(mBounds[0], mBounds[1], mBounds[2], mBounds[3]);
-    if (mMatrix) {
-        mat4 shaderMatrix(*mMatrix);
-        shaderMatrix.mapPoint(start.left, start.top);
-        shaderMatrix.mapPoint(start.right, start.bottom);
-    }
-    snapshot.transform->mapRect(start);
-
-    const float gradientX = start.right - start.left;
-    const float gradientY = start.bottom - start.top;
-
-    mat4 screenSpace(*snapshot.transform);
-    screenSpace.multiply(modelView);
+    mat4 screenSpace;
+    computeScreenSpaceMatrix(screenSpace, modelView);
 
     // Uniforms
     bindTexture(texture->id, gTileModes[mTileX], gTileModes[mTileY], textureSlot);
     glUniform1i(program->getUniform("gradientSampler"), textureSlot);
-    glUniform2f(program->getUniform("gradientStart"), start.left, start.top);
-    glUniform2f(program->getUniform("gradient"), gradientX, gradientY);
-    glUniform1f(program->getUniform("gradientLength"),
-            1.0f / (gradientX * gradientX + gradientY * gradientY));
     glUniformMatrix4fv(program->getUniform("screenSpace"), 1, GL_FALSE, &screenSpace.data[0]);
 }
 
 void SkiaLinearGradientShader::updateTransforms(Program* program, const mat4& modelView,
         const Snapshot& snapshot) {
-    mat4 screenSpace(*snapshot.transform);
-    screenSpace.multiply(modelView);
+    mat4 screenSpace;
+    computeScreenSpaceMatrix(screenSpace, modelView);
     glUniformMatrix4fv(program->getUniform("screenSpace"), 1, GL_FALSE, &screenSpace.data[0]);
 }
 
diff --git a/libs/hwui/SkiaShader.h b/libs/hwui/SkiaShader.h
index 9f8778f..2c1eb35 100644
--- a/libs/hwui/SkiaShader.h
+++ b/libs/hwui/SkiaShader.h
@@ -77,7 +77,7 @@
             const Snapshot& snapshot) {
     }
 
-    void setMatrix(SkMatrix* matrix) {
+    virtual void setMatrix(SkMatrix* matrix) {
         mMatrix = matrix;
     }
 
@@ -139,7 +139,15 @@
             GLuint* textureUnit);
     void updateTransforms(Program* program, const mat4& modelView, const Snapshot& snapshot);
 
+    void setMatrix(SkMatrix* matrix);
+
 private:
+    void updateLocalMatrix(const SkMatrix* matrix);
+    void computeScreenSpaceMatrix(mat4& screenSpace, const mat4& modelView);
+
+    mat4 mUnitMatrix;
+    mat4 mShaderMatrix;
+
     float* mBounds;
     uint32_t* mColors;
     float* mPositions;
diff --git a/tests/HwAccelerationTest/src/com/android/test/hwui/GradientsActivity.java b/tests/HwAccelerationTest/src/com/android/test/hwui/GradientsActivity.java
index b70f3a9..769bfdd 100644
--- a/tests/HwAccelerationTest/src/com/android/test/hwui/GradientsActivity.java
+++ b/tests/HwAccelerationTest/src/com/android/test/hwui/GradientsActivity.java
@@ -24,7 +24,10 @@
 import android.graphics.Paint;
 import android.graphics.Shader;
 import android.os.Bundle;
+import android.view.Gravity;
 import android.view.View;
+import android.widget.FrameLayout;
+import android.widget.SeekBar;
 
 @SuppressWarnings({"UnusedDeclaration"})
 public class GradientsActivity extends Activity {
@@ -32,7 +35,59 @@
     protected void onCreate(Bundle savedInstanceState) {
         super.onCreate(savedInstanceState);
 
-        setContentView(new ShadersView(this));
+        final FrameLayout layout = new FrameLayout(this);
+        final ShadersView shadersView = new ShadersView(this);
+        final GradientView gradientView = new GradientView(this);
+        final SeekBar rotateView = new SeekBar(this);
+        rotateView.setMax(360);
+        rotateView.setOnSeekBarChangeListener(new SeekBar.OnSeekBarChangeListener() {
+            @Override
+            public void onStopTrackingTouch(SeekBar seekBar) {
+            }
+
+            @Override
+            public void onStartTrackingTouch(SeekBar seekBar) {
+            }
+
+            @Override
+            public void onProgressChanged(SeekBar seekBar, int progress, boolean fromUser) {
+                gradientView.setRotationY((float)progress);
+            }
+        });
+        
+        layout.addView(shadersView);
+        layout.addView(gradientView, new FrameLayout.LayoutParams(
+                200, 200, Gravity.CENTER));
+        layout.addView(rotateView, new FrameLayout.LayoutParams(
+                300, FrameLayout.LayoutParams.WRAP_CONTENT,
+                Gravity.CENTER_HORIZONTAL | Gravity.BOTTOM));
+
+        setContentView(layout);
+    }
+    
+    static class GradientView extends View {
+        private final Paint mPaint;
+
+        GradientView(Context c) {
+            super(c);
+
+            LinearGradient gradient = new LinearGradient(0, 0, 200, 0, 0xFF000000, 0,
+                    Shader.TileMode.CLAMP);
+            mPaint = new Paint();
+            mPaint.setShader(gradient);
+        }
+
+        @Override
+        protected void onMeasure(int widthMeasureSpec, int heightMeasureSpec) {
+            super.onMeasure(widthMeasureSpec, heightMeasureSpec);
+            setMeasuredDimension(200, 200);
+        }
+
+        @Override
+        protected void onDraw(Canvas canvas) {
+            super.onDraw(canvas);
+            canvas.drawRect(0.0f, 0.0f, getWidth(), getHeight(), mPaint);
+        }
     }
 
     static class ShadersView extends View {