55import cantera as ct
66
77device_main = "cuda:0"
8- device_list = range (torch .cuda .device_count ())
8+ device_list = [ 0 ] # range(torch.cuda.device_count())
99
1010torch .set_printoptions (precision = 10 )
1111
@@ -95,16 +95,18 @@ def forward(self, x):
9595 Ymu2 = Ymu0
9696 Ystd2 = Ystd0
9797
98+
99+ """
98100 #load model
99101 layers = [n_species +2, 1600, 800, 400, 1]
100-
102+
101103 model0list = []
102104 for i in range(n_species-1):
103105 model0list.append(NN_MLP(layers))
104106
105107 for i in range(n_species-1):
106108 model0list[i].load_state_dict(state_dict[f'net{i}'])
107-
109+
108110
109111 for i in range(n_species-1):
110112 model0list[i].eval()
@@ -113,7 +115,23 @@ def forward(self, x):
113115 if len(device_ids) > 1:
114116 for i in range(n_species-1):
115117 model0list[i] = torch.nn.DataParallel(model0list[i], device_ids=device_ids)
118+ """
119+
120+ #load model
121+ layers = [2 + n_species ]+ [400 ]* 4 + [ n_species - 1 ]
122+ # layers = [2+n_species]+[800,400,200,100]+[n_species-1]
123+
124+
125+ model = NN_MLP (layers )
126+
127+ model .load_state_dict (state_dict ['net' ])
128+
129+ model .eval ()
130+ model .to (device = device )
116131
132+ if len (device_ids ) > 1 :
133+ model = torch .nn .DataParallel (model , device_ids = device_ids )
134+
117135except Exception as e :
118136 print (e .args )
119137
@@ -126,6 +144,8 @@ def inference(vec0):
126144 '''
127145 vec0 = np .abs (np .reshape (vec0 , (- 1 , 3 + n_species ))) # T, P, Yi(7), Rho
128146 vec0 [:,1 ] *= 101325
147+ # vec0[:,1] *= 0
148+ # vec0[:,1] += 101325
129149 mask = vec0 [:,0 ] > frozenTemperature
130150 vec0_input = vec0 [mask , :]
131151 print (f'real inference points number: { vec0_input .shape [0 ]} ' )
@@ -148,9 +168,11 @@ def inference(vec0):
148168 #inference
149169
150170 output0_normalized = []
151- for i in range (n_species - 1 ):
152- output0_normalized .append (model0list [i ](input0_normalized ))
153- output0_normalized = torch .cat (output0_normalized , dim = 1 )
171+
172+ #for i in range(n_species-1):
173+ # output0_normalized.append(model0list[i](input0_normalized))
174+ #output0_normalized = torch.cat(output0_normalized, dim=1)
175+ output0_normalized = model (input0_normalized )
154176
155177 # post_processing
156178 output0_bct = output0_normalized * Ystd0 + Ymu0 + input0_bct [:, 2 :- 1 ]
0 commit comments