@@ -25,7 +25,7 @@ def __init__(self, config=None):
2525 self .model_runnable_predicator = self ._make_model_runnable_predicator (
2626 self .config
2727 )
28- self .num_successful_handled_models = 0
28+ self .num_handled_models = 0
2929
3030 def _make_data_input_predicator (self , config ):
3131 module = load_module (config ["data_input_predicator_filepath" ])
@@ -51,7 +51,7 @@ def _make_config(
5151 model_path_prefix = "" ,
5252 resume = False ,
5353 last_model_log_file = None ,
54- limits_successfully_handled_models = None ,
54+ limits_handled_models = None ,
5555 ):
5656 if data_input_predicator_config is None :
5757 data_input_predicator_config = {}
@@ -72,7 +72,7 @@ def _make_config(
7272 "dimension_generalizer_class_name" : dimension_generalizer_class_name ,
7373 "dimension_generalizer_config" : dimension_generalizer_config ,
7474 "last_model_log_file" : last_model_log_file ,
75- "limits_successfully_handled_models " : limits_successfully_handled_models ,
75+ "limits_handled_models " : limits_handled_models ,
7676 }
7777
7878 def __call__ (self , model_path ):
@@ -125,16 +125,15 @@ def is_dyn_dim_cstr_feasible(dyn_dim_cstr):
125125 )
126126 self ._save_dyn_dim_cstr (dyn_dim_cstr , model_path )
127127 self ._save_dim_gen_pass_names (dim_gen_pass_names , model_path )
128- if len (dyn_dim_cstr .symbols ) > 0 :
129- self .num_successful_handled_models += 1
130- limits = self .config ["limits_successfully_handled_models" ]
131- if limits is not None :
132- if self .num_successful_handled_models > limits :
133- print (
134- "`num_successful_handled_models` exceeds config `limits_successfully_handled_models`" ,
135- file = sys .stderr ,
136- )
137- sys .exit (0 )
128+ self .num_handled_models += 1
129+ limits = self .config ["limits_handled_models" ]
130+ if limits is not None :
131+ if self .num_handled_models >= limits :
132+ print (
133+ "`num_handled_models` exceeds config `limits_handled_models`" ,
134+ file = sys .stderr ,
135+ )
136+ sys .exit (0 )
138137
139138 def get_dimension_generalizer (self ):
140139 if hasattr (self , "_dim_generalizer" ):
@@ -159,6 +158,7 @@ def get_model(self, model_path):
159158 def _try_dimension_generalization (self , dim_axes_pairs , model_path , inputs ):
160159 logging .warning ("enter _try_dimension_generalization" )
161160 if self .config ["dimension_generalizer_filepath" ] is None :
161+ self ._save_model_to_log_file (model_path )
162162 yield model_path , ()
163163 return
164164 model = self .get_model (model_path )
@@ -168,6 +168,7 @@ def _try_dimension_generalization(self, dim_axes_pairs, model_path, inputs):
168168 need_rewrite = dim_gen_pass .need_rewrite (inputs )
169169 logging .warning ("after need_rewrite" )
170170 if not need_rewrite :
171+ self ._save_model_to_log_file (model_path )
171172 yield model_path , ()
172173 return
173174
@@ -177,11 +178,14 @@ def _try_dimension_generalization(self, dim_axes_pairs, model_path, inputs):
177178 with tempfile .TemporaryDirectory () as tmp_dir :
178179 shutil .copytree (Path (model_path ), Path (tmp_dir ), dirs_exist_ok = True )
179180 dim_gen_pass .save_graph_module (graph_module , tmp_dir )
180- if self .config ["last_model_log_file" ] is not None :
181- log_file = Path (self .config ["last_model_log_file" ])
182- shutil .copy (Path (tmp_dir ) / "model.py" , log_file )
181+ self ._save_model_to_log_file (tmp_dir )
183182 yield tmp_dir , dim_gen_pass .get_pass_names ()
184183
184+ def _save_model_to_log_file (self , model_path ):
185+ if self .config ["last_model_log_file" ] is not None :
186+ log_file = Path (self .config ["last_model_log_file" ])
187+ shutil .copy (Path (model_path ) / "model.py" , log_file )
188+
185189 def _save_dim_gen_pass_names (self , dim_gen_pass_names , model_path ):
186190 from graph_net .graph_net_json_file_util import kDimensionGeneralizationPasses
187191
@@ -324,7 +328,7 @@ def append_dim_gen_pass_names(dim_gen_pass_names):
324328 )
325329
326330 for i , picked_dim in enumerate (unique_dims ):
327- logging .warning (f"{ i = } { picked_dim = } " )
331+ logging .warning (f"{ i = } { picked_dim = } { dim2axes [ picked_dim ] = } " )
328332 cur_dyn_dim_cstr = copy .deepcopy (dyn_dim_cstr )
329333
330334 def filter_fn (input_name , input_idx , axis , dim ):
0 commit comments