Skip to content

Commit 4504a6f

Browse files
committed
Clarify the difference between CustomGradient and RawCustomGradient
Signed-off-by: Ryan Nett <JNett96@gmail.com>
1 parent ed1da29 commit 4504a6f

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ private TensorFlow() {}
152152
}
153153

154154
/**
155-
* Keeps references to custom gradient functions to prevent them from being deallocated.
155+
* Keeps references to custom gradient functions to prevent them from being deallocated. All
156+
* access of this set should be synchronized on this class.
156157
*
157158
* <p><b>Required for correctness</b>
158159
*/
@@ -170,6 +171,9 @@ private static synchronized boolean hasGradient(String opType) {
170171
/**
171172
* Register a custom gradient function for ops of {@code opType} type.
172173
*
174+
* <p>Creates the gradient based off of a {@link GraphOperation}. To operate on the op input class
175+
* instead use {@link CustomGradient}.
176+
*
173177
* <p>Note that this only works with graph gradients, and will eventually be deprecated in favor
174178
* of unified gradient support once it is fully supported by tensorflow core.
175179
*
@@ -193,7 +197,8 @@ public static synchronized boolean registerCustomGradient(
193197
/**
194198
* Register a custom gradient function for ops of {@code inputClass}'s op type. The actual op type
195199
* is detected from the class's {@link OpInputsMetadata} annotation. As such, it only works on
196-
* generated op classes or custom op classes with the correct annotations.
200+
* generated op classes or custom op classes with the correct annotations. To operate on the
201+
* {@link org.tensorflow.GraphOperation} directly use {@link RawCustomGradient}.
197202
*
198203
* @param inputClass the inputs class of op to register the gradient for.
199204
* @param gradient the gradient function to use

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,15 @@
2626
* A custom gradient for ops of type {@link T}. Should be registered using {@link
2727
* TensorFlow#registerCustomGradient(Class, CustomGradient)}.
2828
*
29+
* <p>Creates the gradient based off of an instance of the op inputs class, which is created using
30+
* reflection. To operate on the {@link org.tensorflow.GraphOperation} directly use {@link
31+
* RawCustomGradient}.
32+
*
2933
* @param <T> the type of op this gradient is for.
3034
*/
3135
@SuppressWarnings("rawtypes")
3236
@FunctionalInterface
33-
public interface CustomGradient<T extends RawOpInputs> {
37+
public interface CustomGradient<T extends RawOpInputs<?>> {
3438

3539
/**
3640
* Calculate the gradients for {@code op}.

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
* A custom gradient for an op of unspecified type. Should be registered using {@link
2828
* TensorFlow#registerCustomGradient(String, RawCustomGradient)}.
2929
*
30+
* <p>Creates the gradient based off of a {@link GraphOperation}. To operate on the op input class
31+
* instead use {@link CustomGradient}.
32+
*
3033
* <p>The op type of {@code op} will depend on the op type string passed to the registration method.
3134
* Note that the registration method can be called more than once, resulting this gradient function
3235
* being used for multiple different op types.

0 commit comments

Comments
 (0)