Skip to content

Commit 95b4e21

Browse files
committed
Add gemd loss into trainer
1 parent f27f2b7 commit 95b4e21

File tree

8 files changed

+548
-22
lines changed

8 files changed

+548
-22
lines changed

configs/resources/maxsub.json

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
{
2+
"1": [1],
3+
"2": [1,2],
4+
"3": [1,3,4,5],
5+
"4": [1,4],
6+
"5": [1,3,4,5],
7+
"6": [1,6,7,8],
8+
"7": [1,7,9],
9+
"8": [1,6,7,8,9],
10+
"9": [1,7,9],
11+
"10": [2,3,6,10,11,12,13],
12+
"11": [2,4,6,11,14],
13+
"12": [2,5,8,10,11,12,13,14,15],
14+
"13": [2,3,7,13,14,15],
15+
"14": [2,4,7,14],
16+
"15": [2,5,9,13,14,15],
17+
"16": [3,16,17,21,22],
18+
"17": [3,4,17,18,20],
19+
"18": [3,4,18,19],
20+
"19": [4,19],
21+
"20": [4,5,17,18,19,20],
22+
"21": [3,5,16,17,18,20,21,23,24],
23+
"22": [5,20,21,22],
24+
"23": [5,16,18,23],
25+
"24": [5,17,19,24],
26+
"25": [3,6,25,26,27,28,35,38,39,42],
27+
"26": [4,6,7,26,29,31,36],
28+
"27": [3,7,27,30,37],
29+
"28": [3,6,7,28,29,30,31,32,40,41],
30+
"29": [4,7,29,33],
31+
"30": [3,7,30,34],
32+
"31": [4,6,7,31,33],
33+
"32": [3,7,32,33,34],
34+
"33": [4,7,33],
35+
"34": [3,7,34,43],
36+
"35": [3,8,25,28,32,35,36,37,44,45,46],
37+
"36": [4,8,9,26,29,31,33,36],
38+
"37": [3,9,27,30,34,37],
39+
"38": [5,6,8,25,26,30,31,38,40,44,46],
40+
"39": [5,7,8,26,27,28,29,39,41,45,46],
41+
"40": [5,6,9,28,31,33,34,40],
42+
"41": [5,7,9,29,30,32,33,41],
43+
"42": [5,8,35,36,37,38,39,40,41,42],
44+
"43": [5,9,43],
45+
"44": [5,8,25,31,34,44],
46+
"45": [5,9,27,29,32,45],
47+
"46": [5,8,9,26,28,30,33,46],
48+
"47": [10,16,25,47,49,51,65,67,69],
49+
"48": [13,16,34,48,70],
50+
"49": [10,13,16,27,28,49,50,53,54,66,68],
51+
"50": [13,16,30,32,48,50,52],
52+
"51": [10,11,13,17,25,26,28,51,53,54,55,57,59,63,64],
53+
"52": [13,14,17,30,33,34,52],
54+
"53": [10,13,14,17,28,30,31,52,53,58,60],
55+
"54": [13,14,17,27,29,32,52,54,56,60],
56+
"55": [10,14,18,26,32,55,58,62],
57+
"56": [13,14,18,27,33,56],
58+
"57": [11,13,14,18,26,28,29,57,60,61,62],
59+
"58": [10,14,18,31,34,58],
60+
"59": [11,13,18,25,31,56,59,62],
61+
"60": [13,14,18,29,30,33,60],
62+
"61": [14,19,29,61],
63+
"62": [11,14,19,26,31,33,62],
64+
"63": [11,12,15,20,36,38,40,51,52,57,58,59,60,62,63],
65+
"64": [12,14,15,20,36,39,41,53,54,55,56,57,60,61,62,64],
66+
"65": [10,12,21,35,38,47,50,51,53,55,59,63,65,66,71,72,74],
67+
"66": [10,15,21,37,40,48,49,52,53,56,58,66],
68+
"67": [12,13,21,35,39,49,51,54,57,64,67,68,72,73,74],
69+
"68": [13,15,21,37,41,50,52,54,60,68],
70+
"69": [12,22,42,63,64,65,66,67,68,69],
71+
"70": [15,22,43,70],
72+
"71": [12,23,44,47,48,58,59,71],
73+
"72": [12,15,23,45,46,49,50,55,56,57,60,72],
74+
"73": [15,24,45,54,61,73],
75+
"74": [12,15,24,44,46,51,52,53,62,74],
76+
"75": [3,75,77,79],
77+
"76": [4,76,78],
78+
"77": [3,76,77,78,80],
79+
"78": [4,76,78],
80+
"79": [5,75,77,79],
81+
"80": [5,76,78,80],
82+
"81": [3,81,82],
83+
"82": [5,81,82],
84+
"83": [10,75,81,83,84,85,87],
85+
"84": [10,77,81,84,86],
86+
"85": [13,75,81,85,86],
87+
"86": [13,77,81,86,88],
88+
"87": [12,79,82,83,84,85,86,87],
89+
"88": [15,80,82,88],
90+
"89": [16,21,75,89,90,93,97],
91+
"90": [18,21,75,90,94],
92+
"91": [17,20,76,91,92,95],
93+
"92": [19,20,76,92,96],
94+
"93": [16,21,77,91,93,94,95,98],
95+
"94": [18,21,77,92,94,96],
96+
"95": [17,20,78,91,95,96],
97+
"96": [19,20,78,92,96],
98+
"97": [22,23,79,89,90,93,94,97],
99+
"98": [22,24,80,91,92,95,96,98],
100+
"99": [25,35,75,99,100,101,103,105,107,108],
101+
"100": [32,35,75,100,102,104,106],
102+
"101": [27,35,77,101,105,106],
103+
"102": [34,35,77,102,109,110],
104+
"103": [27,37,75,103,104],
105+
"104": [34,37,75,104],
106+
"105": [25,37,77,101,102,105],
107+
"106": [32,37,77,106],
108+
"107": [42,44,79,99,102,104,105,107],
109+
"108": [42,45,79,100,101,103,106,108],
110+
"109": [43,44,80,109],
111+
"110": [43,45,80,110],
112+
"111": [16,35,81,111,112,115,117,119,120],
113+
"112": [16,37,81,112,116,118],
114+
"113": [18,35,81,113,114],
115+
"114": [18,37,81,114],
116+
"115": [21,25,81,111,113,115,116,121],
117+
"116": [21,27,81,112,114,116],
118+
"117": [21,32,81,117,118],
119+
"118": [21,34,81,118,122],
120+
"119": [22,44,82,115,118,119],
121+
"120": [22,45,82,116,117,120],
122+
"121": [23,42,82,111,112,113,114,121],
123+
"122": [24,43,82,122],
124+
"123": [47,65,83,89,99,111,115,123,124,125,127,129,131,132,139,140],
125+
"124": [49,66,83,89,103,112,116,124,126,128,130],
126+
"125": [50,67,85,89,100,111,117,125,126,133,134],
127+
"126": [48,68,85,89,104,112,118,126],
128+
"127": [55,65,83,90,100,113,117,127,128,135,136],
129+
"128": [58,66,83,90,104,114,118,128],
130+
"129": [59,67,85,90,99,113,115,129,130,137,138],
131+
"130": [56,68,85,90,103,114,116,130],
132+
"131": [47,66,84,93,105,112,115,131,132,134,136,138],
133+
"132": [49,65,84,93,101,111,116,131,132,133,135,137],
134+
"133": [50,68,86,93,106,112,117,133],
135+
"134": [48,67,86,93,102,111,118,134,141,142],
136+
"135": [55,66,84,94,106,114,117,135],
137+
"136": [58,65,84,94,102,113,118,136],
138+
"137": [59,68,86,94,105,114,115,137],
139+
"138": [56,67,86,94,101,113,116,138],
140+
"139": [69,71,87,97,107,119,121,123,126,128,129,131,134,136,137,139],
141+
"140": [69,72,87,97,108,120,121,124,125,127,130,132,133,135,138,140],
142+
"141": [70,74,88,98,109,119,122,141],
143+
"142": [70,73,88,98,110,120,122,142],
144+
"143": [1,143,144,145,146],
145+
"144": [1,144,145],
146+
"145": [1,144,145],
147+
"146": [1,143,144,145,146],
148+
"147": [2,143,147,148],
149+
"148": [2,146,147,148],
150+
"149": [5,143,149,150,151,153,155],
151+
"150": [5,143,149,150,152,154],
152+
"151": [5,144,151,152,153],
153+
"152": [5,144,151,152,154],
154+
"153": [5,145,151,153,154],
155+
"154": [5,145,152,153,154],
156+
"155": [5,146,150,152,154,155],
157+
"156": [8,143,156,157,158],
158+
"157": [8,143,156,157,159,160],
159+
"158": [9,143,158,159],
160+
"159": [9,143,158,159,161],
161+
"160": [8,146,156,160,161],
162+
"161": [9,146,158,161],
163+
"162": [12,147,149,157,162,163,164,166],
164+
"163": [15,147,149,159,163,165,167],
165+
"164": [12,147,150,156,162,164,165],
166+
"165": [15,147,150,158,163,165],
167+
"166": [12,148,155,160,164,166,167],
168+
"167": [15,148,155,161,165,167],
169+
"168": [3,143,168,171,172,173],
170+
"169": [4,144,169,170],
171+
"170": [4,145,169,170],
172+
"171": [3,145,169,171,172],
173+
"172": [3,144,170,171,172],
174+
"173": [4,143,169,170,173],
175+
"174": [6,143,174],
176+
"175": [10,147,168,174,175,176],
177+
"176": [11,147,173,174,176],
178+
"177": [21,149,150,168,177,180,181,182],
179+
"178": [20,151,152,169,178,179],
180+
"179": [20,153,154,170,178,179],
181+
"180": [21,153,154,171,178,180,181],
182+
"181": [21,151,152,172,179,180,181],
183+
"182": [20,149,150,173,178,179,182],
184+
"183": [35,156,157,168,183,184,185,186],
185+
"184": [37,158,159,168,184],
186+
"185": [36,157,158,173,185,186],
187+
"186": [36,156,159,173,185,186],
188+
"187": [38,149,156,174,187,188,189],
189+
"188": [40,149,158,174,188,190],
190+
"189": [38,150,157,174,187,189,190],
191+
"190": [40,150,159,174,188,190],
192+
"191": [65,162,164,175,177,183,187,189,191,192,193,194],
193+
"192": [66,163,165,175,177,184,188,190,192],
194+
"193": [63,162,165,176,182,185,188,189,193,194],
195+
"194": [63,163,164,176,182,186,187,190,193,194],
196+
"195": [16,146,196,197,199],
197+
"196": [22,146,195,198],
198+
"197": [23,146,195],
199+
"198": [19,146],
200+
"199": [24,146,198],
201+
"200": [47,148,195,202,204,206],
202+
"201": [48,148,195,203],
203+
"202": [69,148,196,200,201,205],
204+
"203": [70,148,196],
205+
"204": [71,148,197,200,201],
206+
"205": [61,148,198],
207+
"206": [73,148,199,205],
208+
"207": [89,155,195,209,211],
209+
"208": [93,155,195,210,214],
210+
"209": [97,155,196,207,208],
211+
"210": [98,155,196,212,213],
212+
"211": [97,155,197,207,208],
213+
"212": [96,155,198],
214+
"213": [92,155,198],
215+
"214": [98,155,199,212,213],
216+
"215": [111,160,195,216,217,219],
217+
"216": [119,160,196,215],
218+
"217": [121,160,197,215,218],
219+
"218": [112,161,195,220],
220+
"219": [120,161,196,218],
221+
"220": [122,161,199],
222+
"221": [123,166,200,207,215,225,226,229],
223+
"222": [126,167,201,207,218],
224+
"223": [131,167,200,208,218,230],
225+
"224": [134,166,201,208,215,227,228],
226+
"225": [139,166,202,209,216,221,224],
227+
"226": [140,167,202,209,219,222,223],
228+
"227": [141,166,203,210,216],
229+
"228": [142,167,203,210,219],
230+
"229": [139,166,204,211,217,221,222,223,224],
231+
"230": [142,167,206,214,220]
232+
}

configs/simulator.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@ scaler_range: [1.0, 1.0]
3636
wl_range: [0.6199, 0.6199]
3737

3838
# --- Noise Parameter Ranges ---
39-
proportional_noise_range: [0.002, 0.02]
40-
constant_noise_range: [2, 5]
39+
# Rely on dynamic train-time noise augmentation in trainer config, so set to zero for now.
40+
proportional_noise_range: [0.0, 0.0]
41+
constant_noise_range: [0.0, 0.0]

configs/trainer.docker.yaml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ model:
6060
pooling_type: "average"
6161
final_pool: true
6262
use_batchnorm: false
63-
activation: "leaky_relu"
63+
activation: "gelu"
6464
output_type: "flatten"
6565

6666
heads:
@@ -75,16 +75,16 @@ model:
7575
num_lp_outputs: 6
7676

7777
lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
78-
lp_bounds_max: [300.0, 300.0, 300.0, 180.0, 180.0, 180.0]
78+
lp_bounds_max: [500.0, 500.0, 500.0, 180.0, 180.0, 180.0]
7979
bound_lp_with_sigmoid: true
8080

8181
loss:
8282
lambda_cs: 1.0
8383
lambda_sg: 1.0
8484
lambda_lp: 1.0
8585

86-
gemd_mu: 0.0
87-
gemd_distance_matrix_path: null
86+
gemd_mu: 1.0
87+
gemd_distance_matrix_path: "/app/configs/resources/maxsub.json"
8888

8989
optimizer:
9090
lr: 0.0002
@@ -94,7 +94,7 @@ optimizer:
9494
gradient_clip_algorithm: "norm"
9595

9696
trainer:
97-
default_root_dir: "/outputs/convnext_paper"
97+
default_root_dir: "/outputs/convnext_full_replication"
9898
max_epochs: 100
9999
accumulate_grad_batches: 1
100100
precision: "32"
@@ -106,10 +106,10 @@ trainer:
106106

107107
logging:
108108
logger: "mlflow"
109-
csv_logger_name: "model_logs_convnext_paper"
110-
mlflow_experiment_name: "AlphaDiffract_Paper_ConvNeXt"
109+
csv_logger_name: "model_logs_convnext_full_replication"
110+
mlflow_experiment_name: "AlphaDiffract_Full_Replication"
111111
mlflow_tracking_uri: "file:/outputs/mlruns"
112-
mlflow_run_name: "ConvNeXt_Paper_Run"
112+
mlflow_run_name: "ConvNeXt_Full_Replication_Run"
113113

114114
checkpointing:
115115
monitor: "val/loss"

configs/trainer.local.yaml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ model:
6161
pooling_type: "average"
6262
final_pool: true
6363
use_batchnorm: false
64-
activation: "leaky_relu"
64+
activation: "gelu"
6565
output_type: "flatten"
6666

6767
heads:
@@ -76,16 +76,16 @@ model:
7676
num_lp_outputs: 6
7777

7878
lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
79-
lp_bounds_max: [300.0, 300.0, 300.0, 180.0, 180.0, 180.0]
79+
lp_bounds_max: [500.0, 500.0, 500.0, 180.0, 180.0, 180.0]
8080
bound_lp_with_sigmoid: true
8181

8282
loss:
8383
lambda_cs: 1.0
8484
lambda_sg: 1.0
8585
lambda_lp: 1.0
8686

87-
gemd_mu: 0.0
88-
gemd_distance_matrix_path: null
87+
gemd_mu: 1.0
88+
gemd_distance_matrix_path: "configs/resources/maxsub.json"
8989

9090
optimizer:
9191
lr: 0.0002
@@ -95,7 +95,7 @@ optimizer:
9595
gradient_clip_algorithm: "norm"
9696

9797
trainer:
98-
default_root_dir: "outputs/convnext_paper"
98+
default_root_dir: "outputs/convnext_full_replication"
9999
max_epochs: 100
100100
accumulate_grad_batches: 1
101101
precision: "32" # match OG (AMP disabled)
@@ -107,10 +107,10 @@ trainer:
107107

108108
logging:
109109
logger: "mlflow"
110-
csv_logger_name: "model_logs_convnext_paper"
111-
mlflow_experiment_name: "AlphaDiffract_Paper_ConvNeXt"
110+
csv_logger_name: "model_logs_convnext_full_replication"
111+
mlflow_experiment_name: "AlphaDiffract_Full_Replication"
112112
mlflow_tracking_uri: null
113-
mlflow_run_name: "ConvNeXt_Paper_Run"
113+
mlflow_run_name: "ConvNeXt_Full_Replication_Run"
114114

115115
checkpointing:
116116
monitor: "val/loss"

0 commit comments

Comments
 (0)