Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions src/ScaleInvariantAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ end


"""
a = symscale(A; exact=false)
a = symscale(A; exact=false, regularize=false)

Given a matrix `A` assumed to be symmetric, return a vector `a` representing the
"scale of each axis," so that `|A[i,j]| ~ a[i] * a[j]` for all `i, j`. `a[i]` is
Expand All @@ -39,23 +39,28 @@ transformations.

With `exact=false`, the pattern of nonzeros in `A` is approximated as `u * u'`,
where `sum(u) * u[i] = nz[i]` is the number of nonzero in row `i`. This results in an
`O(n^2)` rather than `O(n^3)` algorithm.
`O(n^2)` rather than `O(n^3)` algorithm. `regularize=true` adds a small offset to the
diagonal (relevant only when `exact=true`), which handles all-zero rows of `A`.
"""
function symscale(A::AbstractMatrix; exact::Bool=false)
function symscale(A::AbstractMatrix; exact::Bool=false, regularize::Bool=false)
ax = axes(A, 1)
axes(A, 2) == ax || throw(ArgumentError("symscale requires a square matrix"))
sumlogA, nz = _symscale(A, ax)
n = length(ax)
if !exact || all(==(n), nz)
# Sherman-Morrison formula for efficiency
offset = sum(sumlogA) / (2 * sum(nz))
divsafe!(sumlogA, nz)
return exp.(sumlogA ./ nz .- offset)
end
return exp.(cholesky(Diagonal(nz) + isnz(A)) \ sumlogA)
τ = regularize ? sqrt(eps(eltype(sumlogA))) : zero(eltype(sumlogA))
W = isnz(A)
divsafe!(sumlogA, vec(sum(W; dims=2)); sentinel=-1/τ)
return exp.(cholesky(Diagonal(nz) + isnz(A) + τ * I) \ sumlogA)
end

"""
a, b = matrixscale(A; exact=false)
a, b = matrixscale(A; exact=false, regularize=false)

Given a matrix `A`, return vectors `a` and `b` representing the "scale of each
axis," so that `|A[i,j]| ~ a[i] * b[j]` for all `i, j`. `a[i]` and `b[j]` are
Expand All @@ -75,21 +80,28 @@ With `exact=false`, the pattern of nonzeros in `A` is approximated as `u * v'`,
where `sum(u) * v[j] = mA[j]` and `sum(v) * u[i] = nA[i]`. This results in an
`O(m*n)` rather than `O((m+n)^3)` algorithm.
"""
function matrixscale(A::AbstractMatrix; exact::Bool=false)
function matrixscale(A::AbstractMatrix; exact::Bool=false, regularize::Bool=false)
Base.require_one_based_indexing(A)
ax1, ax2 = axes(A, 1), axes(A, 2)
(s, ns), (t, mt) = _matrixscale(A, ax1, ax2)
m, n = length(ax1), length(ax2)
if !exact || (all(==(n), ns) && all(==(m), mt))
z = sum(ns)
@assert sum(mt) == z "Inconsistent nonzero counts in rows and columns"
a = exp.(s ./ ns .- sum(s) / (2z))
b = exp.(t ./ mt .- sum(t) / (2z))
offsets, offsett = sum(s) / (2z), sum(t) / (2z)
divsafe!(s, ns)
divsafe!(t, mt)
a = exp.(s ./ ns .- offsets)
b = exp.(t ./ mt .- offsett)
return a, b
end
p = vcat(ns, -mt)
W = isnz(A)
a12 = exp.(cholesky(Diagonal(vcat(ns, mt)) + odblocks(W) + p * p') \ vcat(s, t))
T = promote_type(eltype(s), eltype(t))
τ = regularize ? sqrt(eps(T)) : zero(T)
divsafe!(s, vec(sum(W; dims=2)); sentinel=-1/τ)
divsafe!(t, vec(sum(W; dims=1)); sentinel=-1/τ)
a12 = exp.(cholesky(Diagonal(vcat(ns, mt)) + odblocks(W) + p * p' + τ * I) \ vcat(s, t))
return a12[begin:begin+m-1], a12[m+begin:end]
end

Expand Down
10 changes: 10 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ end

isnz(A) = .!iszero.(A)

function divsafe!(sumlog, nz; sentinel=-Inf)
for i in eachindex(sumlog, nz)
if iszero(nz[i])
nz[i] = 1
sumlog[i] = sentinel
end
end
return sumlog, nz
end

function odblocks(Anz::AbstractMatrix{T}) where T
m, n = size(Anz)
return [zeros(T, m, m) Anz; Anz' zeros(T, n, n)]
Expand Down
10 changes: 10 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ end
@test symscale([2.0 1.0; 1.0 3.0]) ≈ symscale([2.0 1.0; 1.0 3.0]; exact=true) ≈ exp.([3 1; 1 3] \ [log(2.0); log(3.0)])
@test symscale([1.0 -0.2; -0.2 0]; exact=true) ≈ [1, 0.2]
@test symscale([1.0 0; 0 2]; exact=true) ≈ [1, sqrt(2)]
@test symscale([1.0 0; 0 0]) ≈ [1, 0]
@test_throws PosDefException symscale([1.0 0; 0 0]; exact=true)
@test symscale([1.0 0; 0 0]; exact=true, regularize=true) ≈ [1, 0]
test_scaleinv(A -> symscale(A; exact=true), [2.0 1.0; 1.0 3.0], 1)
a, b = matrixscale([2.0 1.0; 1.0 3.0]; exact=true)
@test a ≈ b ≈ symscale([2.0 1.0; 1.0 3.0]; exact=true)
Expand All @@ -52,6 +55,13 @@ end
test_sumlog(A, a, b)
a′, b′ = matrixscale(A)
@test sum(log, a) ≈ sum(log, b) ≈ sum(log, a′) ≈ sum(log, b′)
a, b = matrixscale([1.0 0; 0 0])
@test a ≈ [1, 0]
@test b ≈ [1, 0]
@test_throws PosDefException matrixscale([1.0 0; 0 0]; exact=true)
a, b = matrixscale([1.0 0; 0 0]; exact=true, regularize=true)
@test a ≈ [1, 0]
@test b ≈ [1, 0]

@test condscale([1 0; 0 1e-8]) ≈ 1
A = [1.0 -0.2; -0.2 0]
Expand Down