Skip to content

Commit 278ae58

Browse files
committed
Start of gradient registry
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
1 parent 5baabb6 commit 278ae58

File tree

5 files changed

+328
-1
lines changed

5 files changed

+328
-1
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Targeted by JavaCPP version 1.5.4: DO NOT EDIT THIS FILE
2+
3+
package org.tensorflow.internal.c_api;
4+
5+
import java.nio.*;
6+
import org.bytedeco.javacpp.*;
7+
import org.bytedeco.javacpp.annotation.*;
8+
9+
import static org.tensorflow.internal.c_api.global.tensorflow.*;
10+
11+
12+
/** GradFunc is the signature for all gradient functions in GradOpRegistry.
13+
* Implementations should add operations to compute the gradient outputs of
14+
* 'op' (returned in 'grad_outputs') using 'scope' and 'grad_inputs'. */
15+
@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
16+
public class GradFunc extends FunctionPointer {
17+
static { Loader.load(); }
18+
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
19+
public GradFunc(Pointer p) { super(p); }
20+
protected GradFunc() { allocate(); }
21+
private native void allocate();
22+
public native @ByVal NativeStatus call(@Const @ByRef TF_Scope scope, @Cast("const tensorflow::Operation*") @ByRef TF_Operation op,
23+
@Cast("tensorflow::Output*") @StdVector TF_Output grad_inputs,
24+
@Cast("tensorflow::Output*") @StdVector TF_Output grad_outputs);
25+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Targeted by JavaCPP version 1.5.4: DO NOT EDIT THIS FILE
2+
3+
package org.tensorflow.internal.c_api;
4+
5+
import java.nio.*;
6+
import org.bytedeco.javacpp.*;
7+
import org.bytedeco.javacpp.annotation.*;
8+
9+
import static org.tensorflow.internal.c_api.global.tensorflow.*;
10+
11+
12+
/** GradOpRegistry maintains a static registry of gradient functions.
13+
* Gradient functions are indexed in the registry by the forward op name (i.e.
14+
* "MatMul" -> MatMulGrad func). */
15+
@Namespace("tensorflow::ops") @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
16+
public class GradOpRegistry extends Pointer {
17+
static { Loader.load(); }
18+
/** Default native constructor. */
19+
public GradOpRegistry() { super((Pointer)null); allocate(); }
20+
/** Native array allocator. Access with {@link Pointer#position(long)}. */
21+
public GradOpRegistry(long size) { super((Pointer)null); allocateArray(size); }
22+
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
23+
public GradOpRegistry(Pointer p) { super(p); }
24+
private native void allocate();
25+
private native void allocateArray(long size);
26+
@Override public GradOpRegistry position(long position) {
27+
return (GradOpRegistry)super.position(position);
28+
}
29+
@Override public GradOpRegistry getPointer(long i) {
30+
return new GradOpRegistry(this).position(position + i);
31+
}
32+
33+
/** Registers 'func' as the gradient function for 'op'.
34+
* Returns true if registration was successful, check fails otherwise. */
35+
public native @Cast("bool") boolean Register(@StdString BytePointer op, GradFunc func);
36+
public native @Cast("bool") boolean Register(@StdString String op, GradFunc func);
37+
38+
/** Sets 'func' to the gradient function for 'op' and returns Status OK if
39+
* the gradient function for 'op' exists in the registry.
40+
* Note that 'func' can be null for ops that have registered no-gradient with
41+
* the registry.
42+
* Returns error status otherwise. */
43+
public native @ByVal NativeStatus Lookup(@StdString BytePointer op, @ByPtrPtr GradFunc func);
44+
public native @ByVal NativeStatus Lookup(@StdString String op, @ByPtrPtr GradFunc func);
45+
46+
/** Returns a pointer to the global gradient function registry. */
47+
public static native GradOpRegistry Global();
48+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Targeted by JavaCPP version 1.5.4: DO NOT EDIT THIS FILE
2+
3+
package org.tensorflow.internal.c_api;
4+
5+
import java.nio.*;
6+
import org.bytedeco.javacpp.*;
7+
import org.bytedeco.javacpp.annotation.*;
8+
9+
import static org.tensorflow.internal.c_api.global.tensorflow.*;
10+
11+
// #endif
12+
13+
/** \ingroup core
14+
* Denotes success or failure of a call in Tensorflow. */
15+
@Name("tensorflow::Status") @NoOffset @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
16+
public class NativeStatus extends Pointer {
17+
static { Loader.load(); }
18+
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
19+
public NativeStatus(Pointer p) { super(p); }
20+
21+
/** Create a success status. */
22+
23+
/** \brief Create a status with the specified error code and msg as a
24+
* human-readable string containing more detailed information. */
25+
26+
/** \brief Create a status with the specified error code, msg, and stack trace
27+
* as a human-readable string containing more detailed information. */
28+
// #ifndef SWIG
29+
// #endif
30+
31+
/** Copy the specified status. */
32+
public native @ByRef @Name("operator =") NativeStatus put(@Const @ByRef NativeStatus s);
33+
// #ifndef SWIG
34+
// #endif // SWIG
35+
36+
public static native @ByVal NativeStatus OK();
37+
38+
/** Returns true iff the status indicates success. */
39+
public native @Cast("bool") boolean ok();
40+
41+
42+
43+
public native @StdString BytePointer error_message();
44+
45+
public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef NativeStatus x);
46+
47+
///
48+
public native @Cast("bool") @Name("operator !=") boolean notEquals(@Const @ByRef NativeStatus x);
49+
50+
/** \brief If {@code ok()}, stores {@code new_status} into {@code *this}. If {@code !ok()},
51+
* preserves the current status, but may augment with additional
52+
* information about {@code new_status}.
53+
*
54+
* Convenient way of keeping track of the first error encountered.
55+
* Instead of:
56+
* {@code if (overall_status.ok()) overall_status = new_status}
57+
* Use:
58+
* {@code overall_status.Update(new_status);} */
59+
public native void Update(@Const @ByRef NativeStatus new_status);
60+
61+
/** \brief Return a string representation of this status suitable for
62+
* printing. Returns the string {@code "OK"} for success. */
63+
public native @StdString BytePointer ToString();
64+
65+
// Ignores any errors. This method does nothing except potentially suppress
66+
// complaints from any tools that are checking that errors are not dropped on
67+
// the floor.
68+
public native void IgnoreError();
69+
}

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/global/tensorflow.java

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4459,6 +4459,183 @@ public static native void TFE_ContextExportRunMetadata(TFE_Context ctx,
44594459
// #endif // TENSORFLOW_CC_FRAMEWORK_SCOPE_H_
44604460

44614461

4462+
// Parsed from tensorflow/cc/framework/grad_op_registry.h
4463+
4464+
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
4465+
4466+
Licensed under the Apache License, Version 2.0 (the "License");
4467+
you may not use this file except in compliance with the License.
4468+
You may obtain a copy of the License at
4469+
4470+
http://www.apache.org/licenses/LICENSE-2.0
4471+
4472+
Unless required by applicable law or agreed to in writing, software
4473+
distributed under the License is distributed on an "AS IS" BASIS,
4474+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4475+
See the License for the specific language governing permissions and
4476+
limitations under the License.
4477+
==============================================================================*/
4478+
4479+
// #ifndef TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
4480+
// #define TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
4481+
4482+
// #include <unordered_map>
4483+
4484+
// #include "tensorflow/cc/framework/ops.h"
4485+
// #include "tensorflow/cc/framework/scope.h"
4486+
// Targeting ../GradFunc.java
4487+
4488+
4489+
// Targeting ../GradOpRegistry.java
4490+
4491+
4492+
4493+
// namespace ops
4494+
4495+
// Macros used to define gradient functions for ops.
4496+
// #define REGISTER_GRADIENT_OP(name, fn)
4497+
// REGISTER_GRADIENT_OP_UNIQ_HELPER(__COUNTER__, name, fn)
4498+
4499+
// #define REGISTER_NO_GRADIENT_OP(name)
4500+
// REGISTER_GRADIENT_OP_UNIQ_HELPER(__COUNTER__, name, nullptr)
4501+
4502+
// #define REGISTER_GRADIENT_OP_UNIQ_HELPER(ctr, name, fn)
4503+
// REGISTER_GRADIENT_OP_UNIQ(ctr, name, fn)
4504+
4505+
// #define REGISTER_GRADIENT_OP_UNIQ(ctr, name, fn)
4506+
// static bool unused_ret_val_##ctr =
4507+
// ::tensorflow::ops::GradOpRegistry::Global()->Register(name, fn)
4508+
4509+
// namespace tensorflow
4510+
4511+
// #endif // TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
4512+
4513+
4514+
// Parsed from tensorflow/core/platform/status.h
4515+
4516+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
4517+
4518+
Licensed under the Apache License, Version 2.0 (the "License");
4519+
you may not use this file except in compliance with the License.
4520+
You may obtain a copy of the License at
4521+
4522+
http://www.apache.org/licenses/LICENSE-2.0
4523+
4524+
Unless required by applicable law or agreed to in writing, software
4525+
distributed under the License is distributed on an "AS IS" BASIS,
4526+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4527+
See the License for the specific language governing permissions and
4528+
limitations under the License.
4529+
==============================================================================*/
4530+
4531+
// #ifndef TENSORFLOW_CORE_PLATFORM_STATUS_H_
4532+
// #define TENSORFLOW_CORE_PLATFORM_STATUS_H_
4533+
4534+
// #include <functional>
4535+
// #include <iosfwd>
4536+
// #include <memory>
4537+
// #include <string>
4538+
4539+
// #include "tensorflow/core/platform/logging.h"
4540+
// #include "tensorflow/core/platform/macros.h"
4541+
// #include "tensorflow/core/platform/stringpiece.h"
4542+
// #include "tensorflow/core/platform/types.h"
4543+
// #include "tensorflow/core/protobuf/error_codes.pb.h"
4544+
4545+
// A struct representing a frame in a stack trace.
4546+
4547+
// #if defined(__clang__)
4548+
// Only clang supports warn_unused_result as a type annotation.
4549+
// Targeting ../NativeStatus.java
4550+
4551+
4552+
4553+
// Helper class to manage multiple child status values.
4554+
4555+
4556+
4557+
4558+
4559+
// #ifndef SWIG
4560+
4561+
4562+
4563+
// #endif // SWIG
4564+
4565+
4566+
4567+
4568+
4569+
/** \ingroup core */
4570+
@Namespace("tensorflow") public static native @Cast("std::ostream*") @ByRef @Name("operator <<") Pointer shiftLeft(@Cast("std::ostream*") @ByRef Pointer os, @Const @ByRef NativeStatus x);
4571+
4572+
@Namespace("tensorflow") public static native @StdString BytePointer TfCheckOpHelperOutOfLine(
4573+
@Const @ByRef NativeStatus v, @Cast("const char*") BytePointer msg);
4574+
@Namespace("tensorflow") public static native @StdString BytePointer TfCheckOpHelperOutOfLine(
4575+
@Const @ByRef NativeStatus v, String msg);
4576+
4577+
@Namespace("tensorflow") public static native @StdString BytePointer TfCheckOpHelper(@ByVal NativeStatus v,
4578+
@Cast("const char*") BytePointer msg);
4579+
@Namespace("tensorflow") public static native @StdString BytePointer TfCheckOpHelper(@ByVal NativeStatus v,
4580+
String msg);
4581+
4582+
// #define TF_DO_CHECK_OK(val, level)
4583+
// while (auto _result = ::tensorflow::TfCheckOpHelper(val, #val))
4584+
// LOG(level) << *(_result)
4585+
4586+
// #define TF_CHECK_OK(val) TF_DO_CHECK_OK(val, FATAL)
4587+
// #define TF_QCHECK_OK(val) TF_DO_CHECK_OK(val, QFATAL)
4588+
4589+
// DEBUG only version of TF_CHECK_OK. Compiler still parses 'val' even in opt
4590+
// mode.
4591+
// #ifndef NDEBUG
4592+
// #define TF_DCHECK_OK(val) TF_CHECK_OK(val)
4593+
// #else
4594+
// #define TF_DCHECK_OK(val)
4595+
// while (false && (::tensorflow::Status::OK() == (val))) LOG(FATAL)
4596+
// #endif
4597+
4598+
// namespace tensorflow
4599+
4600+
// #endif // TENSORFLOW_CORE_PLATFORM_STATUS_H_
4601+
4602+
4603+
// Parsed from tensorflow/c/tf_status_helper.h
4604+
4605+
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
4606+
4607+
Licensed under the Apache License, Version 2.0 (the "License");
4608+
you may not use this file except in compliance with the License.
4609+
You may obtain a copy of the License at
4610+
4611+
http://www.apache.org/licenses/LICENSE-2.0
4612+
4613+
Unless required by applicable law or agreed to in writing, software
4614+
distributed under the License is distributed on an "AS IS" BASIS,
4615+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4616+
See the License for the specific language governing permissions and
4617+
limitations under the License.
4618+
==============================================================================*/
4619+
4620+
// #ifndef TENSORFLOW_C_TF_STATUS_HELPER_H_
4621+
// #define TENSORFLOW_C_TF_STATUS_HELPER_H_
4622+
4623+
// #include "tensorflow/c/tf_status.h"
4624+
// #include "tensorflow/core/platform/status.h"
4625+
4626+
// Set the attribute of "tf_status" from the attributes of "status".
4627+
@Namespace("tensorflow") public static native void Set_TF_Status_from_Status(TF_Status tf_status,
4628+
@Const @ByRef NativeStatus status);
4629+
4630+
// Returns a "status" from "tf_status".
4631+
@Namespace("tensorflow") public static native @ByVal NativeStatus StatusFromTF_Status(@Const TF_Status tf_status);
4632+
// namespace internal
4633+
4634+
// namespace tensorflow
4635+
4636+
// #endif // TENSORFLOW_C_TF_STATUS_HELPER_H_
4637+
4638+
44624639
// Targeting ../TF_Graph.java
44634640

44644641

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@
5353
"tensorflow_adapters.h",
5454
"tensorflow/c/eager/c_api.h",
5555
"tensorflow/cc/framework/scope.h",
56+
"tensorflow/cc/framework/grad_op_registry.h",
57+
"tensorflow/core/platform/status.h",
58+
"tensorflow/c/tf_status_helper.h",
5659
// "tensorflow/cc/framework/ops.h",
5760
"tensorflow/c/c_api_internal.h",
5861
},
@@ -204,7 +207,7 @@ public void map(InfoMap infoMap) {
204207
.put(new Info("c_api_internal.h")
205208
.linePatterns("struct TF_OperationDescription \\{", "\\};",
206209
"struct TF_Graph \\{", "\\};"))
207-
.put(new Info("TF_CAPI_EXPORT", "TF_Bool", "TF_GUARDED_BY").cppTypes().annotations())
210+
.put(new Info("TF_CAPI_EXPORT", "TF_Bool", "TF_GUARDED_BY", "TF_MUST_USE_RESULT").cppTypes().annotations())
208211
.put(new Info("TF_Buffer::data").javaText("public native @Const Pointer data(); public native TF_Buffer data(Pointer data);"))
209212
.put(new Info("TF_Status").pointerTypes("TF_Status").base("org.tensorflow.internal.c_api.AbstractTF_Status"))
210213
.put(new Info("TF_Buffer").pointerTypes("TF_Buffer").base("org.tensorflow.internal.c_api.AbstractTF_Buffer"))
@@ -243,7 +246,11 @@ public void map(InfoMap infoMap) {
243246
.put(new Info("absl::Span", "tensorflow::gtl::ArraySlice").annotations("@Span"))
244247
.put(new Info("tensorflow::Output").javaNames("TF_Output").cast())
245248
.put(new Info("tensorflow::Operation").javaNames("TF_Operation").cast())
249+
.put(new Info("tensorflow::Status").pointerTypes("NativeStatus").purify())
246250
.put(new Info("tensorflow::CompositeOpScopes",
251+
"tensorflow::StackFrame",
252+
"tensorflow::StatusGroup",
253+
"tensorflow::internal::TF_StatusDeleter",
247254
"tensorflow::GraphDef",
248255
"tensorflow::Scope::graph_as_shared_ptr",
249256
"tensorflow::Scope::ToGraphDef",
@@ -259,6 +266,7 @@ public void map(InfoMap infoMap) {
259266
"tensorflow::Scope::WithAssignedDevice",
260267
"tensorflow::Scope::status",
261268
"tensorflow::Scope::UpdateStatus",
269+
"tensorflow::Status::code",
262270
"tensorflow::CreateOutputWithScope",
263271
"TF_OperationDescription::colocation_constraints"
264272
).skip());

0 commit comments

Comments
 (0)