diff --git a/src/ScaleInvariantAnalysis.jl b/src/ScaleInvariantAnalysis.jl index c272407..7173854 100644 --- a/src/ScaleInvariantAnalysis.jl +++ b/src/ScaleInvariantAnalysis.jl @@ -24,7 +24,7 @@ end """ - a = symscale(A; exact=false) + a = symscale(A; exact=false, regularize=false) Given a matrix `A` assumed to be symmetric, return a vector `a` representing the "scale of each axis," so that `|A[i,j]| ~ a[i] * a[j]` for all `i, j`. `a[i]` is @@ -39,9 +39,10 @@ transformations. With `exact=false`, the pattern of nonzeros in `A` is approximated as `u * u'`, where `sum(u) * u[i] = nz[i]` is the number of nonzero in row `i`. This results in an -`O(n^2)` rather than `O(n^3)` algorithm. +`O(n^2)` rather than `O(n^3)` algorithm. `regularize=true` adds a small offset to the +diagonal (relevant only when `exact=true`), which handles all-zero rows of `A`. """ -function symscale(A::AbstractMatrix; exact::Bool=false) +function symscale(A::AbstractMatrix; exact::Bool=false, regularize::Bool=false) ax = axes(A, 1) axes(A, 2) == ax || throw(ArgumentError("symscale requires a square matrix")) sumlogA, nz = _symscale(A, ax) @@ -49,13 +50,17 @@ function symscale(A::AbstractMatrix; exact::Bool=false) if !exact || all(==(n), nz) # Sherman-Morrison formula for efficiency offset = sum(sumlogA) / (2 * sum(nz)) + divsafe!(sumlogA, nz) return exp.(sumlogA ./ nz .- offset) end - return exp.(cholesky(Diagonal(nz) + isnz(A)) \ sumlogA) + τ = regularize ? sqrt(eps(eltype(sumlogA))) : zero(eltype(sumlogA)) + W = isnz(A) + divsafe!(sumlogA, vec(sum(W; dims=2)); sentinel=-1/τ) + return exp.(cholesky(Diagonal(nz) + isnz(A) + τ * I) \ sumlogA) end """ - a, b = matrixscale(A; exact=false) + a, b = matrixscale(A; exact=false, regularize=false) Given a matrix `A`, return vectors `a` and `b` representing the "scale of each axis," so that `|A[i,j]| ~ a[i] * b[j]` for all `i, j`. `a[i]` and `b[j]` are @@ -75,7 +80,7 @@ With `exact=false`, the pattern of nonzeros in `A` is approximated as `u * v'`, where `sum(u) * v[j] = mA[j]` and `sum(v) * u[i] = nA[i]`. This results in an `O(m*n)` rather than `O((m+n)^3)` algorithm. """ -function matrixscale(A::AbstractMatrix; exact::Bool=false) +function matrixscale(A::AbstractMatrix; exact::Bool=false, regularize::Bool=false) Base.require_one_based_indexing(A) ax1, ax2 = axes(A, 1), axes(A, 2) (s, ns), (t, mt) = _matrixscale(A, ax1, ax2) @@ -83,13 +88,20 @@ function matrixscale(A::AbstractMatrix; exact::Bool=false) if !exact || (all(==(n), ns) && all(==(m), mt)) z = sum(ns) @assert sum(mt) == z "Inconsistent nonzero counts in rows and columns" - a = exp.(s ./ ns .- sum(s) / (2z)) - b = exp.(t ./ mt .- sum(t) / (2z)) + offsets, offsett = sum(s) / (2z), sum(t) / (2z) + divsafe!(s, ns) + divsafe!(t, mt) + a = exp.(s ./ ns .- offsets) + b = exp.(t ./ mt .- offsett) return a, b end p = vcat(ns, -mt) W = isnz(A) - a12 = exp.(cholesky(Diagonal(vcat(ns, mt)) + odblocks(W) + p * p') \ vcat(s, t)) + T = promote_type(eltype(s), eltype(t)) + τ = regularize ? sqrt(eps(T)) : zero(T) + divsafe!(s, vec(sum(W; dims=2)); sentinel=-1/τ) + divsafe!(t, vec(sum(W; dims=1)); sentinel=-1/τ) + a12 = exp.(cholesky(Diagonal(vcat(ns, mt)) + odblocks(W) + p * p' + τ * I) \ vcat(s, t)) return a12[begin:begin+m-1], a12[m+begin:end] end diff --git a/src/utils.jl b/src/utils.jl index 5db1e5e..e5e14dc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -34,6 +34,16 @@ end isnz(A) = .!iszero.(A) +function divsafe!(sumlog, nz; sentinel=-Inf) + for i in eachindex(sumlog, nz) + if iszero(nz[i]) + nz[i] = 1 + sumlog[i] = sentinel + end + end + return sumlog, nz +end + function odblocks(Anz::AbstractMatrix{T}) where T m, n = size(Anz) return [zeros(T, m, m) Anz; Anz' zeros(T, n, n)] diff --git a/test/runtests.jl b/test/runtests.jl index 7f0d594..1e465ee 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -42,6 +42,9 @@ end @test symscale([2.0 1.0; 1.0 3.0]) ≈ symscale([2.0 1.0; 1.0 3.0]; exact=true) ≈ exp.([3 1; 1 3] \ [log(2.0); log(3.0)]) @test symscale([1.0 -0.2; -0.2 0]; exact=true) ≈ [1, 0.2] @test symscale([1.0 0; 0 2]; exact=true) ≈ [1, sqrt(2)] + @test symscale([1.0 0; 0 0]) ≈ [1, 0] + @test_throws PosDefException symscale([1.0 0; 0 0]; exact=true) + @test symscale([1.0 0; 0 0]; exact=true, regularize=true) ≈ [1, 0] test_scaleinv(A -> symscale(A; exact=true), [2.0 1.0; 1.0 3.0], 1) a, b = matrixscale([2.0 1.0; 1.0 3.0]; exact=true) @test a ≈ b ≈ symscale([2.0 1.0; 1.0 3.0]; exact=true) @@ -52,6 +55,13 @@ end test_sumlog(A, a, b) a′, b′ = matrixscale(A) @test sum(log, a) ≈ sum(log, b) ≈ sum(log, a′) ≈ sum(log, b′) + a, b = matrixscale([1.0 0; 0 0]) + @test a ≈ [1, 0] + @test b ≈ [1, 0] + @test_throws PosDefException matrixscale([1.0 0; 0 0]; exact=true) + a, b = matrixscale([1.0 0; 0 0]; exact=true, regularize=true) + @test a ≈ [1, 0] + @test b ≈ [1, 0] @test condscale([1 0; 0 1e-8]) ≈ 1 A = [1.0 -0.2; -0.2 0]