Skip to content

Commit a098ea4

Browse files
Address PR review comments for GELU approximate parameter
- Move Approximate enum from GELU module class to neutral TorchSharp namespace as GELUApproximate, removing Tensor/functional layer dependency on Modules layer - Add CharSet, BestFitMapping, ThrowOnUnmappableChar attributes to THSTensor_gelu/gelu_ DllImport declarations to match existing LPStr-based imports pattern - Update all references in Tensor.cs, GELU.cs, and tests Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 1ea4b6a commit a098ea4

5 files changed

Lines changed: 30 additions & 27 deletions

File tree

src/TorchSharp/NN/Activation/GELU.cs

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,7 @@ namespace Modules
1414
/// </summary>
1515
public sealed class GELU : ParameterLessModule<Tensor, Tensor>
1616
{
17-
/// <summary>
18-
/// Specifies the approximation method for GELU.
19-
/// </summary>
20-
public enum Approximate
21-
{
22-
/// <summary>
23-
/// Exact GELU computation.
24-
/// </summary>
25-
none,
26-
/// <summary>
27-
/// Tanh-based approximation.
28-
/// </summary>
29-
tanh
30-
}
31-
32-
internal GELU(bool inplace, Approximate approximate = Approximate.none) : base(nameof(GELU))
17+
internal GELU(bool inplace, GELUApproximate approximate = GELUApproximate.none) : base(nameof(GELU))
3318
{
3419
this.inplace = inplace;
3520
this.approximate = approximate;
@@ -42,7 +27,7 @@ public override Tensor forward(Tensor tensor)
4227

4328
public bool inplace {get; set; }
4429

45-
public Approximate approximate { get; set; }
30+
public GELUApproximate approximate { get; set; }
4631
}
4732
}
4833

@@ -72,7 +57,7 @@ public static GELU GELU(bool inplace)
7257
/// </summary>
7358
/// <param name="approximate">The approximation method to use. Default: none</param>
7459
/// <param name="inplace">Do the operation in-place. Default: False</param>
75-
public static GELU GELU(GELU.Approximate approximate, bool inplace = false)
60+
public static GELU GELU(GELUApproximate approximate, bool inplace = false)
7661
{
7762
return new GELU(inplace, approximate);
7863
}
@@ -95,7 +80,7 @@ public static Tensor gelu(Tensor x, bool inplace)
9580
/// <param name="x">The input tensor</param>
9681
/// <param name="approximate">The approximation method to use.</param>
9782
/// <param name="inplace">Do the operation in-place. Default: False</param>
98-
public static Tensor gelu(Tensor x, GELU.Approximate approximate, bool inplace = false)
83+
public static Tensor gelu(Tensor x, GELUApproximate approximate, bool inplace = false)
9984
{
10085
return inplace ? x.gelu_(approximate).alias() : x.gelu(approximate);
10186
}

src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -706,10 +706,10 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
706706
[DllImport("LibTorchSharp")]
707707
internal static extern void THSTensor_elu_(IntPtr tensor, IntPtr alpha, IntPtr scale, IntPtr input_scale);
708708

709-
[DllImport("LibTorchSharp")]
709+
[DllImport("LibTorchSharp", CharSet = CharSet.Ansi, BestFitMapping = false, ThrowOnUnmappableChar = true)]
710710
internal static extern IntPtr THSTensor_gelu(IntPtr tensor, [MarshalAs(UnmanagedType.LPStr)] string approximate);
711711

712-
[DllImport("LibTorchSharp")]
712+
[DllImport("LibTorchSharp", CharSet = CharSet.Ansi, BestFitMapping = false, ThrowOnUnmappableChar = true)]
713713
internal static extern IntPtr THSTensor_gelu_(IntPtr tensor, [MarshalAs(UnmanagedType.LPStr)] string approximate);
714714

715715
[DllImport("LibTorchSharp")]
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
2+
namespace TorchSharp
3+
{
4+
/// <summary>
5+
/// Specifies the approximation method for the GELU activation function.
6+
/// </summary>
7+
public enum GELUApproximate
8+
{
9+
/// <summary>
10+
/// Exact GELU computation.
11+
/// </summary>
12+
none,
13+
/// <summary>
14+
/// Tanh-based approximation.
15+
/// </summary>
16+
tanh
17+
}
18+
}

src/TorchSharp/Tensor/Tensor.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2983,9 +2983,9 @@ public Tensor gelu()
29832983
return new Tensor(res);
29842984
}
29852985

2986-
public Tensor gelu(TorchSharp.Modules.GELU.Approximate approximate)
2986+
public Tensor gelu(GELUApproximate approximate)
29872987
{
2988-
var res = NativeMethods.THSTensor_gelu(Handle, approximate == TorchSharp.Modules.GELU.Approximate.tanh ? "tanh" : "none");
2988+
var res = NativeMethods.THSTensor_gelu(Handle, approximate == GELUApproximate.tanh ? "tanh" : "none");
29892989
if (res == IntPtr.Zero)
29902990
CheckForErrors();
29912991
return new Tensor(res);
@@ -2999,9 +2999,9 @@ public Tensor gelu_()
29992999
return new Tensor(res);
30003000
}
30013001

3002-
public Tensor gelu_(TorchSharp.Modules.GELU.Approximate approximate)
3002+
public Tensor gelu_(GELUApproximate approximate)
30033003
{
3004-
var res = NativeMethods.THSTensor_gelu_(Handle, approximate == TorchSharp.Modules.GELU.Approximate.tanh ? "tanh" : "none");
3004+
var res = NativeMethods.THSTensor_gelu_(Handle, approximate == GELUApproximate.tanh ? "tanh" : "none");
30053005
if (res == IntPtr.Zero)
30063006
CheckForErrors();
30073007
return new Tensor(res);

test/TorchSharpTest/NN.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,7 @@ public void EvaluateGELU()
621621
[Fact]
622622
public void EvaluateGELUWithTanhApproximate()
623623
{
624-
var rel = GELU(Modules.GELU.Approximate.tanh);
624+
var rel = GELU(GELUApproximate.tanh);
625625

626626
foreach (var device in TestUtils.AvailableDevices()) {
627627
var input = torch.randn(new long[] { 64, 8 }, device: device) * 25.0;
@@ -636,7 +636,7 @@ public void EvaluateGELUWithTanhApproximate()
636636
// Verify that tanh approximate produces different results from exact
637637
var x = torch.tensor(new float[] { -1.0f, 0.0f, 1.0f, 2.0f });
638638
var exact = torch.nn.functional.gelu(x);
639-
var approx = torch.nn.functional.gelu(x, Modules.GELU.Approximate.tanh);
639+
var approx = torch.nn.functional.gelu(x, GELUApproximate.tanh);
640640
Assert.False(exact.allclose(approx, rtol: 1e-5, atol: 1e-5));
641641
}
642642

0 commit comments

Comments
 (0)