Pytorch triple backward#200
Conversation
|
I'm going to run the full test suite tomorrow. |
vbharadwaj-bk
left a comment
There was a problem hiding this comment.
95% looks good. I think the stream_test.py modifications are redundant, since each of the custom ops have already been tested. But otherwise good.
| ): | ||
| assert self.torch_op | ||
|
|
||
| in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True) |
There was a problem hiding this comment.
TODO-someday: I wonder if we can combine all of these derivative functions into one to compact this file.
| @pytest.fixture(scope="class") | ||
| def problem(self, dtype, with_jax): | ||
| if with_jax: | ||
| pytest.skip("N/A for JAX") |
There was a problem hiding this comment.
TODO-someday: we could expand this test to include JAX. But not in this commit.
| return (X, Y, W, edge_index[0], edge_index[1]) | ||
|
|
||
|
|
||
| @pytest.fixture |
There was a problem hiding this comment.
Hmm. I don't think we should have any modifications to stream_test.py. Because triple_backward is a composition of existing ops that all work fine with streams, I see no reason why their composition shouldn't pass stream tests. We need to test anything that's implemented as a custom op to make sure that the stream information is lowered correctly onto the kernel, but then any composition of those operators should be ok. Let's shrink the diff here.
| self.check_result(result, fieldname) | ||
|
|
||
|
|
||
| class TestTripleBackwardConvDirectOps: |
There was a problem hiding this comment.
Remind me, what is the purpose of these DirectOps tests?
Adding triple backward support for higher order training to pytorch