diff --git a/scripts/nn/layers/gelu.dml b/scripts/nn/layers/gelu.dml new file mode 100644 index 00000000000..23c1d407be1 --- /dev/null +++ b/scripts/nn/layers/gelu.dml @@ -0,0 +1,70 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +/* + * Gaussian Error Linear Unit (GELU) nonlinearity layer. + */ + +source("nn/layers/tanh.dml") as tanh + +forward = function(matrix[double] X) + return (matrix[double] out) { + /* + * Computes the forward pass for a GELU nonlinearity layer, via + * its tanh approximation. + * + * Performs an element-wise evaluation of + * `GELU(x) = x * CDF(x)`. + * where CDF is the cumulative distribution function of the + * standard normal distribution: + * `CDF(x) = 0.5 * (1 + erf(x/sqrt(2)))` + * This implementation uses the tanh approximation: + * `CDF(x) =~ 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715x^3)))` + * + * Inputs: + * - X: Inputs, of shape (any, any). + * + * Outputs: + * - out: Outputs, of same shape as `X`. + */ + cdf = 0.5 * (1 + tanh(sqrt(2 / pi) * (X + 0.044715 * X^3))) + out = cdf * X +} + +backward = function(matrix[double] dout, matrix[double] X) + return (matrix[double] dX) { + /* + * Computes the backward pass for a GELU nonlinearity layer, via + * its tanh approximation. + * + * Inputs: + * - dout: Gradient wrt `out` from upstream, of same shape as `X`. + * - X: Previous input data matrix, of shape (any, any). + * + * Outputs: + * - dX: Gradient wrt `X`, of same shape as `X`. + */ + a = sqrt(2 / pi) + b = 0.044715 + T = tanh(a * (X + b * X^3)) + dT = 1 - T^2 + dX = dout * (0.5 * (1 + T) + 0.5 * X * dT * a * (1 + 3 * b * X^2)) +} diff --git a/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java b/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java index 3b002871d73..a9922cf35f7 100644 --- a/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java +++ b/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java @@ -124,6 +124,11 @@ public void resnet() { run("resnet_bottleneck.dml"); } + @Test + public void gelu() { + run("gelu.dml"); + } + @Override protected void run(String name) { super.run("component/" + name); diff --git a/src/test/scripts/applications/nn/component/gelu.dml b/src/test/scripts/applications/nn/component/gelu.dml new file mode 100644 index 00000000000..3d7ea833458 --- /dev/null +++ b/src/test/scripts/applications/nn/component/gelu.dml @@ -0,0 +1,66 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +source("nn/layers/gelu.dml") as gelu +source("src/test/scripts/applications/nn/util.dml") as test_util + +gelu_test1 = function() { + print("Testing GELU, test 1") + + X = matrix("1. -0.5 + 0. 2.", rows=2, cols=2) + dout = matrix("1 1 + 1 1", rows=2, cols=2) + out_expected = matrix("0.841192 -0.154286 + 0. 1.9545977", rows=2, cols=2) + gradient_expected = matrix("1.0829641 0.13263011 + 0.5 1.0860993", rows=2, cols=2) + + out = gelu::forward(X) + + test_util::check_all_close(out, out_expected, 0.00001) + + gradient = gelu::backward(dout, X) + test_util::check_all_close(gradient, gradient_expected, 0.00001) +} + +gelu_test2 = function() { + print("Testing GELU, test 2") + + X = matrix("0.5 -1.5 + 1. -2.", rows=2, cols=2) + dout = matrix("1 1 + 1 1", rows=2, cols=2) + out_expected = matrix("0.345714 -0.10042843 + 0.841192 -0.04540229", rows=2, cols=2) + gradient_expected = matrix("0.8673699 -0.1277108 + 1.0829641 -0.08609922", rows=2, cols=2) + + out = gelu::forward(X) + + test_util::check_all_close(out, out_expected, 0.00001) + + gradient = gelu::backward(dout, X) + test_util::check_all_close(gradient, gradient_expected, 0.00001) +} + +gelu_test1() +gelu_test2() \ No newline at end of file