File tree Expand file tree Collapse file tree 3 files changed +15
-3
lines changed
tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow Expand file tree Collapse file tree 3 files changed +15
-3
lines changed Original file line number Diff line number Diff line change @@ -152,7 +152,8 @@ private TensorFlow() {}
152152 }
153153
154154 /**
155- * Keeps references to custom gradient functions to prevent them from being deallocated.
155+ * Keeps references to custom gradient functions to prevent them from being deallocated. All
156+ * access of this set should be synchronized on this class.
156157 *
157158 * <p><b>Required for correctness</b>
158159 */
@@ -170,6 +171,9 @@ private static synchronized boolean hasGradient(String opType) {
170171 /**
171172 * Register a custom gradient function for ops of {@code opType} type.
172173 *
174+ * <p>Creates the gradient based off of a {@link GraphOperation}. To operate on the op input class
175+ * instead use {@link CustomGradient}.
176+ *
173177 * <p>Note that this only works with graph gradients, and will eventually be deprecated in favor
174178 * of unified gradient support once it is fully supported by tensorflow core.
175179 *
@@ -193,7 +197,8 @@ public static synchronized boolean registerCustomGradient(
193197 /**
194198 * Register a custom gradient function for ops of {@code inputClass}'s op type. The actual op type
195199 * is detected from the class's {@link OpInputsMetadata} annotation. As such, it only works on
196- * generated op classes or custom op classes with the correct annotations.
200+ * generated op classes or custom op classes with the correct annotations. To operate on the
201+ * {@link org.tensorflow.GraphOperation} directly use {@link RawCustomGradient}.
197202 *
198203 * @param inputClass the inputs class of op to register the gradient for.
199204 * @param gradient the gradient function to use
Original file line number Diff line number Diff line change 2626 * A custom gradient for ops of type {@link T}. Should be registered using {@link
2727 * TensorFlow#registerCustomGradient(Class, CustomGradient)}.
2828 *
29+ * <p>Creates the gradient based off of an instance of the op inputs class, which is created using
30+ * reflection. To operate on the {@link org.tensorflow.GraphOperation} directly use {@link
31+ * RawCustomGradient}.
32+ *
2933 * @param <T> the type of op this gradient is for.
3034 */
3135@ SuppressWarnings ("rawtypes" )
3236@ FunctionalInterface
33- public interface CustomGradient <T extends RawOpInputs > {
37+ public interface CustomGradient <T extends RawOpInputs <?> > {
3438
3539 /**
3640 * Calculate the gradients for {@code op}.
Original file line number Diff line number Diff line change 2727 * A custom gradient for an op of unspecified type. Should be registered using {@link
2828 * TensorFlow#registerCustomGradient(String, RawCustomGradient)}.
2929 *
30+ * <p>Creates the gradient based off of a {@link GraphOperation}. To operate on the op input class
31+ * instead use {@link CustomGradient}.
32+ *
3033 * <p>The op type of {@code op} will depend on the op type string passed to the registration method.
3134 * Note that the registration method can be called more than once, resulting this gradient function
3235 * being used for multiple different op types.
You can’t perform that action at this time.
0 commit comments