From 355d3520517768f29f60c30d4cb19a5edc71f58f Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Sat, 3 Jan 2026 16:31:08 -0800 Subject: [PATCH 1/2] fix: Set init in `softmax_edge_neighbors` --- GNNlib/src/utils.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/GNNlib/src/utils.jl b/GNNlib/src/utils.jl index 8c739f3d9..c6cf34177 100644 --- a/GNNlib/src/utils.jl +++ b/GNNlib/src/utils.jl @@ -90,6 +90,9 @@ function softmax_edge_neighbors(g::AbstractGNNGraph, e) @assert size(e)[end] == g.num_edges end s, t = edge_index(g) + if isempty(t) + return zero(eltype(e)) + end max_ = gather(scatter(max, e, t), t) num = exp.(e .- max_) den = gather(scatter(+, num, t), t) From e012247bce5694f76ff78faaeff56bc72b578baa Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Sun, 4 Jan 2026 10:40:52 -0800 Subject: [PATCH 2/2] fix: Dimension consistency --- GNNlib/src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNlib/src/utils.jl b/GNNlib/src/utils.jl index c6cf34177..2f0361740 100644 --- a/GNNlib/src/utils.jl +++ b/GNNlib/src/utils.jl @@ -91,7 +91,7 @@ function softmax_edge_neighbors(g::AbstractGNNGraph, e) end s, t = edge_index(g) if isempty(t) - return zero(eltype(e)) + return zeros_like(e, eltype(e), (size(e)[1:end-1]..., 0)) end max_ = gather(scatter(max, e, t), t) num = exp.(e .- max_)