@@ -123,44 +123,45 @@ def test_unsupported_op_gets_complexify_wrap() -> None:
123123 nodes_by_target .setdefault (n .target , []).append (n )
124124
125125 # view_as_complex must be present (inserted by the fallback wrapper)
126- assert torch . ops . aten . view_as_complex . default in nodes_by_target , (
127- "Expected view_as_complex to be inserted before cumsum, but it was not found"
128- )
126+ assert (
127+ torch . ops . aten . view_as_complex . default in nodes_by_target
128+ ), "Expected view_as_complex to be inserted before cumsum, but it was not found"
129129
130130 # cumsum must still be present (it was NOT removed)
131- assert torch . ops . aten . cumsum . default in nodes_by_target , (
132- " cumsum should remain in the graph (runs as PyTorch fallback)"
133- )
131+ assert (
132+ torch . ops . aten . cumsum . default in nodes_by_target
133+ ), "cumsum should remain in the graph (runs as PyTorch fallback)"
134134
135135 # The view_as_complex output feeds directly into cumsum
136136 vc_node = nodes_by_target [torch .ops .aten .view_as_complex .default ][0 ]
137137 cumsum_node = nodes_by_target [torch .ops .aten .cumsum .default ][0 ]
138- assert cumsum_node . args [ 0 ] is vc_node , (
139- f"cumsum's first arg should be the view_as_complex node, got { cumsum_node .args [0 ]} "
140- )
138+ assert (
139+ cumsum_node .args [0 ] is vc_node
140+ ), f"cumsum's first arg should be the view_as_complex node, got { cumsum_node . args [ 0 ] } "
141141
142142 # The view_as_complex input is a real-layout (is_complex_layout) node
143143 vc_input = vc_node .args [0 ]
144144 assert isinstance (vc_input , torch .fx .Node ), "view_as_complex input must be a Node"
145- assert vc_input .meta .get ("is_complex_layout" , False ), (
146- "view_as_complex input should be a real-layout complex node ( is_complex_layout=True)"
147- )
145+ assert vc_input .meta .get (
146+ "is_complex_layout" , False
147+ ), "view_as_complex input should be a real-layout complex node (is_complex_layout=True)"
148148
149149 # view_as_real must follow cumsum
150- assert torch . ops . aten . view_as_real . default in nodes_by_target , (
151- "Expected view_as_real to be inserted after cumsum, but it was not found"
152- )
150+ assert (
151+ torch . ops . aten . view_as_real . default in nodes_by_target
152+ ), "Expected view_as_real to be inserted after cumsum, but it was not found"
153153 vr_node = nodes_by_target [torch .ops .aten .view_as_real .default ][0 ]
154- assert vr_node . args [ 0 ] is cumsum_node , (
155- f"view_as_real's arg should be the cumsum node, got { vr_node .args [0 ]} "
156- )
154+ assert (
155+ vr_node .args [0 ] is cumsum_node
156+ ), f"view_as_real's arg should be the cumsum node, got { vr_node . args [ 0 ] } "
157157
158158 # After metadata propagation, cumsum receives a complex-dtype tensor
159159 vc_val = vc_node .meta .get ("val" )
160160 if vc_val is not None :
161- assert vc_val .dtype in (torch .complex64 , torch .complex128 ), (
162- f"view_as_complex output should be complex, got { vc_val .dtype } "
163- )
161+ assert vc_val .dtype in (
162+ torch .complex64 ,
163+ torch .complex128 ,
164+ ), f"view_as_complex output should be complex, got { vc_val .dtype } "
164165
165166
166167# ===========================================================================
@@ -221,9 +222,10 @@ def test_complex_partial_lowering_with_graph_break() -> None:
221222 if n .target == torch .ops .aten .cumsum .default :
222223 vc_val = n .args [0 ].meta .get ("val" )
223224 if vc_val is not None :
224- assert vc_val .dtype in (torch .complex64 , torch .complex128 ), (
225- f"cumsum should receive a complex tensor, got { vc_val .dtype } "
226- )
225+ assert vc_val .dtype in (
226+ torch .complex64 ,
227+ torch .complex128 ,
228+ ), f"cumsum should receive a complex tensor, got { vc_val .dtype } "
227229 break
228230
229231 # End-to-end: compile and verify numerical correctness
0 commit comments