Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Coloring/Coloring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ function acyclic_coloring(g::UndirectedGraph)
firstVisitToTree = fill(_Edge(0, 0, 0), _num_edges(g))
color = fill(0, _num_vertices(g))
# disjoint set forest of edges in the graph
S = DataStructures.IntDisjointSets(_num_edges(g))
S = DataStructures.IntDisjointSet{Int}(_num_edges(g))
@inbounds for v in 1:_num_vertices(g)
n_neighbor = _num_neighbors(v, g)
start_neighbor = _start_neighbors(v, g)
Expand Down
220 changes: 220 additions & 0 deletions src/reverse_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,83 @@ function _forward_eval(
tmp_dot += v1 * v2
end
@s f.forward_storage[k] = tmp_dot
elseif node.index == 12 # hcat
idx1, idx2 = children_indices
ix1 = children_arr[idx1]
ix2 = children_arr[idx2]
nb_cols1 = f.sizes.ndims[ix1] <= 1 ? 1 : _size(f.sizes, ix1, 2)
col_size = f.sizes.ndims[ix1] == 0 ? 1 : _size(f.sizes, k, 1)
for j in _eachindex(f.sizes, ix1)
@j f.partials_storage[ix1] = one(T)
val = @j f.forward_storage[ix1]
@j f.forward_storage[k] = val
end
for j in _eachindex(f.sizes, ix2)
@j f.partials_storage[ix2] = one(T)
val = @j f.forward_storage[ix2]
_setindex!(
f.forward_storage,
val,
f.sizes,
k,
j + nb_cols1 * col_size,
)
end
elseif node.index == 13 # vcat
idx1, idx2 = children_indices
ix1 = children_arr[idx1]
ix2 = children_arr[idx2]
nb_rows1 = f.sizes.ndims[ix1] <= 1 ? 1 : _size(f.sizes, ix1, 1)
nb_rows2 = f.sizes.ndims[ix2] <= 1 ? 1 : _size(f.sizes, ix2, 1)
nb_rows = nb_rows1 + nb_rows2
for j in _eachindex(f.sizes, ix1)
@j f.partials_storage[ix1] = one(T)
val = @j f.forward_storage[ix1]
_setindex!(
f.forward_storage,
val,
f.sizes,
k,
div(j-1, nb_rows1) * nb_rows + 1 + (j-1) % nb_rows1,
)
end
for j in _eachindex(f.sizes, ix2)
@j f.partials_storage[ix2] = one(T)
val = @j f.forward_storage[ix2]
_setindex!(
f.forward_storage,
val,
f.sizes,
k,
div(j-1, nb_rows1) * nb_rows +
1 +
(j-1) % nb_rows1 +
nb_rows1,
)
end
elseif node.index == 14 # norm
ix = children_arr[children_indices[1]]
tmp_norm_squared = zero(T)
for j in _eachindex(f.sizes, ix)
v = @j f.forward_storage[ix]
tmp_norm_squared += v * v
end
@s f.forward_storage[k] = sqrt(tmp_norm_squared)
for j in _eachindex(f.sizes, ix)
v = @j f.forward_storage[ix]
if tmp_norm_squared == 0
@j f.partials_storage[ix] = zero(T)
else
@j f.partials_storage[ix] = v / @s f.forward_storage[k]
end
end
elseif node.index == 16 # row
for j in _eachindex(f.sizes, k)
ix = children_arr[children_indices[j]]
@s f.partials_storage[ix] = one(T)
val = @s f.forward_storage[ix]
@j f.forward_storage[k] = val
end
else # atan, min, max
f_input = _UnsafeVectorView(d.jac_storage, N)
∇f = _UnsafeVectorView(d.user_output_buffer, N)
Expand Down Expand Up @@ -380,6 +457,149 @@ function _reverse_eval(f::_SubexpressionStorage)
end
end
continue
elseif op == :hcat
idx1, idx2 = children_indices
ix1 = children_arr[idx1]
ix2 = children_arr[idx2]
nb_cols1 =
f.sizes.ndims[ix1] <= 1 ? 1 : _size(f.sizes, ix1, 2)
col_size =
f.sizes.ndims[ix1] == 0 ? 1 : _size(f.sizes, k, 1)
for j in _eachindex(f.sizes, ix1)
partial = @j f.partials_storage[ix1]
val = ifelse(
_getindex(f.reverse_storage, f.sizes, k, j) ==
0.0 && !isfinite(partial),
_getindex(f.reverse_storage, f.sizes, k, j),
_getindex(f.reverse_storage, f.sizes, k, j) *
partial,
)
@j f.reverse_storage[ix1] = val
end
for j in _eachindex(f.sizes, ix2)
partial = @j f.partials_storage[ix2]
val = ifelse(
_getindex(
f.reverse_storage,
f.sizes,
k,
j + nb_cols1 * col_size,
) == 0.0 && !isfinite(partial),
_getindex(
f.reverse_storage,
f.sizes,
k,
j + nb_cols1 * col_size,
),
_getindex(
f.reverse_storage,
f.sizes,
k,
j + nb_cols1 * col_size,
) * partial,
)
@j f.reverse_storage[ix2] = val
end
continue
elseif op == :vcat
idx1, idx2 = children_indices
ix1 = children_arr[idx1]
ix2 = children_arr[idx2]
nb_rows1 =
f.sizes.ndims[ix1] <= 1 ? 1 : _size(f.sizes, ix1, 1)
nb_rows2 =
f.sizes.ndims[ix2] <= 1 ? 1 : _size(f.sizes, ix2, 1)
nb_rows = nb_rows1 + nb_rows2
row_size =
f.sizes.ndims[ix1] == 0 ? 1 : _size(f.sizes, k, 2)
for j in _eachindex(f.sizes, ix1)
partial = @j f.partials_storage[ix1]
val = ifelse(
_getindex(
f.reverse_storage,
f.sizes,
k,
div(j-1, nb_rows1) * nb_rows +
1 +
(j-1) % nb_rows1,
) == 0.0 && !isfinite(partial),
_getindex(
f.reverse_storage,
f.sizes,
k,
div(j-1, nb_rows1) * nb_rows +
1 +
(j-1) % nb_rows1,
),
_getindex(
f.reverse_storage,
f.sizes,
k,
div(j-1, nb_rows1) * nb_rows +
1 +
(j-1) % nb_rows1,
) * partial,
)
@j f.reverse_storage[ix1] = val
end
for j in _eachindex(f.sizes, ix2)
partial = @j f.partials_storage[ix2]
val = ifelse(
_getindex(
f.reverse_storage,
f.sizes,
k,
div(j-1, nb_rows1) * nb_rows +
1 +
(j-1) % nb_rows1 +
nb_rows1,
) == 0.0 && !isfinite(partial),
_getindex(
f.reverse_storage,
f.sizes,
k,
div(j-1, nb_rows1) * nb_rows +
1 +
(j-1) % nb_rows1 +
nb_rows1,
),
_getindex(
f.reverse_storage,
f.sizes,
k,
div(j-1, nb_rows1) * nb_rows +
1 +
(j-1) % nb_rows1 +
nb_rows1,
) * partial,
)
@j f.reverse_storage[ix2] = val
end
continue
elseif op == :norm
# Node `k` is scalar, the jacobian w.r.t. the vectorized input
# child is a row vector whose entries are stored in `f.partials_storage`
rev_parent = @s f.reverse_storage[k]
for j in
_eachindex(f.sizes, children_arr[children_indices[1]])
ix = children_arr[children_indices[1]]
partial = @j f.partials_storage[ix]
val = ifelse(
rev_parent == 0.0 && !isfinite(partial),
rev_parent,
rev_parent * partial,
)
@j f.reverse_storage[ix] = val
end
continue
elseif op == :row
for j in _eachindex(f.sizes, k)
ix = children_arr[children_indices[j]]
rev_parent_j = @j f.reverse_storage[k]
# partial is 1 so we can ignore it
@s f.reverse_storage[ix] = rev_parent_j
end
continue
end
end
elseif node.type != MOI.Nonlinear.NODE_CALL_UNIVARIATE
Expand Down
74 changes: 66 additions & 8 deletions src/sizes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,59 @@ function _infer_sizes(
op,
)
_add_size!(sizes, k, (N,))
elseif op == :row
_assert_scalar_children(
sizes,
children_arr,
children_indices,
op,
)
_add_size!(sizes, k, (1, N))
elseif op == :dot
# TODO assert all arguments have same size
elseif op == :norm
# TODO actually norm should be moved to univariate
elseif op == :+ || op == :-
# TODO assert all arguments have same size
_copy_size!(sizes, k, children_arr[first(children_indices)])
elseif op == :hcat
total_cols = 0
for c_idx in children_indices
total_cols +=
sizes.ndims[children_arr[c_idx]] <= 1 ? 1 :
_size(sizes, children_arr[c_idx], 2)
end
if sizes.ndims[children_arr[first(children_indices)]] == 0
shape = (1, total_cols)
else
@assert sizes.ndims[children_arr[first(
children_indices,
)]] <= 2 "Hcat with ndims > 2 is not supported yet"
shape = (
_size(sizes, children_arr[first(children_indices)], 1),
total_cols,
)
end
_add_size!(sizes, k, tuple(shape...))
elseif op == :vcat
total_rows = 0
for c_idx in children_indices
total_rows +=
sizes.ndims[children_arr[c_idx]] <= 1 ? 1 :
_size(sizes, children_arr[c_idx], 1)
end
if sizes.ndims[children_arr[first(children_indices)]] == 0
shape = (total_rows, 1)
else
@assert sizes.ndims[children_arr[first(
children_indices,
)]] <= 2 "Hcat with ndims > 2 is not supported yet"
shape = (
total_rows,
_size(sizes, children_arr[first(children_indices)], 2),
)
end
_add_size!(sizes, k, tuple(shape...))
elseif op == :*
# TODO assert compatible sizes and all ndims should be 0 or 2
first_matrix = findfirst(children_indices) do i
Expand All @@ -193,14 +241,24 @@ function _infer_sizes(
last_matrix = findfirst(children_indices) do i
return !iszero(sizes.ndims[children_arr[i]])
end
_add_size!(
sizes,
k,
(
_size(sizes, first_matrix, 1),
_size(sizes, last_matrix, sizes.ndims[last_matrix]),
),
)
if sizes.ndims[last_matrix] == 0 ||
sizes.ndims[first_matrix] == 0
_add_size!(sizes, k, (1, 1))
continue
else
_add_size!(
sizes,
k,
(
_size(sizes, first_matrix, 1),
_size(
sizes,
last_matrix,
sizes.ndims[last_matrix],
),
),
)
end
end
elseif op == :^ || op == :/
@assert N == 2
Expand Down
Loading
Loading