diff --git a/GNNlib/src/utils.jl b/GNNlib/src/utils.jl index 8c739f3d9..2f0361740 100644 --- a/GNNlib/src/utils.jl +++ b/GNNlib/src/utils.jl @@ -90,6 +90,9 @@ function softmax_edge_neighbors(g::AbstractGNNGraph, e) @assert size(e)[end] == g.num_edges end s, t = edge_index(g) + if isempty(t) + return zeros_like(e, eltype(e), (size(e)[1:end-1]..., 0)) + end max_ = gather(scatter(max, e, t), t) num = exp.(e .- max_) den = gather(scatter(+, num, t), t)