Conversation
There was a problem hiding this comment.
Pull request overview
This PR fixes a PyTorch bug where torch.clamp incorrectly handles NaN values differently depending on whether they are passed as scalars or tensors (pytorch/pytorch#172067). The fix adds special handling for NaN values when both min and max are scalars.
Key changes:
- Added
import mathfor NaN detection - Refactored clip function logic to handle mixed scalar/tensor signatures better
- Added special case to return NaN-filled tensors when NaN bounds are provided as scalars
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| min: int | float | Array | None = None, | ||
| max: int | float | Array | None = None, | ||
| **kwargs | ||
| ) -> Array: |
There was a problem hiding this comment.
Removing **kwargs from the function signature breaks support for the out= parameter. The test_clip_out test in tests/test_common.py expects torch's clip to support out=, which torch.clamp does support. The **kwargs should be retained in the signature and passed to torch.clamp calls at lines 865 and 877 to maintain backward compatibility.
|
Merging as a follow-up to gh-353 to fix an edge case which surfaced in more thorough testing. |
Work around pytorch/pytorch#172067 :
Tested locally with (yes,$10^5$ examples)
which passes.