From c380b31b8daa484320d610901a14158567ebd598 Mon Sep 17 00:00:00 2001 From: Tamas Hakkel Date: Tue, 22 Apr 2025 20:41:23 +0200 Subject: [PATCH] add reshapeInput function --- docs/src/calculus.md | 1 + src/ProximalOperators.jl | 1 + src/calculus/reshapeInput.jl | 56 ++++++++++++ test/runtests.jl | 1 + test/test_reshapeInput.jl | 162 +++++++++++++++++++++++++++++++++++ 5 files changed, 221 insertions(+) create mode 100644 src/calculus/reshapeInput.jl create mode 100644 test/test_reshapeInput.jl diff --git a/docs/src/calculus.md b/docs/src/calculus.md index 7468a1d9..c0e32969 100644 --- a/docs/src/calculus.md +++ b/docs/src/calculus.md @@ -32,4 +32,5 @@ Precompose PrecomposeDiagonal Tilt Translate +ReshapeInput ``` diff --git a/src/ProximalOperators.jl b/src/ProximalOperators.jl index bd3fc6e7..a1f11a75 100644 --- a/src/ProximalOperators.jl +++ b/src/ProximalOperators.jl @@ -91,6 +91,7 @@ include("calculus/precomposeDiagonal.jl") include("calculus/regularize.jl") include("calculus/separableSum.jl") include("calculus/slicedSeparableSum.jl") +include("calculus/reshapeInput.jl") include("calculus/sqrDistL2.jl") include("calculus/tilt.jl") include("calculus/translate.jl") diff --git a/src/calculus/reshapeInput.jl b/src/calculus/reshapeInput.jl new file mode 100644 index 00000000..b5cfa4f9 --- /dev/null +++ b/src/calculus/reshapeInput.jl @@ -0,0 +1,56 @@ +# wrap a function to reshape the input + +export ReshapeInput + +""" + ReshapeInput(f, expected_shape) + +Wrap a function to reshape the input. +It is useful when the function `f` expects a specific shape of the input, but you want to pass it a different shape. + +```julia +julia> f = ReshapeInput(IndballRank(5), (10, 10)) +ReshapeInput(IndBallRank{Int64}(5), (10, 10)) + +julia> f(rand(100)) +Inf +``` +""" +struct ReshapeInput{F, S} + f::F + expected_shape::S +end + +function (f::ReshapeInput)(x) + if size(x) != f.expected_shape + x = reshape(x, f.expected_shape) + end + return f.f(x) +end + +function prox!(y, f::ReshapeInput, x, gamma) + if size(x) != f.expected_shape + x = reshape(x, f.expected_shape) + end + if size(y) != f.expected_shape + y = reshape(y, f.expected_shape) + end + return prox!(y, f.f, x, gamma) +end + +function gradient!(y, f::ReshapeInput, x) + if size(x) != f.expected_shape + x = reshape(x, f.expected_shape) + end + if size(y) != f.expected_shape + y = reshape(y, f.expected_shape) + end + return gradient!(y, f.f, x) +end + +function prox_naive(f::ReshapeInput, x, gamma) + if size(x) != f.expected_shape + x = reshape(x, f.expected_shape) + end + return prox_naive(f.f, x, gamma) +end diff --git a/test/runtests.jl b/test/runtests.jl index 3ece6ad5..58a4e070 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -156,6 +156,7 @@ end include("test_separableSum.jl") include("test_slicedSeparableSum.jl") include("test_sum.jl") + include("test_reshapeInput.jl") end @testset "Equivalences" begin diff --git a/test/test_reshapeInput.jl b/test/test_reshapeInput.jl new file mode 100644 index 00000000..1169f0b4 --- /dev/null +++ b/test/test_reshapeInput.jl @@ -0,0 +1,162 @@ +using LinearAlgebra +using ProximalOperators +using Test + +# Define a simple test function that we can use with ReshapeInput +struct SimpleTestFunc end + +# Make it callable - returns squared norm +# This function requires 2D input (matrix), and will error for vectors or higher-dimensional arrays +function (::SimpleTestFunc)(x) + if ndims(x) != 2 + throw(DimensionMismatch("SimpleTestFunc requires 2D input (matrix), got $(ndims(x))D array")) + end + return sum(abs2, x) +end + +# Define a prox! method for SimpleTestFunc +function ProximalOperators.prox!(y, f::SimpleTestFunc, x, gamma) + if ndims(x) != 2 + throw(DimensionMismatch("SimpleTestFunc requires 2D input (matrix), got $(ndims(x))D array")) + end + # Simple soft-thresholding prox: prox(||·||^2) = x / (1 + 2*gamma) + y .= x ./ (1 + 2 * gamma) + return sum(abs2, y) +end + +# Define a gradient! method for SimpleTestFunc +function ProximalOperators.gradient!(y, f::SimpleTestFunc, x) + if ndims(x) != 2 + throw(DimensionMismatch("SimpleTestFunc requires 2D input (matrix), got $(ndims(x))D array")) + end + # Gradient of squared norm: 2*x + y .= 2 .* x + return sum(abs2, y) +end + + + +@testset "ReshapeInput Tests" begin + + @testset "Basic Function Call with Correct Shape" begin + # Create a ReshapeInput wrapper + f = ReshapeInput(SimpleTestFunc(), (2, 2)) + + # Create input with correct shape + x = reshape(1.0:4.0, 2, 2) + result = f(x) + + # Should return squared norm of all elements: 1 + 4 + 9 + 16 = 30 + expected = sum(abs2, x) + @test result ≈ expected + end + + @testset "Function Call with Shape Reshaping" begin + # Create a ReshapeInput wrapper expecting (2, 2) + f = ReshapeInput(SimpleTestFunc(), (2, 2)) + + # Create input as a vector (different shape) + x = vec(reshape(1.0:4.0, 2, 2)) # [1, 2, 3, 4] + result = f(x) + + # Should reshape to (2, 2) internally and compute squared norm + x_reshaped = reshape(x, 2, 2) + expected = sum(abs2, x_reshaped) + @test result ≈ expected + end + + @testset "Function Call with Multiple Reshaping" begin + # Create a ReshapeInput wrapper expecting (3, 4) + f = ReshapeInput(SimpleTestFunc(), (3, 4)) + + # Create input as a vector of 12 elements + x = collect(1.0:12.0) + result = f(x) + + # Should reshape to (3, 4) and compute squared norm + x_reshaped = reshape(x, 3, 4) + expected = sum(abs2, x_reshaped) + @test result ≈ expected + end + + @testset "prox! with Correct Shape" begin + # Create a ReshapeInput wrapper + f = ReshapeInput(SimpleTestFunc(), (2, 2)) + + # Create input and output with correct shape + x = reshape(1.0:4.0, 2, 2) + y = zeros(2, 2) + gamma = 0.5 + + result = prox!(y, f, x, gamma) + + # prox of squared norm with soft-thresholding + expected_y = x ./ (1 + 2 * gamma) + expected_result = sum(abs2, expected_y) + + @test y ≈ expected_y + @test result ≈ expected_result + end + + @testset "prox! with Shape Reshaping" begin + # Create a ReshapeInput wrapper expecting (2, 2) + f = ReshapeInput(SimpleTestFunc(), (2, 2)) + + # Create input and output as vectors + x = collect(1.0:4.0) + y = zeros(4) + gamma = 0.5 + + result = prox!(y, f, x, gamma) + + # Should internally reshape to (2, 2) + x_reshaped = reshape(x, 2, 2) + expected_y_reshaped = x_reshaped ./ (1 + 2 * gamma) + expected_result = sum(abs2, expected_y_reshaped) + + # y should contain the reshaped result flattened back + y_expected = vec(expected_y_reshaped) + @test y ≈ y_expected + @test result ≈ expected_result + end + + @testset "gradient! with Correct Shape" begin + # Create a ReshapeInput wrapper + f = ReshapeInput(SimpleTestFunc(), (2, 2)) + + # Create input and output with correct shape + x = reshape(1.0:4.0, 2, 2) + y = zeros(2, 2) + + result = gradient!(y, f, x) + + # Gradient of squared norm: 2*x + expected_y = 2 .* x + expected_result = sum(abs2, expected_y) + + @test y ≈ expected_y + @test result ≈ expected_result + end + + @testset "gradient! with Shape Reshaping" begin + # Create a ReshapeInput wrapper expecting (2, 2) + f = ReshapeInput(SimpleTestFunc(), (2, 2)) + + # Create input and output as vectors + x = collect(1.0:4.0) + y = zeros(4) + + result = gradient!(y, f, x) + + # Should internally reshape to (2, 2) + x_reshaped = reshape(x, 2, 2) + expected_y_reshaped = 2 .* x_reshaped + expected_result = sum(abs2, expected_y_reshaped) + + # y should contain the reshaped result flattened back + y_expected = vec(expected_y_reshaped) + @test y ≈ y_expected + @test result ≈ expected_result + end + +end