Skip to content

Commit f6c0480

Browse files
committed
Add vector adapter
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
1 parent 593731c commit f6c0480

File tree

7 files changed

+177
-15
lines changed

7 files changed

+177
-15
lines changed

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/GradFunc.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import org.bytedeco.javacpp.annotation.ByVal;
1010
import org.bytedeco.javacpp.annotation.Const;
1111
import org.bytedeco.javacpp.annotation.Properties;
12-
import org.bytedeco.javacpp.annotation.StdVector;
12+
import org.bytedeco.javacpp.annotation.StdMove;
1313

1414

1515
/**
@@ -39,6 +39,6 @@ protected GradFunc() {
3939

4040
public native @ByVal
4141
NativeStatus call(@Const @ByRef TF_Scope scope, @Const @ByRef NativeOperation op,
42-
@StdVector NativeOutput grad_inputs,
43-
@StdVector NativeOutput grad_outputs);
42+
@StdMove NativeOutputVector grad_inputs,
43+
NativeOutputVector grad_outputs);
4444
}
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
// Targeted by JavaCPP version 1.5.5: DO NOT EDIT THIS FILE
2+
3+
package org.tensorflow.internal.c_api;
4+
5+
import org.bytedeco.javacpp.Loader;
6+
import org.bytedeco.javacpp.Pointer;
7+
import org.bytedeco.javacpp.annotation.ByRef;
8+
import org.bytedeco.javacpp.annotation.ByVal;
9+
import org.bytedeco.javacpp.annotation.Cast;
10+
import org.bytedeco.javacpp.annotation.Const;
11+
import org.bytedeco.javacpp.annotation.Index;
12+
import org.bytedeco.javacpp.annotation.Name;
13+
import org.bytedeco.javacpp.annotation.NoOffset;
14+
import org.bytedeco.javacpp.annotation.Properties;
15+
import org.bytedeco.javacpp.annotation.StdMove;
16+
17+
@Name("std::vector<tensorflow::Output>")
18+
@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
19+
public class NativeOutputVector extends Pointer {
20+
21+
static {
22+
Loader.load();
23+
}
24+
25+
/**
26+
* Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}.
27+
*/
28+
public NativeOutputVector(Pointer p) {
29+
super(p);
30+
}
31+
32+
public NativeOutputVector(NativeOutput value) {
33+
this(1);
34+
put(0, value);
35+
}
36+
37+
public NativeOutputVector(NativeOutput... array) {
38+
this(array.length);
39+
put(array);
40+
}
41+
42+
public NativeOutputVector() {
43+
allocate();
44+
}
45+
46+
public NativeOutputVector(long n) {
47+
allocate(n);
48+
}
49+
50+
private native void allocate();
51+
52+
private native void allocate(@Cast("size_t") long n);
53+
54+
public native @Name("operator =")
55+
@ByRef
56+
NativeOutputVector put(@ByRef @StdMove NativeOutputVector x);
57+
58+
public boolean empty() {
59+
return size() == 0;
60+
}
61+
62+
public native long size();
63+
64+
public void clear() {
65+
resize(0);
66+
}
67+
68+
public native void resize(@Cast("size_t") long n);
69+
70+
@Index(function = "at")
71+
public native @ByRef
72+
NativeOutput get(@Cast("size_t") long i);
73+
74+
public native NativeOutputVector put(@Cast("size_t") long i, NativeOutput value);
75+
76+
public native @ByVal
77+
Iterator insert(@ByVal Iterator pos, @ByRef NativeOutput value);
78+
79+
public native @ByVal
80+
Iterator erase(@ByVal Iterator pos);
81+
82+
public native @ByVal
83+
Iterator begin();
84+
85+
public native @ByVal
86+
Iterator end();
87+
88+
@NoOffset
89+
@Name("iterator")
90+
public static class Iterator extends Pointer {
91+
92+
public Iterator(Pointer p) {
93+
super(p);
94+
}
95+
96+
public Iterator() {
97+
}
98+
99+
public native @Name("operator ++")
100+
@ByRef
101+
Iterator increment();
102+
103+
public native @Name("operator ==")
104+
boolean equals(@ByRef Iterator it);
105+
106+
public native @Name("operator *")
107+
@ByRef
108+
@Const
109+
NativeOutput get();
110+
}
111+
112+
public NativeOutput[] get() {
113+
NativeOutput[] array = new NativeOutput[size() < Integer.MAX_VALUE ? (int) size()
114+
: Integer.MAX_VALUE];
115+
for (int i = 0; i < array.length; i++) {
116+
array[i] = get(i);
117+
}
118+
return array;
119+
}
120+
121+
@Override
122+
public String toString() {
123+
return java.util.Arrays.toString(get());
124+
}
125+
126+
public NativeOutput pop_back() {
127+
long size = size();
128+
NativeOutput value = get(size - 1);
129+
resize(size - 1);
130+
return value;
131+
}
132+
133+
public NativeOutputVector push_back(NativeOutput value) {
134+
long size = size();
135+
resize(size + 1);
136+
return put(size, value);
137+
}
138+
139+
public NativeOutputVector put(NativeOutput value) {
140+
if (size() != 1) {
141+
resize(1);
142+
}
143+
return put(0, value);
144+
}
145+
146+
public NativeOutputVector put(NativeOutput... array) {
147+
if (size() != array.length) {
148+
resize(array.length);
149+
}
150+
for (int i = 0; i < array.length; i++) {
151+
put(i, array[i]);
152+
}
153+
return this;
154+
}
155+
}
156+

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/global/tensorflow.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ public class tensorflow extends org.tensorflow.internal.c_api.presets.tensorflow
7474
Loader.load();
7575
}
7676

77+
// Targeting ../NativeOutputVector.java
78+
7779
// Parsed from tensorflow/core/platform/ctstring_internal.h
7880

7981
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GradientAdapterHelpers.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.Iterator;
2121
import java.util.List;
2222
import org.tensorflow.internal.c_api.NativeOutput;
23+
import org.tensorflow.internal.c_api.NativeOutputVector;
2324
import org.tensorflow.internal.c_api.Node;
2425

2526
/**
@@ -34,10 +35,10 @@ public class GradientAdapterHelpers {
3435
* @param g the graph the outputs are in
3536
* @param nativeOutputs the native outputs to convert
3637
*/
37-
public static List<Output<?>> fromNativeOutputs(Graph g, NativeOutput nativeOutputs) {
38-
List<Output<?>> gradInputs = new ArrayList<>((int) nativeOutputs.capacity());
38+
public static List<Output<?>> fromNativeOutputs(Graph g, NativeOutputVector nativeOutputs) {
39+
List<Output<?>> gradInputs = new ArrayList<>((int) nativeOutputs.size());
3940
for (int i = 0; i < nativeOutputs.capacity(); i++) {
40-
NativeOutput output = nativeOutputs.position(i);
41+
NativeOutput output = nativeOutputs.get(i);
4142
gradInputs.add(new Output<>(getGraphOp(g, output.node()),
4243
output.index()));
4344
}
@@ -50,12 +51,13 @@ public static List<Output<?>> fromNativeOutputs(Graph g, NativeOutput nativeOutp
5051
* @param outputs the outputs to put
5152
* @param nativeOutputs the native array to put the outputs into
5253
*/
53-
public static void putToNativeOutputs(List<Operand<?>> outputs, NativeOutput nativeOutputs) {
54-
nativeOutputs.capacity(outputs.size());
54+
public static void putToNativeOutputs(List<Operand<?>> outputs,
55+
NativeOutputVector nativeOutputs) {
56+
nativeOutputs.resize(outputs.size());
5557
for (int i = 0; i < outputs.size(); i++) {
5658
Output<?> output = outputs.get(i).asOutput();
5759
Node node = ((GraphOperation) output.op()).getUnsafeNativeHandle().node();
58-
nativeOutputs.position(i).put(new NativeOutput(node, output.index()));
60+
nativeOutputs.put(i, new NativeOutput(node, output.index()));
5961
}
6062
}
6163

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,8 @@ public void map(InfoMap infoMap) {
304304
.annotations("@StdString").valueTypes("BytePointer", "String")
305305
.pointerTypes("BytePointer"))
306306
.put(new Info("absl::Span", "tensorflow::gtl::ArraySlice").annotations("@Span"))
307+
.put(new Info("std::vector<tensorflow::Output>").valueTypes("@StdMove NativeOutputVector")
308+
.pointerTypes("NativeOutputVector").define())
307309
.put(new Info("tensorflow::Output").javaNames("NativeOutput"))
308310
.put(new Info("tensorflow::Operation").javaNames("NativeOperation"))
309311
.put(new Info("tensorflow::Status").javaNames("NativeStatus").purify())

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawGradientAdapter.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import org.tensorflow.Output;
2828
import org.tensorflow.internal.c_api.GradFunc;
2929
import org.tensorflow.internal.c_api.NativeOperation;
30-
import org.tensorflow.internal.c_api.NativeOutput;
30+
import org.tensorflow.internal.c_api.NativeOutputVector;
3131
import org.tensorflow.internal.c_api.NativeStatus;
3232
import org.tensorflow.internal.c_api.TF_Scope;
3333
import org.tensorflow.internal.c_api.TF_Status;
@@ -44,8 +44,8 @@ public RawGradientAdapter(RawCustomGradient gradient) {
4444
}
4545

4646
@Override
47-
public NativeStatus call(TF_Scope scope, NativeOperation op, NativeOutput grad_inputs,
48-
NativeOutput grad_outputs) {
47+
public NativeStatus call(TF_Scope scope, NativeOperation op, NativeOutputVector grad_inputs,
48+
NativeOutputVector grad_outputs) {
4949
try (PointerScope pointerScope = new PointerScope()) {
5050
Graph g = Graph.findGraphForPointer(scope.graph());
5151
if (g == null) {

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/TypedGradientAdapter.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import org.tensorflow.Output;
2828
import org.tensorflow.internal.c_api.GradFunc;
2929
import org.tensorflow.internal.c_api.NativeOperation;
30-
import org.tensorflow.internal.c_api.NativeOutput;
30+
import org.tensorflow.internal.c_api.NativeOutputVector;
3131
import org.tensorflow.internal.c_api.NativeStatus;
3232
import org.tensorflow.internal.c_api.TF_Scope;
3333
import org.tensorflow.internal.c_api.TF_Status;
@@ -46,8 +46,8 @@ public TypedGradientAdapter(CustomGradient<T> gradient, Class<T> opClass) {
4646
}
4747

4848
@Override
49-
public NativeStatus call(TF_Scope scope, NativeOperation op, NativeOutput grad_inputs,
50-
NativeOutput grad_outputs) {
49+
public NativeStatus call(TF_Scope scope, NativeOperation op, NativeOutputVector grad_inputs,
50+
NativeOutputVector grad_outputs) {
5151
try (PointerScope pointerScope = new PointerScope()) {
5252
Graph g = Graph.findGraphForPointer(scope.graph());
5353
if (g == null) {

0 commit comments

Comments
 (0)