From 355d3520517768f29f60c30d4cb19a5edc71f58f Mon Sep 17 00:00:00 2001
From: Leni Aniva
Date: Sat, 3 Jan 2026 16:31:08 -0800
Subject: [PATCH 1/2] fix: Set init in `softmax_edge_neighbors`
---
GNNlib/src/utils.jl | 3 +++
1 file changed, 3 insertions(+)
diff --git a/GNNlib/src/utils.jl b/GNNlib/src/utils.jl
index 8c739f3d9..c6cf34177 100644
--- a/GNNlib/src/utils.jl
+++ b/GNNlib/src/utils.jl
@@ -90,6 +90,9 @@ function softmax_edge_neighbors(g::AbstractGNNGraph, e)
@assert size(e)[end] == g.num_edges
end
s, t = edge_index(g)
+ if isempty(t)
+ return zero(eltype(e))
+ end
max_ = gather(scatter(max, e, t), t)
num = exp.(e .- max_)
den = gather(scatter(+, num, t), t)
From e012247bce5694f76ff78faaeff56bc72b578baa Mon Sep 17 00:00:00 2001
From: Leni Aniva
Date: Sun, 4 Jan 2026 10:40:52 -0800
Subject: [PATCH 2/2] fix: Dimension consistency
---
GNNlib/src/utils.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/GNNlib/src/utils.jl b/GNNlib/src/utils.jl
index c6cf34177..2f0361740 100644
--- a/GNNlib/src/utils.jl
+++ b/GNNlib/src/utils.jl
@@ -91,7 +91,7 @@ function softmax_edge_neighbors(g::AbstractGNNGraph, e)
end
s, t = edge_index(g)
if isempty(t)
- return zero(eltype(e))
+ return zeros_like(e, eltype(e), (size(e)[1:end-1]..., 0))
end
max_ = gather(scatter(max, e, t), t)
num = exp.(e .- max_)