From e90e515a8e013d8cc24a09e0a02d2aca690f8058 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Wed, 24 Dec 2025 01:24:36 -0800 Subject: [PATCH 1/2] feat: Reduce nodes over heterogeneous graphs --- GNNlib/src/utils.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/GNNlib/src/utils.jl b/GNNlib/src/utils.jl index 8c739f3d9..f281b297d 100644 --- a/GNNlib/src/utils.jl +++ b/GNNlib/src/utils.jl @@ -27,6 +27,17 @@ function reduce_nodes(aggr, indicator::AbstractVector, x) return NNlib.scatter(aggr, x, indicator) end +""" + reduce_nodes(aggr, node_type, g, x) + +Return the graph-wise aggregation of the node features `x` on type `node_type` +given a heterogeneous graph `g`. The aggregation operator `aggr` can be `+`, +`mean`, `max`, or `min`. +""" +function reduce_nodes(aggr, node_type, g::GNNHeteroGraph, x) + return NNlib.scatter(aggr, x[node_type], graph_indicator(g, node_type)) +end + """ reduce_edges(aggr, g, e) From 81b9e57336dc4e45f075d92bc465973d9a6fcdab Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Sat, 3 Jan 2026 17:03:29 -0800 Subject: [PATCH 2/2] test: Add test for reduce --- GNNlib/test/utils.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/GNNlib/test/utils.jl b/GNNlib/test/utils.jl index bf06f86fd..784329db9 100644 --- a/GNNlib/test/utils.jl +++ b/GNNlib/test/utils.jl @@ -19,6 +19,17 @@ @test r2 == r end + @testset "reduce_nodes" begin + g = rand_bipartite_heterograph((5, 10), 20) + x = ( + A = [Float32(i) for j = 1:1, i = 1:g.num_nodes[:A]], + B = [Float32(0) for j = 1:2, _ = 1:g.num_nodes[:B]], + ) + expected = sum(i for i = 1:g.num_nodes[:A]) + result = reduce_nodes(+, :A, g, x) + @test result == [expected;;] + end + @testset "reduce_edges" begin r = reduce_edges(mean, g, e) @test size(r) == (De, g.num_graphs)