From e90e515a8e013d8cc24a09e0a02d2aca690f8058 Mon Sep 17 00:00:00 2001
From: Leni Aniva
Date: Wed, 24 Dec 2025 01:24:36 -0800
Subject: [PATCH 1/2] feat: Reduce nodes over heterogeneous graphs
---
GNNlib/src/utils.jl | 11 +++++++++++
1 file changed, 11 insertions(+)
diff --git a/GNNlib/src/utils.jl b/GNNlib/src/utils.jl
index 8c739f3d9..f281b297d 100644
--- a/GNNlib/src/utils.jl
+++ b/GNNlib/src/utils.jl
@@ -27,6 +27,17 @@ function reduce_nodes(aggr, indicator::AbstractVector, x)
return NNlib.scatter(aggr, x, indicator)
end
+"""
+ reduce_nodes(aggr, node_type, g, x)
+
+Return the graph-wise aggregation of the node features `x` on type `node_type`
+given a heterogeneous graph `g`. The aggregation operator `aggr` can be `+`,
+`mean`, `max`, or `min`.
+"""
+function reduce_nodes(aggr, node_type, g::GNNHeteroGraph, x)
+ return NNlib.scatter(aggr, x[node_type], graph_indicator(g, node_type))
+end
+
"""
reduce_edges(aggr, g, e)
From 81b9e57336dc4e45f075d92bc465973d9a6fcdab Mon Sep 17 00:00:00 2001
From: Leni Aniva
Date: Sat, 3 Jan 2026 17:03:29 -0800
Subject: [PATCH 2/2] test: Add test for reduce
---
GNNlib/test/utils.jl | 11 +++++++++++
1 file changed, 11 insertions(+)
diff --git a/GNNlib/test/utils.jl b/GNNlib/test/utils.jl
index bf06f86fd..784329db9 100644
--- a/GNNlib/test/utils.jl
+++ b/GNNlib/test/utils.jl
@@ -19,6 +19,17 @@
@test r2 == r
end
+ @testset "reduce_nodes" begin
+ g = rand_bipartite_heterograph((5, 10), 20)
+ x = (
+ A = [Float32(i) for j = 1:1, i = 1:g.num_nodes[:A]],
+ B = [Float32(0) for j = 1:2, _ = 1:g.num_nodes[:B]],
+ )
+ expected = sum(i for i = 1:g.num_nodes[:A])
+ result = reduce_nodes(+, :A, g, x)
+ @test result == [expected;;]
+ end
+
@testset "reduce_edges" begin
r = reduce_edges(mean, g, e)
@test size(r) == (De, g.num_graphs)