@@ -59,7 +59,8 @@ def worker_fn():
5959 device = get_world_group ().device )
6060 tensor = torch .ones (16 , 1024 , 1024 ,
6161 dtype = torch .float32 ).cuda (pynccl_comm .rank )
62- tensor = pynccl_comm .all_reduce (tensor )
62+ with pynccl_comm .change_state (enable = True ):
63+ tensor = pynccl_comm .all_reduce (tensor )
6364 torch .cuda .synchronize ()
6465 assert torch .all (tensor == pynccl_comm .world_size ).cpu ().item ()
6566
@@ -80,16 +81,17 @@ def multiple_allreduce_worker_fn():
8081 group = groups [0 ] if torch .distributed .get_rank () in [0 , 1 ] else groups [1 ]
8182 pynccl_comm = PyNcclCommunicator (group = group , device = device )
8283 tensor = torch .ones (16 , 1024 , 1024 , dtype = torch .float32 , device = device )
83- # two groups can communicate independently
84- if torch .distributed .get_rank () in [0 , 1 ]:
85- tensor = pynccl_comm .all_reduce (tensor )
86- tensor = pynccl_comm .all_reduce (tensor )
87- torch .cuda .synchronize ()
88- assert torch .all (tensor == 4 ).cpu ().item ()
89- else :
90- tensor = pynccl_comm .all_reduce (tensor )
91- torch .cuda .synchronize ()
92- assert torch .all (tensor == 2 ).cpu ().item ()
84+ with pynccl_comm .change_state (enable = True ):
85+ # two groups can communicate independently
86+ if torch .distributed .get_rank () in [0 , 1 ]:
87+ tensor = pynccl_comm .all_reduce (tensor )
88+ tensor = pynccl_comm .all_reduce (tensor )
89+ torch .cuda .synchronize ()
90+ assert torch .all (tensor == 4 ).cpu ().item ()
91+ else :
92+ tensor = pynccl_comm .all_reduce (tensor )
93+ torch .cuda .synchronize ()
94+ assert torch .all (tensor == 2 ).cpu ().item ()
9395
9496
9597@pytest .mark .skipif (torch .cuda .device_count () < 4 ,
@@ -135,7 +137,9 @@ def worker_fn_with_cudagraph():
135137 # run something in the default stream to initialize torch engine
136138 a = torch .ones ((4 , 4 ), device = f'cuda:{ pynccl_comm .rank } ' )
137139 torch .cuda .synchronize ()
138- with torch .cuda .graph (graph ):
140+ with torch .cuda .graph (
141+ graph , stream = pynccl_comm .stream ), pynccl_comm .change_state (
142+ enable = True ):
139143 a_out = pynccl_comm .all_reduce (a )
140144 torch .cuda .synchronize ()
141145 graph .replay ()
@@ -164,7 +168,8 @@ def all_gather_worker_fn():
164168 for r in range (world_size )
165169 ]).to (device )
166170
167- pynccl_comm .all_gather (result , tensor )
171+ with pynccl_comm .change_state (enable = True ):
172+ pynccl_comm .all_gather (result , tensor )
168173 torch .cuda .synchronize ()
169174 torch .testing .assert_close (result , expected , rtol = 1e-5 , atol = 1e-8 )
170175
@@ -201,7 +206,8 @@ def reduce_scatter_worker_fn():
201206 expected = sum (tensor [rank * scattered_size :(rank + 1 ) * scattered_size ]
202207 for tensor in all_tensors ).to (device )
203208
204- pynccl_comm .reduce_scatter (result , tensor )
209+ with pynccl_comm .change_state (enable = True ):
210+ pynccl_comm .reduce_scatter (result , tensor )
205211 torch .cuda .synchronize ()
206212 torch .testing .assert_close (result , expected , rtol = 1e-5 , atol = 1e-8 )
207213
@@ -228,13 +234,15 @@ def send_recv_worker_fn():
228234 else :
229235 tensor = torch .empty (16 , 1024 , 1024 ,
230236 dtype = torch .float32 ).cuda (pynccl_comm .rank )
231-
232- if pynccl_comm .rank == 0 :
233- pynccl_comm .send (tensor ,
234- dst = (pynccl_comm .rank + 1 ) % pynccl_comm .world_size )
235- else :
236- pynccl_comm .recv (tensor ,
237- src = (pynccl_comm .rank - 1 ) % pynccl_comm .world_size )
237+ with pynccl_comm .change_state (enable = True ):
238+ if pynccl_comm .rank == 0 :
239+ pynccl_comm .send (tensor ,
240+ dst = (pynccl_comm .rank + 1 ) %
241+ pynccl_comm .world_size )
242+ else :
243+ pynccl_comm .recv (tensor ,
244+ src = (pynccl_comm .rank - 1 ) %
245+ pynccl_comm .world_size )
238246 torch .cuda .synchronize ()
239247 assert torch .all (tensor == 1 ).cpu ().item ()
240248
@@ -265,12 +273,15 @@ def multiple_send_recv_worker_fn():
265273 1024 ,
266274 dtype = torch .float32 ,
267275 device = device )
268- if torch .distributed .get_rank () in [0 , 1 ]:
269- pynccl_comm .send (tensor ,
270- dst = (pynccl_comm .rank + 1 ) % pynccl_comm .world_size )
271- else :
272- pynccl_comm .recv (tensor ,
273- src = (pynccl_comm .rank - 1 ) % pynccl_comm .world_size )
276+ with pynccl_comm .change_state (enable = True ):
277+ if torch .distributed .get_rank () in [0 , 1 ]:
278+ pynccl_comm .send (tensor ,
279+ dst = (pynccl_comm .rank + 1 ) %
280+ pynccl_comm .world_size )
281+ else :
282+ pynccl_comm .recv (tensor ,
283+ src = (pynccl_comm .rank - 1 ) %
284+ pynccl_comm .world_size )
274285 torch .cuda .synchronize ()
275286 if torch .distributed .get_rank () in [0 , 2 ]:
276287 assert torch .all (tensor == 1 ).cpu ().item ()
0 commit comments