File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -433,29 +433,30 @@ def run(
433433 torch .cuda .synchronize ()
434434
435435 for c in self .connections :
436- flad_m = False
437- if A_Minus != None and ((isinstance (A_Minus , float )) or (c in A_Minus )):
438- if A_MD :
439- kwargs ["a_minus" ] = A_Minus [c ]
440- else :
441- kwargs ["a_minus" ] = A_Minus
442- flad_m = True
436+ with stream ():
437+ flad_m = False
438+ if A_Minus != None and ((isinstance (A_Minus , float )) or (c in A_Minus )):
439+ if A_MD :
440+ kwargs ["a_minus" ] = A_Minus [c ]
441+ else :
442+ kwargs ["a_minus" ] = A_Minus
443+ flad_m = True
443444
444- flad_p = False
445- if A_Plus != None and ((isinstance (A_Plus , float )) or (c in A_Plus )):
446- if A_PD :
447- kwargs ["a_plus" ] = A_Plus [c ]
448- else :
449- kwargs ["a_plus" ] = A_Plus
450- flad_p = True
451-
452- self .connections [c ].update (
453- mask = masks .get (c , None ), learning = self .learning , ** kwargs
454- )
455- if flad_m :
456- kwargs .pop ("a_minus" )
457- if flad_p :
458- kwargs .pop ("a_plus" )
445+ flad_p = False
446+ if A_Plus != None and ((isinstance (A_Plus , float )) or (c in A_Plus )):
447+ if A_PD :
448+ kwargs ["a_plus" ] = A_Plus [c ]
449+ else :
450+ kwargs ["a_plus" ] = A_Plus
451+ flad_p = True
452+
453+ self .connections [c ].update (
454+ mask = masks .get (c , None ), learning = self .learning , ** kwargs
455+ )
456+ if flad_m :
457+ kwargs .pop ("a_minus" )
458+ if flad_p :
459+ kwargs .pop ("a_plus" )
459460
460461 # # Get input to all layers.
461462 # current_inputs.update(self._get_inputs())
You can’t perform that action at this time.
0 commit comments