Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/calculus.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ Precompose
PrecomposeDiagonal
Tilt
Translate
ReshapeInput
```
1 change: 1 addition & 0 deletions src/ProximalOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ include("calculus/precomposeDiagonal.jl")
include("calculus/regularize.jl")
include("calculus/separableSum.jl")
include("calculus/slicedSeparableSum.jl")
include("calculus/reshapeInput.jl")
include("calculus/sqrDistL2.jl")
include("calculus/tilt.jl")
include("calculus/translate.jl")
Expand Down
56 changes: 56 additions & 0 deletions src/calculus/reshapeInput.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# wrap a function to reshape the input

export ReshapeInput

"""
ReshapeInput(f, expected_shape)

Wrap a function to reshape the input.
It is useful when the function `f` expects a specific shape of the input, but you want to pass it a different shape.

```julia
julia> f = ReshapeInput(IndballRank(5), (10, 10))
ReshapeInput(IndBallRank{Int64}(5), (10, 10))

julia> f(rand(100))
Inf
```
"""
struct ReshapeInput{F, S}
f::F
expected_shape::S
end

function (f::ReshapeInput)(x)
if size(x) != f.expected_shape
x = reshape(x, f.expected_shape)
end
return f.f(x)
end

function prox!(y, f::ReshapeInput, x, gamma)
if size(x) != f.expected_shape
x = reshape(x, f.expected_shape)
end
if size(y) != f.expected_shape
y = reshape(y, f.expected_shape)
end
return prox!(y, f.f, x, gamma)
end

function gradient!(y, f::ReshapeInput, x)
if size(x) != f.expected_shape
x = reshape(x, f.expected_shape)
end
if size(y) != f.expected_shape
y = reshape(y, f.expected_shape)
end
return gradient!(y, f.f, x)
end

function prox_naive(f::ReshapeInput, x, gamma)
if size(x) != f.expected_shape
x = reshape(x, f.expected_shape)
end
return prox_naive(f.f, x, gamma)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ end
include("test_separableSum.jl")
include("test_slicedSeparableSum.jl")
include("test_sum.jl")
include("test_reshapeInput.jl")
end

@testset "Equivalences" begin
Expand Down
162 changes: 162 additions & 0 deletions test/test_reshapeInput.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
using LinearAlgebra
using ProximalOperators
using Test

# Define a simple test function that we can use with ReshapeInput
struct SimpleTestFunc end

# Make it callable - returns squared norm
# This function requires 2D input (matrix), and will error for vectors or higher-dimensional arrays
function (::SimpleTestFunc)(x)
if ndims(x) != 2
throw(DimensionMismatch("SimpleTestFunc requires 2D input (matrix), got $(ndims(x))D array"))
end
return sum(abs2, x)
end

# Define a prox! method for SimpleTestFunc
function ProximalOperators.prox!(y, f::SimpleTestFunc, x, gamma)
if ndims(x) != 2
throw(DimensionMismatch("SimpleTestFunc requires 2D input (matrix), got $(ndims(x))D array"))
end
# Simple soft-thresholding prox: prox(||·||^2) = x / (1 + 2*gamma)
y .= x ./ (1 + 2 * gamma)
return sum(abs2, y)
end

# Define a gradient! method for SimpleTestFunc
function ProximalOperators.gradient!(y, f::SimpleTestFunc, x)
if ndims(x) != 2
throw(DimensionMismatch("SimpleTestFunc requires 2D input (matrix), got $(ndims(x))D array"))
end
# Gradient of squared norm: 2*x
y .= 2 .* x
return sum(abs2, y)
end



@testset "ReshapeInput Tests" begin

@testset "Basic Function Call with Correct Shape" begin
# Create a ReshapeInput wrapper
f = ReshapeInput(SimpleTestFunc(), (2, 2))

# Create input with correct shape
x = reshape(1.0:4.0, 2, 2)
result = f(x)

# Should return squared norm of all elements: 1 + 4 + 9 + 16 = 30
expected = sum(abs2, x)
@test result ≈ expected
end

@testset "Function Call with Shape Reshaping" begin
# Create a ReshapeInput wrapper expecting (2, 2)
f = ReshapeInput(SimpleTestFunc(), (2, 2))

# Create input as a vector (different shape)
x = vec(reshape(1.0:4.0, 2, 2)) # [1, 2, 3, 4]
result = f(x)

# Should reshape to (2, 2) internally and compute squared norm
x_reshaped = reshape(x, 2, 2)
expected = sum(abs2, x_reshaped)
@test result ≈ expected
end

@testset "Function Call with Multiple Reshaping" begin
# Create a ReshapeInput wrapper expecting (3, 4)
f = ReshapeInput(SimpleTestFunc(), (3, 4))

# Create input as a vector of 12 elements
x = collect(1.0:12.0)
result = f(x)

# Should reshape to (3, 4) and compute squared norm
x_reshaped = reshape(x, 3, 4)
expected = sum(abs2, x_reshaped)
@test result ≈ expected
end

@testset "prox! with Correct Shape" begin
# Create a ReshapeInput wrapper
f = ReshapeInput(SimpleTestFunc(), (2, 2))

# Create input and output with correct shape
x = reshape(1.0:4.0, 2, 2)
y = zeros(2, 2)
gamma = 0.5

result = prox!(y, f, x, gamma)

# prox of squared norm with soft-thresholding
expected_y = x ./ (1 + 2 * gamma)
expected_result = sum(abs2, expected_y)

@test y ≈ expected_y
@test result ≈ expected_result
end

@testset "prox! with Shape Reshaping" begin
# Create a ReshapeInput wrapper expecting (2, 2)
f = ReshapeInput(SimpleTestFunc(), (2, 2))

# Create input and output as vectors
x = collect(1.0:4.0)
y = zeros(4)
gamma = 0.5

result = prox!(y, f, x, gamma)

# Should internally reshape to (2, 2)
x_reshaped = reshape(x, 2, 2)
expected_y_reshaped = x_reshaped ./ (1 + 2 * gamma)
expected_result = sum(abs2, expected_y_reshaped)

# y should contain the reshaped result flattened back
y_expected = vec(expected_y_reshaped)
@test y ≈ y_expected
@test result ≈ expected_result
end

@testset "gradient! with Correct Shape" begin
# Create a ReshapeInput wrapper
f = ReshapeInput(SimpleTestFunc(), (2, 2))

# Create input and output with correct shape
x = reshape(1.0:4.0, 2, 2)
y = zeros(2, 2)

result = gradient!(y, f, x)

# Gradient of squared norm: 2*x
expected_y = 2 .* x
expected_result = sum(abs2, expected_y)

@test y ≈ expected_y
@test result ≈ expected_result
end

@testset "gradient! with Shape Reshaping" begin
# Create a ReshapeInput wrapper expecting (2, 2)
f = ReshapeInput(SimpleTestFunc(), (2, 2))

# Create input and output as vectors
x = collect(1.0:4.0)
y = zeros(4)

result = gradient!(y, f, x)

# Should internally reshape to (2, 2)
x_reshaped = reshape(x, 2, 2)
expected_y_reshaped = 2 .* x_reshaped
expected_result = sum(abs2, expected_y_reshaped)

# y should contain the reshaped result flattened back
y_expected = vec(expected_y_reshaped)
@test y ≈ y_expected
@test result ≈ expected_result
end

end
Loading