Skip to content

Commit 00d0389

Browse files
committed
chore: addressing PR AIs
1 parent e16b612 commit 00d0389

7 files changed

Lines changed: 38 additions & 33 deletions

File tree

.github/workflows/build-test-linux-x86_64.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ jobs:
141141
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_*
142142
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/test_000_*
143143
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/
144-
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_hlo_tests_results.xml hlo/
145144
popd
146145
147146
L0-py-core-tests:
@@ -237,6 +236,7 @@ jobs:
237236
cd tests/py/dynamo
238237
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_*
239238
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_partitioning_tests_results.xml partitioning/test_001_*
239+
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_hlo_tests_results.xml hlo/
240240
241241
popd
242242

.github/workflows/build-test-linux-x86_64_rtx.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ jobs:
142142
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_*
143143
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/
144144
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/
145-
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_hlo_tests_results.xml hlo/
146145
popd
147146
148147
L0-py-core-tests:
@@ -205,6 +204,7 @@ jobs:
205204
pushd .
206205
cd tests/py/dynamo
207206
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_*
207+
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_hlo_tests_results.xml hlo/
208208
popd
209209
210210
L1-dynamo-compile-tests:

.github/workflows/build-test-windows.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ jobs:
140140
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_*
141141
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/test_000_*
142142
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/
143-
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_hlo_tests_results.xml hlo/
144143
popd
145144
146145
L0-py-core-tests:
@@ -227,6 +226,7 @@ jobs:
227226
cd tests/py/dynamo
228227
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_*
229228
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_partitioning_tests_results.xml partitioning/test_001_*
229+
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_hlo_tests_results.xml hlo/
230230
popd
231231
232232
L1-dynamo-compile-tests:

.github/workflows/build-test-windows_rtx.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ jobs:
144144
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_*
145145
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/
146146
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/
147-
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_hlo_tests_results.xml hlo/
148147
popd
149148
150149
L0-py-core-tests:
@@ -201,6 +200,7 @@ jobs:
201200
pushd .
202201
cd tests/py/dynamo
203202
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_*
203+
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_hlo_tests_results.xml hlo/
204204
popd
205205
206206
L1-dynamo-compile-tests:

py/torch_tensorrt/dynamo/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -838,9 +838,9 @@ def copy_metadata(match_and_replacements: List[Any]) -> None:
838838
"""
839839
for match_and_replacement in match_and_replacements:
840840
anchor_node = match_and_replacement.nodes_map[match_and_replacement.anchor]
841-
assert len(match_and_replacement.replacements) == 1, (
842-
"Found more than 1 replacements for the anchor node."
843-
)
841+
assert (
842+
len(match_and_replacement.replacements) == 1
843+
), "Found more than 1 replacements for the anchor node."
844844
replacement_node = match_and_replacement.replacements[0]
845845
replacement_node.meta = anchor_node.meta
846846

tests/py/dynamo/hlo/test_complex_graph_break.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -123,44 +123,45 @@ def test_unsupported_op_gets_complexify_wrap() -> None:
123123
nodes_by_target.setdefault(n.target, []).append(n)
124124

125125
# view_as_complex must be present (inserted by the fallback wrapper)
126-
assert torch.ops.aten.view_as_complex.default in nodes_by_target, (
127-
"Expected view_as_complex to be inserted before cumsum, but it was not found"
128-
)
126+
assert (
127+
torch.ops.aten.view_as_complex.default in nodes_by_target
128+
), "Expected view_as_complex to be inserted before cumsum, but it was not found"
129129

130130
# cumsum must still be present (it was NOT removed)
131-
assert torch.ops.aten.cumsum.default in nodes_by_target, (
132-
"cumsum should remain in the graph (runs as PyTorch fallback)"
133-
)
131+
assert (
132+
torch.ops.aten.cumsum.default in nodes_by_target
133+
), "cumsum should remain in the graph (runs as PyTorch fallback)"
134134

135135
# The view_as_complex output feeds directly into cumsum
136136
vc_node = nodes_by_target[torch.ops.aten.view_as_complex.default][0]
137137
cumsum_node = nodes_by_target[torch.ops.aten.cumsum.default][0]
138-
assert cumsum_node.args[0] is vc_node, (
139-
f"cumsum's first arg should be the view_as_complex node, got {cumsum_node.args[0]}"
140-
)
138+
assert (
139+
cumsum_node.args[0] is vc_node
140+
), f"cumsum's first arg should be the view_as_complex node, got {cumsum_node.args[0]}"
141141

142142
# The view_as_complex input is a real-layout (is_complex_layout) node
143143
vc_input = vc_node.args[0]
144144
assert isinstance(vc_input, torch.fx.Node), "view_as_complex input must be a Node"
145-
assert vc_input.meta.get("is_complex_layout", False), (
146-
"view_as_complex input should be a real-layout complex node (is_complex_layout=True)"
147-
)
145+
assert vc_input.meta.get(
146+
"is_complex_layout", False
147+
), "view_as_complex input should be a real-layout complex node (is_complex_layout=True)"
148148

149149
# view_as_real must follow cumsum
150-
assert torch.ops.aten.view_as_real.default in nodes_by_target, (
151-
"Expected view_as_real to be inserted after cumsum, but it was not found"
152-
)
150+
assert (
151+
torch.ops.aten.view_as_real.default in nodes_by_target
152+
), "Expected view_as_real to be inserted after cumsum, but it was not found"
153153
vr_node = nodes_by_target[torch.ops.aten.view_as_real.default][0]
154-
assert vr_node.args[0] is cumsum_node, (
155-
f"view_as_real's arg should be the cumsum node, got {vr_node.args[0]}"
156-
)
154+
assert (
155+
vr_node.args[0] is cumsum_node
156+
), f"view_as_real's arg should be the cumsum node, got {vr_node.args[0]}"
157157

158158
# After metadata propagation, cumsum receives a complex-dtype tensor
159159
vc_val = vc_node.meta.get("val")
160160
if vc_val is not None:
161-
assert vc_val.dtype in (torch.complex64, torch.complex128), (
162-
f"view_as_complex output should be complex, got {vc_val.dtype}"
163-
)
161+
assert vc_val.dtype in (
162+
torch.complex64,
163+
torch.complex128,
164+
), f"view_as_complex output should be complex, got {vc_val.dtype}"
164165

165166

166167
# ===========================================================================
@@ -221,9 +222,10 @@ def test_complex_partial_lowering_with_graph_break() -> None:
221222
if n.target == torch.ops.aten.cumsum.default:
222223
vc_val = n.args[0].meta.get("val")
223224
if vc_val is not None:
224-
assert vc_val.dtype in (torch.complex64, torch.complex128), (
225-
f"cumsum should receive a complex tensor, got {vc_val.dtype}"
226-
)
225+
assert vc_val.dtype in (
226+
torch.complex64,
227+
torch.complex128,
228+
), f"cumsum should receive a complex tensor, got {vc_val.dtype}"
227229
break
228230

229231
# End-to-end: compile and verify numerical correctness

tests/py/dynamo/lowering/test_complex_rewrite.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,11 @@ def _export_and_lower(
6262
def _real_inputs(inputs: Tuple[Any, ...]) -> Tuple[Any, ...]:
6363
"""Convert complex tensors to [..., 2] real layout."""
6464
return tuple(
65-
torch.view_as_real(x).contiguous() if isinstance(x, torch.Tensor) and x.is_complex()
66-
else x
65+
(
66+
torch.view_as_real(x).contiguous()
67+
if isinstance(x, torch.Tensor) and x.is_complex()
68+
else x
69+
)
6770
for x in inputs
6871
)
6972

0 commit comments

Comments
 (0)