Skip to content

Conversation

@lenianiva
Copy link
Contributor

@lenianiva lenianiva commented Jan 18, 2026

Resolves #637

Fixes 3 problems:

  1. In gat_conv during backpropagation, "DimensionMismatch: arrays could not be broadcast to a common size". This could be solved by removing the conditional used to calculate Wxi, Wxj:
    Wxj = l.dense_x(xj)
    Wxj = reshape(Wxj, chout, heads, :)
    Wxi = l.dense_x(xi)
    Wxi = reshape(Wxi, chout, heads, :)
  1. The reshape(x, :, size(x, 3)) in gat_conv creates incompatible sizes when x is empty. This PR determines the size using the first two axes of x instead.
  2. The empty array problem DimensionMismatch: variable with size(x) == (1, 1, 0) cannot have a gradient with size(dx) == (4, 1, 0). This could be fixed by a patch for ChainRulesCore.jl fix: Allow arbitrary reshape in projection if array is zero sized JuliaDiff/ChainRulesCore.jl#702

Test script:

using GNNGraphs, GraphNeuralNetworks, NNlib, Flux

graph = GNNHeteroGraph(
    Dict(
        (:A, :a, :B) => ([1, 2], [3, 4]),
        (:B, :a, :A) => ([1], [2]),
        (:C, :a, :A) => (Int[], Int[]),
        (:A, :a, :C) => (Int[], Int[]),
        (:D, :a, :A) => (Int[], Int[]),
        (:E, :a, :A) => (Int[], Int[]),
        (:E, :a, :D) => (Int[], Int[]),
        (:D, :a, :E) => (Int[], Int[]),
    );
    num_nodes = Dict(
        :A => 3, :B => 5,
        :C => 7, :D => 0, :E => 0,
    )
)

width = 9
layer = HeteroGraphConv(
    [
        (src, edge, dst) => GATConv(width => width, NNlib.elu; dropout = Float32(0.25), add_self_loops = true) for
        (src, edge, dst) in keys(graph.edata)
    ];
)
layer2 = HeteroGraphConv(
    [
        (src, edge, dst) => GATConv(width => width, NNlib.elu; dropout = Float32(0.25), add_self_loops = true) for
        (src, edge, dst) in keys(graph.edata)
    ];
)

x = (
    A = rand(Float32, width, 3),
    B = rand(Float32, width, 5),
    C = rand(Float32, width, 7),
    D = rand(Float32, width, 0),
    E = rand(Float32, width, 0),
)

x1 = layer(graph, x)
x2 = layer2(graph, x1)
@info "$x2"

g = Flux.gradient(x) do x
    y = layer(graph, x)
    sum(y[:A])
end

@lenianiva
Copy link
Contributor Author

lenianiva commented Jan 18, 2026

How would this work if bias is set to false?

Also, we need the same patch on gatv2. Simply adding a new conditional in gat_conv only solves the solution for inference. During training it does not work for some reason.

Comment on lines +127 to +130
Wxj = l.dense_x(xj)
Wxj = reshape(Wxj, chout, heads, :)
Wxi = l.dense_x(xi)
Wxi = reshape(Wxi, chout, heads, :)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Wxj = l.dense_x(xj)
Wxj = reshape(Wxj, chout, heads, :)
Wxi = l.dense_x(xi)
Wxi = reshape(Wxi, chout, heads, :)
Wxj = l.dense_x(xj)
Wxj = reshape(Wxj, chout, heads, :)
if xi !== xj
Wxi = l.dense_x(xi)
Wxi = reshape(Wxi, chout, heads, :)
else
Wxi = Wxj
end

would work?

Copy link
Contributor Author

@lenianiva lenianiva Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work. It seems like the above example triggers both branches and Zygote gets one branch confused for another. I have seen this kind of behaviour before with Zygote.

@CarloLucibello
Copy link
Member

regarding the wrong gradient shape for empty array, do you have any clue why it is happening? ideally the chainrule's projection shouldn't be patched, it should just receive a dx in the correct shape.

@lenianiva
Copy link
Contributor Author

lenianiva commented Jan 19, 2026

regarding the wrong gradient shape for empty array, do you have any clue why it is happening? ideally the chainrule's projection shouldn't be patched, it should just receive a dx in the correct shape.

I have absolutely no idea. A hint there is the error only occurs for 0-sized arrays. I think it has something to do with computing gradients and is outside the scope of this package.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GATConv doesn't work on hetero graphs with empty edge arrays or during backpropagation

2 participants