diff --git a/RELEASENOTES.md b/RELEASENOTES.md
index b7993c6d2..3ce3e3690 100644
--- a/RELEASENOTES.md
+++ b/RELEASENOTES.md
@@ -6,6 +6,7 @@ Releases, starting with 9/2/2021, are listed with the most recent release at the
__API Changes__:
Fix `torch.jit.ScriptModule.zero_grad`.
+Add ReadOnlySpan overloads to many methods.
# NuGet Version 0.105.2
diff --git a/src/TorchSharp/Autograd.cs b/src/TorchSharp/Autograd.cs
index 4c73fce46..c043225da 100644
--- a/src/TorchSharp/Autograd.cs
+++ b/src/TorchSharp/Autograd.cs
@@ -135,9 +135,9 @@ public static IList grad(IList outputs, IList inputs, IL
using var grads = new PinnedArray();
using var results = new PinnedArray();
- IntPtr outsRef = outs.CreateArray(outputs.Select(p => p.Handle).ToArray());
- IntPtr insRef = ins.CreateArray(inputs.Select(p => p.Handle).ToArray());
- IntPtr gradsRef = grad_outputs == null ? IntPtr.Zero : grads.CreateArray(grad_outputs.Select(p => p.Handle).ToArray());
+ IntPtr outsRef = outs.CreateArray(outputs.ToHandleArray());
+ IntPtr insRef = ins.CreateArray(inputs.ToHandleArray());
+ IntPtr gradsRef = grad_outputs == null ? IntPtr.Zero : grads.CreateArray(grad_outputs.ToHandleArray());
long gradsLength = grad_outputs == null ? 0 : grads.Array.Length;
THSAutograd_grad(outsRef, outs.Array.Length, insRef, ins.Array.Length, gradsRef, gradsLength, retain_graph, create_graph, allow_unused, results.CreateArray);
@@ -178,9 +178,9 @@ public static void backward(IList tensors, IList grad_tensors =
using var ts = new PinnedArray();
using var gts = new PinnedArray();
using var ins = new PinnedArray();
- IntPtr tensRef = ts.CreateArray(tensors.Select(p => p.Handle).ToArray());
- IntPtr gradsRef = grad_tensors == null ? IntPtr.Zero : gts.CreateArray(grad_tensors.Select(p => p.Handle).ToArray());
- IntPtr insRef = inputs == null ? IntPtr.Zero : ins.CreateArray(inputs.Select(p => p.Handle).ToArray());
+ IntPtr tensRef = ts.CreateArray(tensors.ToHandleArray());
+ IntPtr gradsRef = grad_tensors == null ? IntPtr.Zero : gts.CreateArray(grad_tensors.ToHandleArray());
+ IntPtr insRef = inputs == null ? IntPtr.Zero : ins.CreateArray(inputs.ToHandleArray());
long insLength = inputs == null ? 0 : ins.Array.Length;
long gradsLength = grad_tensors == null ? 0 : gts.Array.Length;
diff --git a/src/TorchSharp/AutogradFunction.cs b/src/TorchSharp/AutogradFunction.cs
index 390ce94c6..49adb2ee1 100644
--- a/src/TorchSharp/AutogradFunction.cs
+++ b/src/TorchSharp/AutogradFunction.cs
@@ -148,7 +148,7 @@ internal List ComputeVariableInput(object[] args)
internal void SetNextEdges(List inputVars, bool isExecutable)
{
using var l = new PinnedArray();
- THSAutograd_CSharpNode_setNextEdges(handle, l.CreateArrayWithSize(inputVars.Select(v => v.Handle).ToArray()), isExecutable);
+ THSAutograd_CSharpNode_setNextEdges(handle, l.CreateArrayWithSize(inputVars.ToHandleArray()), isExecutable);
CheckForErrors();
}
@@ -166,10 +166,10 @@ internal List WrapOutputs(List inputVars, List outputs,
using var outputArr = new PinnedArray();
using var resultsArr = new PinnedArray();
- var varsPtr = varsArr.CreateArrayWithSize(inputVars.Select(v => v.Handle).ToArray());
- var diffsPtr = diffArr.CreateArrayWithSize(_context.NonDifferentiableTensors.Select(v => v.Handle).ToArray());
- var dirtyPtr = diffArr.CreateArrayWithSize(_context.DirtyTensors.Select(v => v.Handle).ToArray());
- var outputPtr = outputArr.CreateArrayWithSize(outputs.Select(v => v.Handle).ToArray());
+ var varsPtr = varsArr.CreateArrayWithSize(inputVars.ToHandleArray());
+ var diffsPtr = diffArr.CreateArrayWithSize(_context.NonDifferentiableTensors.ToHandleArray());
+ var dirtyPtr = diffArr.CreateArrayWithSize(_context.DirtyTensors.ToHandleArray());
+ var outputPtr = outputArr.CreateArrayWithSize(outputs.ToHandleArray());
THSAutograd_Function_wrapOutputs(varsPtr, diffsPtr, dirtyPtr, outputPtr, isExecutable ? handle : new(), resultsArr.CreateArray);
CheckForErrors();
diff --git a/src/TorchSharp/LinearAlgebra.cs b/src/TorchSharp/LinearAlgebra.cs
index 436650ac7..91a22e3b2 100644
--- a/src/TorchSharp/LinearAlgebra.cs
+++ b/src/TorchSharp/LinearAlgebra.cs
@@ -444,7 +444,7 @@ public static Tensor multi_dot(IList tensors)
}
using (var parray = new PinnedArray()) {
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
+ IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray());
var res = THSLinalg_multi_dot(tensorsRef, parray.Array.Length);
if (res == IntPtr.Zero)
torch.CheckForErrors();
diff --git a/src/TorchSharp/NN/Utils/RNNUtils.cs b/src/TorchSharp/NN/Utils/RNNUtils.cs
index ab0b62cc5..eb486a912 100644
--- a/src/TorchSharp/NN/Utils/RNNUtils.cs
+++ b/src/TorchSharp/NN/Utils/RNNUtils.cs
@@ -55,7 +55,7 @@ public static (torch.Tensor, torch.Tensor) pad_packed_sequence(PackedSequence se
/// The padded tensor
public static torch.Tensor pad_sequence(IEnumerable sequences, bool batch_first = false, double padding_value = 0.0)
{
- var sequences_arg = sequences.Select(p => p.Handle).ToArray();
+ var sequences_arg = sequences.ToHandleArray();
var res = THSNN_pad_sequence(sequences_arg, sequences_arg.Length, batch_first, padding_value);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new torch.Tensor(res);
@@ -69,7 +69,7 @@ public static torch.Tensor pad_sequence(IEnumerable sequences, boo
/// The packed batch of variable length sequences
public static PackedSequence pack_sequence(IEnumerable sequences, bool enforce_sorted = true)
{
- var sequences_arg = sequences.Select(p => p.Handle).ToArray();
+ var sequences_arg = sequences.ToHandleArray();
var res = THSNN_pack_sequence(sequences_arg, sequences_arg.Length, enforce_sorted);
if (res.IsInvalid) { torch.CheckForErrors(); }
return new PackedSequence(res);
diff --git a/src/TorchSharp/Optimizers/LBFGS.cs b/src/TorchSharp/Optimizers/LBFGS.cs
index 1249b5ba5..a06424dce 100644
--- a/src/TorchSharp/Optimizers/LBFGS.cs
+++ b/src/TorchSharp/Optimizers/LBFGS.cs
@@ -47,7 +47,7 @@ public static LBFGS LBFGS(IEnumerable parameters, double lr = 0.01, l
if (!max_eval.HasValue) max_eval = 5 * max_iter / 4;
using var parray = new PinnedArray();
- IntPtr paramsRef = parray.CreateArray(parameters.Select(p => p.Handle).ToArray());
+ IntPtr paramsRef = parray.CreateArray(parameters.ToHandleArray());
var res = THSNN_LBFGS_ctor(paramsRef, parray.Array.Length, lr, max_iter, max_eval.Value, tolerange_grad, tolerance_change, history_size);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
diff --git a/src/TorchSharp/Tensor/torch.IndexingSlicingJoiningMutatingOps.cs b/src/TorchSharp/Tensor/torch.IndexingSlicingJoiningMutatingOps.cs
index c55ec9f4c..ce70d2d23 100644
--- a/src/TorchSharp/Tensor/torch.IndexingSlicingJoiningMutatingOps.cs
+++ b/src/TorchSharp/Tensor/torch.IndexingSlicingJoiningMutatingOps.cs
@@ -44,7 +44,39 @@ public static Tensor cat(IList tensors, long dim = 0)
}
using var parray = new PinnedArray();
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
+ IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray());
+
+ var res = THSTensor_cat(tensorsRef, parray.Array.Length, dim);
+ if (res == IntPtr.Zero) CheckForErrors();
+ return new Tensor(res);
+ }
+
+ // https://pytorch.org/docs/stable/generated/torch.cat
+ ///
+ /// Concatenates the given sequence of tensors in the given dimension.
+ ///
+ /// A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.
+ /// The dimension over which the tensors are concatenated
+ /// All tensors must either have the same shape (except in the concatenating dimension) or be empty.
+ public static Tensor cat(Tensor[] tensors, long dim = 0) => torch.cat((ReadOnlySpan)tensors, dim);
+
+ // https://pytorch.org/docs/stable/generated/torch.cat
+ ///
+ /// Concatenates the given sequence of tensors in the given dimension.
+ ///
+ /// A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.
+ /// The dimension over which the tensors are concatenated
+ /// All tensors must either have the same shape (except in the concatenating dimension) or be empty.
+ public static Tensor cat(ReadOnlySpan tensors, long dim = 0)
+ {
+ switch (tensors.Length)
+ {
+ case <=0: throw new ArgumentException(nameof(tensors));
+ case 1: return tensors[0].alias();
+ }
+
+ using var parray = new PinnedArray();
+ IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray());
var res = THSTensor_cat(tensorsRef, parray.Array.Length, dim);
if (res == IntPtr.Zero) CheckForErrors();
@@ -60,6 +92,24 @@ public static Tensor cat(IList tensors, long dim = 0)
/// All tensors must either have the same shape (except in the concatenating dimension) or be empty.
public static Tensor concat(IList tensors, long dim = 0) => torch.cat(tensors, dim);
+ // https://pytorch.org/docs/stable/generated/torch.concat
+ ///
+ /// Alias of torch.cat()
+ ///
+ /// A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.
+ /// The dimension over which the tensors are concatenated
+ /// All tensors must either have the same shape (except in the concatenating dimension) or be empty.
+ public static Tensor concat(Tensor[] tensors, long dim = 0) => torch.cat(tensors, dim);
+
+ // https://pytorch.org/docs/stable/generated/torch.concat
+ ///
+ /// Alias of torch.cat()
+ ///
+ /// A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.
+ /// The dimension over which the tensors are concatenated
+ /// All tensors must either have the same shape (except in the concatenating dimension) or be empty.
+ public static Tensor concat(ReadOnlySpan tensors, long dim = 0) => torch.cat(tensors, dim);
+
// https://pytorch.org/docs/stable/generated/torch.conj
///
/// Returns a view of input with a flipped conjugate bit. If input has a non-complex dtype, this function just returns input.
@@ -99,45 +149,53 @@ public static Tensor[] dsplit(Tensor input, (long, long, long, long) indices_or_
///
/// Stack tensors in sequence depthwise (along third axis).
///
- ///
- ///
+ /// An array of input tensors.
+ /// A tensor containing the input tensors stacked along the third axis (depth-wise).
/// This is equivalent to concatenation along the third axis after 1-D and 2-D tensors have been reshaped by torch.atleast_3d().
public static Tensor dstack(params Tensor[] tensors)
- => dstack((IEnumerable)tensors);
+ => dstack(tensors.ToHandleArray());
// https://pytorch.org/docs/stable/generated/torch.dstack
///
/// Stack tensors in sequence depthwise (along third axis).
///
- ///
- ///
+ /// A list of input tensors.
+ /// A tensor containing the input tensors stacked along the third axis (depth-wise).
/// This is equivalent to concatenation along the third axis after 1-D and 2-D tensors have been reshaped by torch.atleast_3d().
public static Tensor dstack(IList tensors)
- {
- using (var parray = new PinnedArray()) {
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
+ => dstack(tensors.ToHandleArray());
- var res = THSTensor_dstack(tensorsRef, parray.Array.Length);
- if (res == IntPtr.Zero) { torch.CheckForErrors(); }
- return new Tensor(res);
- }
- }
+ // https://pytorch.org/docs/stable/generated/torch.dstack
+ ///
+ /// Stack tensors in sequence depthwise (along third axis).
+ ///
+ /// A span of input tensors.
+ /// A tensor containing the input tensors stacked along the third axis (depth-wise).
+ /// This is equivalent to concatenation along the third axis after 1-D and 2-D tensors have been reshaped by torch.atleast_3d().
+ public static Tensor dstack(ReadOnlySpan tensors)
+ => dstack(tensors.ToHandleArray());
+ // https://pytorch.org/docs/stable/generated/torch.dstack
///
/// Stack tensors in sequence depthwise (along third axis).
///
- ///
- ///
+ /// A sequence of input tensors.
+ /// A tensor containing the input tensors stacked along the third axis (depth-wise).
/// This is equivalent to concatenation along the third axis after 1-D and 2-D tensors have been reshaped by torch.atleast_3d().
public static Tensor dstack(IEnumerable tensors)
+ => dstack(tensors.ToHandleArray());
+
+ static Tensor dstack(IntPtr[] tensors)
{
- using var parray = new PinnedArray();
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
- var res = THSTensor_dstack(tensorsRef, parray.Array.Length);
- if (res == IntPtr.Zero) { CheckForErrors(); }
- return new Tensor(res);
+ using (var parray = new PinnedArray()) {
+ IntPtr tensorsRef = parray.CreateArray(tensors);
+
+ var res = THSTensor_dstack(tensorsRef, parray.Array.Length);
+ if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ return new Tensor(res);
+ }
}
-
+
// https://pytorch.org/docs/stable/generated/torch.gather
///
/// Gathers values along an axis specified by dim.
@@ -189,39 +247,42 @@ public static Tensor[] hsplit(Tensor input, (long, long, long, long) indices_or_
///
/// Stack tensors in sequence horizontally (column wise).
///
- ///
- ///
+ /// A list of input tensors.
+ /// A tensor containing the input tensors stacked horizontally (column-wise).
public static Tensor hstack(IList tensors)
- {
- using var parray = new PinnedArray();
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
-
- var res = THSTensor_hstack(tensorsRef, parray.Array.Length);
- if (res == IntPtr.Zero) { torch.CheckForErrors(); }
- return new Tensor(res);
- }
+ => hstack(tensors.ToHandleArray());
// https://pytorch.org/docs/stable/generated/torch.hstack
///
/// Stack tensors in sequence horizontally (column wise).
///
- ///
- ///
+ /// An array of input tensors.
+ /// A tensor containing the input tensors stacked horizontally (column-wise).
public static Tensor hstack(params Tensor[] tensors)
- {
- return hstack((IEnumerable)tensors);
- }
+ => hstack(tensors.ToHandleArray());
// https://pytorch.org/docs/stable/generated/torch.hstack
///
/// Stack tensors in sequence horizontally (column wise).
///
- ///
- ///
+ /// A sequence of input tensors.
+ /// A tensor containing the input tensors stacked horizontally (column-wise).
public static Tensor hstack(IEnumerable tensors)
+ => hstack(tensors.ToHandleArray());
+
+ // https://pytorch.org/docs/stable/generated/torch.hstack
+ ///
+ /// Stack tensors in sequence horizontally (column wise).
+ ///
+ /// A span of input tensors.
+ /// A tensor containing the input tensors stacked horizontally (column-wise).
+ public static Tensor hstack(ReadOnlySpan tensors)
+ => hstack(tensors.ToHandleArray());
+
+ static Tensor hstack(IntPtr[] tensors)
{
using var parray = new PinnedArray();
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
+ IntPtr tensorsRef = parray.CreateArray(tensors);
var res = THSTensor_hstack(tensorsRef, parray.Array.Length);
if (res == IntPtr.Zero) { CheckForErrors(); }
@@ -474,7 +535,7 @@ public static Tensor[] split(Tensor tensor, long[] split_size_or_sections, long
public static Tensor stack(IEnumerable tensors, long dim = 0)
{
using var parray = new PinnedArray();
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
+ IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray());
var res = THSTensor_stack(tensorsRef, parray.Array.Length, dim);
if (res == IntPtr.Zero) { CheckForErrors(); }
@@ -557,12 +618,33 @@ public static Tensor[] vsplit(Tensor input, long[] indices_or_sections)
///
/// Stack tensors in sequence vertically (row wise).
///
- ///
- ///
+ /// A list of input tensors.
+ /// A tensor containing the input tensors stacked vertically (row-wise).
public static Tensor vstack(IList tensors)
+ => vstack(tensors.ToHandleArray());
+
+ // https://pytorch.org/docs/stable/generated/torch.vstack
+ ///
+ /// Stack tensors in sequence vertically (row wise).
+ ///
+ /// An array of input tensors.
+ /// A tensor containing the input tensors stacked vertically (row-wise).
+ public static Tensor vstack(Tensor[] tensors)
+ => vstack(tensors.ToHandleArray());
+
+ // https://pytorch.org/docs/stable/generated/torch.vstack
+ ///
+ /// Stack tensors in sequence vertically (row wise).
+ ///
+ /// A span of input tensors.
+ /// A tensor containing the input tensors stacked vertically (row-wise).
+ public static Tensor vstack(ReadOnlySpan tensors)
+ => vstack(tensors.ToHandleArray());
+
+ static Tensor vstack(IntPtr[] tensors)
{
using var parray = new PinnedArray();
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
+ IntPtr tensorsRef = parray.CreateArray(tensors);
var res = THSTensor_vstack(tensorsRef, parray.Array.Length);
if (res == IntPtr.Zero) { CheckForErrors(); }
diff --git a/src/TorchSharp/Tensor/torch.OtherOperations.cs b/src/TorchSharp/Tensor/torch.OtherOperations.cs
index 4edfcf715..b4b092c4f 100644
--- a/src/TorchSharp/Tensor/torch.OtherOperations.cs
+++ b/src/TorchSharp/Tensor/torch.OtherOperations.cs
@@ -45,7 +45,7 @@ public static partial class torch
public static Tensor block_diag(params Tensor[] tensors)
{
using var parray = new PinnedArray();
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
+ IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray());
var res = THSTensor_block_diag(tensorsRef, parray.Array.Length);
if (res == IntPtr.Zero) { CheckForErrors(); }
@@ -71,7 +71,7 @@ public static IList broadcast_tensors(params Tensor[] tensors)
using (var pa = new PinnedArray())
using (var parray = new PinnedArray()) {
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
+ IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray());
THSTensor_broadcast_tensors(tensorsRef, tensors.Length, pa.CreateArray);
CheckForErrors();
@@ -124,24 +124,36 @@ public static Tensor bucketize(Tensor input, Tensor boundaries, bool outInt32 =
///
/// Do cartesian product of the given sequence of tensors.
///
- ///
- public static Tensor cartesian_prod(IList tensors)
+ /// A list of input tensors.
+ /// A tensor containing the Cartesian product of the input .
+ public static Tensor cartesian_prod(IList tensors) => cartesian_prod(tensors.ToHandleArray());
+
+ // https://pytorch.org/docs/stable/generated/torch.cartesian_prod
+ ///
+ /// Do cartesian product of the given sequence of tensors.
+ ///
+ /// An array of input tensors.
+ /// A tensor containing the Cartesian product of the input .
+ public static Tensor cartesian_prod(params Tensor[] tensors) => cartesian_prod(tensors.ToHandleArray());
+
+ // https://pytorch.org/docs/stable/generated/torch.cartesian_prod
+ ///
+ /// Do cartesian product of the given sequence of tensors.
+ ///
+ /// A span of input tensors.
+ /// A tensor containing the Cartesian product of the input .
+ public static Tensor cartesian_prod(ReadOnlySpan tensors) => cartesian_prod(tensors.ToHandleArray());
+
+ static Tensor cartesian_prod(IntPtr[] tensors)
{
using var parray = new PinnedArray();
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
+ IntPtr tensorsRef = parray.CreateArray(tensors);
var res = THSTensor_cartesian_prod(tensorsRef, parray.Array.Length);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Tensor(res);
}
- // https://pytorch.org/docs/stable/generated/torch.cartesian_prod
- ///
- /// Do cartesian product of the given sequence of tensors.
- ///
- ///
- public static Tensor cartesian_prod(params Tensor[] tensors) => cartesian_prod((IList)tensors);
-
// https://pytorch.org/docs/stable/generated/torch.cdist
///
/// Computes batched the p-norm distance between each pair of the two collections of row vectors.
@@ -350,7 +362,7 @@ public static Tensor diag_embed(Tensor input, long offset = 0L, long dim1 = -2L,
public static Tensor einsum(string equation, params Tensor[] tensors)
{
using var parray = new PinnedArray();
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
+ IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray());
var res = THSTensor_einsum(equation, tensorsRef, parray.Array.Length);
if (res == IntPtr.Zero) { CheckForErrors(); }
@@ -512,7 +524,7 @@ public static Tensor[] meshgrid(IEnumerable tensors, string indexing = "
IntPtr[] ptrArray;
using (var parray = new PinnedArray()) {
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
+ IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray());
_ = THSTensor_meshgrid(tensorsRef, parray.Array.Length, indexing, parray.CreateArray);
CheckForErrors();
ptrArray = parray.Array;
diff --git a/src/TorchSharp/Tensor/torch.cs b/src/TorchSharp/Tensor/torch.cs
index 6892d2b69..856df4dd2 100644
--- a/src/TorchSharp/Tensor/torch.cs
+++ b/src/TorchSharp/Tensor/torch.cs
@@ -29,10 +29,28 @@ public static partial class torch
///
/// A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.
/// The dimension over which the tensors are concatenated
- ///
+ /// A tensor resulting from concatenating the input tensors along .
/// All tensors must either have the same shape (except in the concatenating dimension) or be empty.
public static Tensor concatenate(IList tensors, long axis = 0) => torch.cat(tensors, axis);
+ ///
+ /// Concatenates the given sequence of tensors along the given axis (dimension).
+ ///
+ /// A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.
+ /// The dimension over which the tensors are concatenated
+ /// A tensor resulting from concatenating the input tensors along .
+ /// All tensors must either have the same shape (except in the concatenating dimension) or be empty.
+ public static Tensor concatenate(Tensor[] tensors, long axis = 0) => torch.cat(tensors, axis);
+
+ ///
+ /// Concatenates the given sequence of tensors along the given axis (dimension).
+ ///
+ /// A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.
+ /// The dimension over which the tensors are concatenated
+ /// A tensor resulting from concatenating the input tensors along .
+ /// All tensors must either have the same shape (except in the concatenating dimension) or be empty.
+ public static Tensor concatenate(ReadOnlySpan tensors, long axis = 0) => torch.cat(tensors, axis);
+
///
/// Returns a tensor with all the dimensions of input of size 1 removed. When dim is given, a squeeze operation is done only in the given dimension.
///
@@ -53,12 +71,30 @@ public static partial class torch
/// Creates a new tensor by horizontally stacking the input tensors.
///
/// A list of input tensors.
- ///
+ /// A tensor formed by horizontally stacking the inputs. Zero- or one-dimensional tensors are first reshaped into (numel, 1) columns.
+ /// Equivalent to torch.hstack(tensors), except each zero or one dimensional tensor t in tensors is first reshaped into a (t.numel(), 1) column before being stacked horizontally.
+ public static Tensor column_stack(IList tensors) => column_stack(tensors.ToHandleArray());
+
+ ///
+ /// Creates a new tensor by horizontally stacking the input tensors.
+ ///
+ /// An array of input tensors.
+ /// A tensor formed by horizontally stacking the inputs. Zero- or one-dimensional tensors are first reshaped into (numel, 1) columns.
/// Equivalent to torch.hstack(tensors), except each zero or one dimensional tensor t in tensors is first reshaped into a (t.numel(), 1) column before being stacked horizontally.
- public static Tensor column_stack(IList tensors)
+ public static Tensor column_stack(params Tensor[] tensors) => column_stack(tensors.ToHandleArray());
+
+ ///
+ /// Creates a new tensor by horizontally stacking the input tensors.
+ ///
+ /// A span of input tensors.
+ /// A tensor formed by horizontally stacking the inputs. Zero- or one-dimensional tensors are first reshaped into (numel, 1) columns.
+ /// Equivalent to torch.hstack(tensors), except each zero or one dimensional tensor t in tensors is first reshaped into a (t.numel(), 1) column before being stacked horizontally.
+ public static Tensor column_stack(ReadOnlySpan tensors) => column_stack(tensors.ToHandleArray());
+
+ static Tensor column_stack(IntPtr[] tensors)
{
using var parray = new PinnedArray();
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
+ IntPtr tensorsRef = parray.CreateArray(tensors);
var res = THSTensor_column_stack(tensorsRef, parray.Array.Length);
if (res == IntPtr.Zero) { CheckForErrors(); }
@@ -66,35 +102,36 @@ public static Tensor column_stack(IList tensors)
}
///
- /// Creates a new tensor by horizontally stacking the input tensors.
+ /// Stack tensors in sequence vertically (row wise).
///
/// A list of input tensors.
- ///
- /// Equivalent to torch.hstack(tensors), except each zero or one dimensional tensor t in tensors is first reshaped into a (t.numel(), 1) column before being stacked horizontally.
- public static Tensor column_stack(params Tensor[] tensors) => column_stack((IList)tensors);
+ /// A tensor formed by stacking the inputs row-wise (vertically).
+ public static Tensor row_stack(IList tensors) => row_stack(tensors.ToHandleArray());
///
/// Stack tensors in sequence vertically (row wise).
///
- ///
- ///
- public static Tensor row_stack(IList tensors)
+ /// An array of input tensors.
+ /// A tensor formed by stacking the inputs row-wise (vertically).
+ public static Tensor row_stack(params Tensor[] tensors) => row_stack(tensors.ToHandleArray());
+
+ ///
+ /// Stack tensors in sequence vertically (row wise).
+ ///
+ /// A span of input tensors.
+ /// A tensor formed by stacking the inputs row-wise (vertically).
+ public static Tensor row_stack(ReadOnlySpan tensors) => row_stack(tensors.ToHandleArray());
+
+ static Tensor row_stack(IntPtr[] tensors)
{
using var parray = new PinnedArray();
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
+ IntPtr tensorsRef = parray.CreateArray(tensors);
var res = THSTensor_row_stack(tensorsRef, parray.Array.Length);
if (res == IntPtr.Zero) { CheckForErrors(); }
return new Tensor(res);
}
- ///
- /// Stack tensors in sequence vertically (row wise).
- ///
- ///
- ///
- public static Tensor row_stack(params Tensor[] tensors) => row_stack((IList)tensors);
-
///
/// Removes a tensor dimension.
///
diff --git a/src/TorchSharp/Torch.cs b/src/TorchSharp/Torch.cs
index c59196bfe..f35c5df71 100644
--- a/src/TorchSharp/Torch.cs
+++ b/src/TorchSharp/Torch.cs
@@ -393,7 +393,7 @@ public static partial class utils
public static double clip_grad_norm_(IEnumerable tensors, double max_norm, double norm_type = 2.0)
{
using (var parray = new PinnedArray()) {
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
+ IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray());
var value = THSTensor_clip_grad_norm_(tensorsRef, parray.Array.Length, max_norm, norm_type);
CheckForErrors();
return value;
@@ -409,7 +409,7 @@ public static double clip_grad_norm_(IEnumerable tensors, dou
public static void clip_grad_value_(IEnumerable tensors, double clip_value)
{
using (var parray = new PinnedArray()) {
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
+ IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray());
THSTensor_clip_grad_value_(tensorsRef, parray.Array.Length, clip_value);
CheckForErrors();
}
@@ -423,7 +423,7 @@ public static void clip_grad_value_(IEnumerable tensors, doub
public static Tensor parameters_to_vector(IEnumerable tensors)
{
using (var parray = new PinnedArray()) {
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
+ IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray());
var res = THSTensor_parameters_to_vector(tensorsRef, parray.Array.Length);
if (res == IntPtr.Zero)
@@ -441,7 +441,7 @@ public static Tensor parameters_to_vector(IEnumerable tensors
public static void vector_to_parameters(Tensor vec, IEnumerable tensors)
{
using (var parray = new PinnedArray()) {
- IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
+ IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray());
THSTensor_vector_to_parameters(vec.Handle, tensorsRef, parray.Array.Length);
CheckForErrors();
diff --git a/src/TorchSharp/Utils/OverloadHelper.cs b/src/TorchSharp/Utils/OverloadHelper.cs
new file mode 100644
index 000000000..14316a0a1
--- /dev/null
+++ b/src/TorchSharp/Utils/OverloadHelper.cs
@@ -0,0 +1,40 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace TorchSharp
+{
+ static class OverloadHelper
+ {
+ public static IntPtr[] ToHandleArray(this ReadOnlySpan span)
+ {
+ if (span.Length == 0)
+ return Array.Empty();
+
+ var result = new IntPtr[span.Length];
+ for (int i = 0; i < span.Length; i++)
+ result[i] = span[i].Handle;
+
+ return result;
+ }
+
+ public static IntPtr[] ToHandleArray(this IList list)
+ {
+ if (list.Count == 0)
+ return Array.Empty();
+
+ var result = new IntPtr[list.Count];
+ for (int i = 0; i < list.Count; i++)
+ result[i] = list[i].Handle;
+
+ return result;
+ }
+
+ public static IntPtr[] ToHandleArray(this torch.Tensor[] array) => ToHandleArray((ReadOnlySpan)array);
+
+ public static IntPtr[] ToHandleArray(this IEnumerable enumerable)
+ {
+ return enumerable.Select(t => t.Handle).ToArray();
+ }
+ }
+}
diff --git a/src/TorchSharp/Utils/PinnedArray.cs b/src/TorchSharp/Utils/PinnedArray.cs
index 19eb9ce63..333a69a4c 100644
--- a/src/TorchSharp/Utils/PinnedArray.cs
+++ b/src/TorchSharp/Utils/PinnedArray.cs
@@ -1,7 +1,5 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
-using System.Collections.Generic;
-using System.Reflection;
using System.Runtime.InteropServices;
namespace TorchSharp