-
Notifications
You must be signed in to change notification settings - Fork 59
fix: Gate empty result in GATConv #638
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
How would this work if Also, we need the same patch on gatv2. Simply adding a new conditional in |
| Wxj = l.dense_x(xj) | ||
| Wxj = reshape(Wxj, chout, heads, :) | ||
| Wxi = l.dense_x(xi) | ||
| Wxi = reshape(Wxi, chout, heads, :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| Wxj = l.dense_x(xj) | |
| Wxj = reshape(Wxj, chout, heads, :) | |
| Wxi = l.dense_x(xi) | |
| Wxi = reshape(Wxi, chout, heads, :) | |
| Wxj = l.dense_x(xj) | |
| Wxj = reshape(Wxj, chout, heads, :) | |
| if xi !== xj | |
| Wxi = l.dense_x(xi) | |
| Wxi = reshape(Wxi, chout, heads, :) | |
| else | |
| Wxi = Wxj | |
| end |
would work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't work. It seems like the above example triggers both branches and Zygote gets one branch confused for another. I have seen this kind of behaviour before with Zygote.
|
regarding the wrong gradient shape for empty array, do you have any clue why it is happening? ideally the chainrule's projection shouldn't be patched, it should just receive a dx in the correct shape. |
I have absolutely no idea. A hint there is the error only occurs for 0-sized arrays. I think it has something to do with computing gradients and is outside the scope of this package. |
Resolves #637
Fixes 3 problems:
gat_convduring backpropagation, "DimensionMismatch: arrays could not be broadcast to a common size". This could be solved by removing the conditional used to calculateWxi, Wxj:reshape(x, :, size(x, 3))ingat_convcreates incompatible sizes whenxis empty. This PR determines the size using the first two axes ofxinstead.DimensionMismatch: variable with size(x) == (1, 1, 0) cannot have a gradient with size(dx) == (4, 1, 0). This could be fixed by a patch forChainRulesCore.jlfix: Allow arbitrary reshape in projection if array is zero sized JuliaDiff/ChainRulesCore.jl#702Test script: