From 614760696b0eafcefdf8d27d4322efdce854fff5 Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Wed, 30 Jul 2025 17:06:55 -0400 Subject: [PATCH 1/7] Implemented Automatic Mixed Precision (AMP) for training --- XPointMLTest.py | 790 ++++++++++++++++++++++++++++-------------------- 1 file changed, 468 insertions(+), 322 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index 7b45044..45bf51d 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -20,6 +20,9 @@ from timeit import default_timer as timer +# Import mixed precision training components +from torch.amp import autocast, GradScaler + def expand_xpoints_mask(binary_mask, kernel_size=9): """ Expands each X-point in a binary mask to include surrounding cells @@ -64,14 +67,14 @@ def rotate(frameData,deg): print(f"invalid rotation specified... exiting") sys.exit() psi = v2.functional.rotate(frameData["psi"], deg, v2.InterpolationMode.BILINEAR) - all = v2.functional.rotate(frameData["all"], deg, v2.InterpolationMode.BILINEAR) + all_data = v2.functional.rotate(frameData["all"], deg, v2.InterpolationMode.BILINEAR) mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.BILINEAR) return { "fnum": frameData["fnum"], "rotation": deg, "reflectionAxis": -1, # no reflection "psi": psi, - "all": all, + "all": all_data, "mask": mask, "x": frameData["x"], "y": frameData["y"], @@ -84,14 +87,14 @@ def reflect(frameData,axis): print(f"invalid reflection axis specified... exiting") sys.exit() psi = torch.flip(frameData["psi"][0], dims=(axis,)).unsqueeze(0) - all = torch.flip(frameData["all"], dims=(axis,)) + all_data = torch.flip(frameData["all"], dims=(axis,)) mask = torch.flip(frameData["mask"][0], dims=(axis,)).unsqueeze(0) return { "fnum": frameData["fnum"], "rotation": 0, "reflectionAxis": axis, "psi": psi, - "all": all, + "all": all_data, "mask": mask, "x": frameData["x"], "y": frameData["y"], @@ -100,59 +103,58 @@ def reflect(frameData,axis): } def getPgkylData(paramFile, frameNumber, verbosity): - if verbosity > 0: - print(f"=== frame {frameNumber} ===") - params = {} #Initialize dictionary to store plotting and other parameters - params["polyOrderOverride"] = 0 #Override default dg interpolation and interpolate to given number of points - constrcutBandJ = 1 - #Read vector potential - var = gkData.gkData(str(paramFile),frameNumber,'psi',params).compactRead() - psi = var.data - coords = var.coords - axesNorm = var.d[ var.speciesFileIndex.index('ion') ] - if verbosity > 0: - print(f"psi shape: {psi.shape}, min={psi.min()}, max={psi.max()}") - #Construct B and J (first and second derivatives) - [df_dx,df_dy,df_dz] = auxFuncs.genGradient(psi,var.dx) - [d2f_dxdx,d2f_dxdy,d2f_dxdz] = auxFuncs.genGradient(df_dx,var.dx) - [d2f_dydx,d2f_dydy,d2f_dydz] = auxFuncs.genGradient(df_dy,var.dx) - bx = df_dy - by = -df_dx - jz = -(d2f_dxdx + d2f_dydy) / var.mu0 - del df_dx,df_dy,df_dz,d2f_dxdx,d2f_dxdy,d2f_dxdz,d2f_dydx,d2f_dydy,d2f_dydz - #Indicies of critical points, X points, and O points (max and min) - critPoints = auxFuncs.getCritPoints(psi) - [xpts, optsMax, optsMin] = auxFuncs.getXOPoints(psi, critPoints) - return [var.filenameBase, axesNorm, critPoints, xpts, optsMax, optsMin, coords, psi, bx, by, jz] + if verbosity > 0: + print(f"=== frame {frameNumber} ===") + params = {} #Initialize dictionary to store plotting and other parameters + params["polyOrderOverride"] = 0 #Override default dg interpolation and interpolate to given number of points + #Read vector potential + var = gkData.gkData(str(paramFile),frameNumber,'psi',params).compactRead() + psi = var.data + coords = var.coords + axesNorm = var.d[ var.speciesFileIndex.index('ion') ] + if verbosity > 0: + print(f"psi shape: {psi.shape}, min={psi.min()}, max={psi.max()}") + #Construct B and J (first and second derivatives) + [df_dx,df_dy,df_dz] = auxFuncs.genGradient(psi,var.dx) + [d2f_dxdx,d2f_dxdy,d2f_dxdz] = auxFuncs.genGradient(df_dx,var.dx) + [d2f_dydx,d2f_dydy,d2f_dydz] = auxFuncs.genGradient(df_dy,var.dx) + bx = df_dy + by = -df_dx + jz = -(d2f_dxdx + d2f_dydy) / var.mu0 + del df_dx,df_dy,df_dz,d2f_dxdx,d2f_dxdy,d2f_dxdz,d2f_dydx,d2f_dydy,d2f_dydz + #Indicies of critical points, X points, and O points (max and min) + critPoints = auxFuncs.getCritPoints(psi) + [xpts, optsMax, optsMin] = auxFuncs.getXOPoints(psi, critPoints) + return [var.filenameBase, axesNorm, critPoints, xpts, optsMax, optsMin, coords, psi, bx, by, jz] def cachedPgkylDataExists(cacheDir, frameNumber, fieldName): - if cacheDir == None: - return False - else: - cachedFrame = cacheDir / f"{frameNumber}_{fieldName}.npy" - return cachedFrame.exists(); + if cacheDir is None: + return False + else: + cachedFrame = cacheDir / f"{frameNumber}_{fieldName}.npy" + return cachedFrame.exists() def loadPgkylDataFromCache(cacheDir, frameNumber, fields): - outFields = {} - if cacheDir != None: - for name in fields.keys(): - if name == "fileName": - with open(cacheDir / f"{frameNumber}_{name}.txt", "r") as file: - outFields[name] = file.read().rstrip() - else: - outFields[name] = np.load(cacheDir / f"{frameNumber}_{name}.npy") - return outFields - else: - return None + outFields = {} + if cacheDir is not None: + for name in fields.keys(): + if name == "fileName": + with open(cacheDir / f"{frameNumber}_{name}.txt", "r") as file: + outFields[name] = file.read().rstrip() + else: + outFields[name] = np.load(cacheDir / f"{frameNumber}_{name}.npy") + return outFields + else: + return None def writePgkylDataToCache(cacheDir, frameNumber, fields): - if cacheDir != None: - for name, field in fields.items(): - if name == "fileName": - with open(cacheDir / f"{frameNumber}_{name}.txt", "w") as text_file: - text_file.write(f"{field}") - else: - np.save(cacheDir / f"{frameNumber}_{name}.npy",field) + if cacheDir is not None: + for name, field in fields.items(): + if name == "fileName": + with open(cacheDir / f"{frameNumber}_{name}.txt", "w") as text_file: + text_file.write(f"{field}") + else: + np.save(cacheDir / f"{frameNumber}_{name}.npy",field) # DATASET DEFINITION class XPointDataset(Dataset): @@ -164,10 +166,10 @@ class XPointDataset(Dataset): - Returns (psiTensor, maskTensor) as a PyTorch (float) pair. """ def __init__(self, paramFile, fnumList, xptCacheDir=None, - rotateAndReflect=False, verbosity=0): + rotateAndReflect=False, verbosity=0): """ paramFile: Path to parameter file (string). - fnumList: List of frames to iterate. + fnumList: List of frames to iterate. """ super().__init__() self.paramFile = paramFile @@ -182,7 +184,6 @@ def __init__(self, paramFile, fnumList, xptCacheDir=None, self.params["upperLimits"] = [1e6, 1e6, 0.e6, 1.e6, 1.e6] self.params["restFrame"] = 1 self.params["polyOrderOverride"] = 0 - self.params["plotContours"] = 1 self.params["colorContours"] = 'k' self.params["numContours"] = 50 @@ -190,18 +191,17 @@ def __init__(self, paramFile, fnumList, xptCacheDir=None, self.params["symBar"] = 1 self.params["colormap"] = 'bwr' - # load all the data self.data = [] for fnum in fnumList: frameData = self.load(fnum) self.data.append(frameData) if rotateAndReflect: - self.data.append(rotate(frameData,90)) - self.data.append(rotate(frameData,180)) - self.data.append(rotate(frameData,270)) - self.data.append(reflect(frameData,0)) - self.data.append(reflect(frameData,1)) + self.data.append(rotate(frameData,90)) + self.data.append(rotate(frameData,180)) + self.data.append(rotate(frameData,270)) + self.data.append(reflect(frameData,0)) + self.data.append(reflect(frameData,1)) def __len__(self): return len(self.data) @@ -213,38 +213,29 @@ def load(self, fnum): t0 = timer() # check if cache exists - if self.xptCacheDir != None: - if not self.xptCacheDir.is_dir(): - print(f"Xpoint cache directory {self.xptCacheDir} does not exist... exiting") - sys.exit() + if self.xptCacheDir is not None: + if not self.xptCacheDir.is_dir(): + print(f"Xpoint cache directory {self.xptCacheDir} does not exist... exiting") + sys.exit() + t2 = timer() - - fields = {"psi":None, - "critPts":None, - "xpts":None, - "optsMax":None, - "optsMin":None, - "axesNorm":None, - "coords":None, - "fileName":None, - "Bx":None, "By":None, - "Jz":None} + fields = {"psi":None, "critPts":None, "xpts":None, "optsMax":None, "optsMin":None, "axesNorm":None, "coords":None, "fileName":None, "Bx":None, "By":None, "Jz":None} # Indicies of critical points, X points, and O points (max and min) - if self.xptCacheDir != None and cachedPgkylDataExists(self.xptCacheDir, fnum, "psi"): - fields = loadPgkylDataFromCache(self.xptCacheDir, fnum, fields) + if self.xptCacheDir is not None and cachedPgkylDataExists(self.xptCacheDir, fnum, "psi"): + fields = loadPgkylDataFromCache(self.xptCacheDir, fnum, fields) else: - [fileName, axesNorm, critPoints, xpts, optsMax, optsMin, coords, psi, bx, by, jz] = getPgkylData(self.paramFile, fnum, verbosity=self.verbosity) - fields = {"psi":psi, "critPts":critPoints, "xpts":xpts, - "optsMax":optsMax, "optsMin":optsMin, - "axesNorm": axesNorm, "coords": coords, - "fileName": fileName, - "Bx":bx, "By":by, "Jz":jz} - writePgkylDataToCache(self.xptCacheDir, fnum, fields) + [fileName, axesNorm, critPoints, xpts, optsMax, optsMin, coords, psi, bx, by, jz] = getPgkylData(self.paramFile, fnum, verbosity=self.verbosity) + fields = {"psi":psi, "critPts":critPoints, "xpts":xpts, + "optsMax":optsMax, "optsMin":optsMin, + "axesNorm": axesNorm, "coords": coords, + "fileName": fileName, + "Bx":bx, "By":by, "Jz":jz} + writePgkylDataToCache(self.xptCacheDir, fnum, fields) self.params["axesNorm"] = fields["axesNorm"] if self.verbosity > 0: - print("time (s) to find X and O points: " + str(timer()-t2)) + print("time (s) to find X and O points: " + str(timer()-t2)) # Create array of 0s with 1s only at X points binaryMap = np.zeros(np.shape(fields["psi"])) @@ -253,30 +244,89 @@ def load(self, fnum): binaryMap = expand_xpoints_mask(binaryMap, kernel_size=9) # -------------- 6) Convert to Torch Tensors -------------- - psi_torch = torch.from_numpy(fields["psi"]).float().unsqueeze(0) # [1, Nx, Ny] + psi_torch = torch.from_numpy(fields["psi"]).float().unsqueeze(0) # [1, Nx, Ny] bx_torch = torch.from_numpy(fields["Bx"]).float().unsqueeze(0) by_torch = torch.from_numpy(fields["By"]).float().unsqueeze(0) jz_torch = torch.from_numpy(fields["Jz"]).float().unsqueeze(0) + all_torch = torch.cat((psi_torch,bx_torch,by_torch,jz_torch)) # [4, Nx, Ny] mask_torch = torch.from_numpy(binaryMap).float().unsqueeze(0) # [1, Nx, Ny] if self.verbosity > 0: - print("time (s) to load and process gkyl frame: " + str(timer()-t0)) + print("time (s) to load and process gkyl frame: " + str(timer()-t0)) return { "fnum": fnum, "rotation": 0, "reflectionAxis": -1, # no reflection - "psi": psi_torch, # shape [1, Nx, Ny] - "all": all_torch, # shape [4, Nx, Ny] - "mask": mask_torch, # shape [1, Nx, Ny] + "psi": psi_torch, # shape [1, Nx, Ny] + "all": all_torch, # shape [4, Nx, Ny] + "mask": mask_torch, # shape [1, Nx, Ny] "x": fields["coords"][0], "y": fields["coords"][1], "filenameBase": fields["fileName"], "params": dict(self.params) # copy of the params for local plotting } +class XPointPatchDataset(Dataset): + """On‑the‑fly square crops, balancing positive / background patches.""" + def __init__(self, base_ds, patch=64, pos_ratio=0.6, retries=20): + self.base_ds = base_ds + self.patch = patch + self.pos_ratio = pos_ratio + self.retries = retries + self.rng = np.random.default_rng() + # Precompute some statistics for normalization + self.compute_normalization_stats() + + def compute_normalization_stats(self): + """Compute global mean and std for normalization""" + # Sample a few frames to compute statistics + n_samples = min(10, len(self.base_ds)) + all_values = [] + + for i in range(n_samples): + frame = self.base_ds[i] + all_values.append(frame["all"].numpy()) + + all_values = np.concatenate([v.flatten() for v in all_values]) + self.global_mean = np.mean(all_values) + self.global_std = np.std(all_values) + + # Prevent division by zero + if self.global_std == 0: + self.global_std = 1.0 + + print(f"Computed normalization stats: mean={self.global_mean:.4f}, std={self.global_std:.4f}") + def __len__(self): + # give each full frame K random crops per epoch (K=16 by default) + return len(self.base_ds) * 16 + + def _crop(self, arr, top, left): + return arr[..., top:top+self.patch, left:left+self.patch] + + def __getitem__(self, _): + frame = self.base_ds[self.rng.integers(len(self.base_ds))] + H, W = frame["mask"].shape[-2:] + + # comments on the logic + for attempt in range(self.retries): + y0 = self.rng.integers(0, H - self.patch + 1) + x0 = self.rng.integers(0, W - self.patch + 1) + crop_mask = self._crop(frame["mask"], y0, x0) + has_pos = crop_mask.sum() > 0 + want_pos = (attempt / self.retries) < self.pos_ratio + + if has_pos == want_pos or attempt == self.retries - 1: + crop_all = self._crop(frame["all"], y0, x0) + # Apply global normalization + crop_all = (crop_all - self.global_mean) / self.global_std + + return { + "all" : crop_all, + "mask": crop_mask + } # 2) U-NET ARCHITECTURE class UNet(nn.Module): @@ -285,7 +335,7 @@ class UNet(nn.Module): in: (N, 1, H, W) ++++ BX, BY, JZ out: (N, 1, H, W) """ - def __init__(self, input_channels=1, base_channels=16): + def __init__(self, input_channels=4, base_channels=64): super().__init__() self.enc1 = self.double_conv(input_channels, base_channels) # -> base_channels self.enc2 = self.double_conv(base_channels, base_channels*2) # -> 2*base_channels @@ -306,6 +356,15 @@ def __init__(self, input_channels=1, base_channels=16): self.dec1 = self.double_conv(base_channels*2, base_channels) self.out_conv = nn.Conv2d(base_channels, 1, kernel_size=1) + + # Initialize weights for better stability + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) def double_conv(self, in_ch, out_ch): return nn.Sequential( @@ -317,63 +376,144 @@ def double_conv(self, in_ch, out_ch): def forward(self, x): # Encoder - e1 = self.enc1(x) # shape: [N, base_channels, H, W] - p1 = self.pool(e1) # half spatial dims + e1 = self.enc1(x) # shape: [N, base_channels, H, W] + p1 = self.pool(e1) # half spatial dims - e2 = self.enc2(p1) # [N, 2*base_channels, H/2, W/2] + e2 = self.enc2(p1) # [N, 2*base_channels, H/2, W/2] p2 = self.pool(e2) - e3 = self.enc3(p2) # [N, 4*base_channels, H/4, W/4] - p3 = self.pool(e3) # [N, 4*base_channels, H/8, W/8] + e3 = self.enc3(p2) # [N, 4*base_channels, H/4, W/4] + p3 = self.pool(e3) # [N, 4*base_channels, H/8, W/8] # Bottleneck - b = self.bottleneck(p3) # [N, 8*base_channels, H/8, W/8] + b = self.bottleneck(p3) # [N, 8*base_channels, H/8, W/8] # Decoder - u3 = self.up3(b) # -> shape ~ e3 + u3 = self.up3(b) # -> shape ~ e3 cat3 = torch.cat([u3, e3], dim=1) # skip connection d3 = self.dec3(cat3) - u2 = self.up2(d3) # -> shape ~ e2 + u2 = self.up2(d3) # -> shape ~ e2 cat2 = torch.cat([u2, e2], dim=1) d2 = self.dec2(cat2) - u1 = self.up1(d2) # -> shape ~ e1 + u1 = self.up1(d2) # -> shape ~ e1 cat1 = torch.cat([u1, e1], dim=1) d1 = self.dec1(cat1) out = self.out_conv(d1) return out # We'll apply sigmoid in the loss or after - # TRAIN & VALIDATION UTILS -def train_one_epoch(model, loader, criterion, optimizer, device): +def train_one_epoch(model, loader, criterion, optimizer, device, scaler=None, use_amp=False, amp_dtype=torch.float16): model.train() running_loss = 0.0 - for batch in loader: - all, mask = batch["all"].to(device), batch["mask"].to(device) - pred = model(all) - - loss = criterion(pred, mask) - - optimizer.zero_grad() - loss.backward() - optimizer.step() + num_batches = 0 + num_skipped = 0 + + for batch_idx, batch in enumerate(loader): + all_data, mask = batch["all"].to(device), batch["mask"].to(device) + + if use_amp: + # Clear gradients + optimizer.zero_grad() + + # Use autocast for forward pass + with autocast(device_type='cuda', dtype=amp_dtype): + pred = model(all_data) + loss = criterion(pred, mask) + + # Check if loss is valid + if not torch.isfinite(loss): + print(f"Warning: Non-finite loss detected in batch {batch_idx}, skipping...") + num_skipped += 1 + continue + + # For bfloat16, we don't use GradScaler + if amp_dtype == torch.bfloat16: + loss.backward() + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + if not torch.isfinite(grad_norm): + print(f"Warning: Non-finite gradients detected in batch {batch_idx}, skipping...") + num_skipped += 1 + optimizer.zero_grad() + continue + + optimizer.step() + else: + # Use GradScaler for float16 + scaled_loss = scaler.scale(loss) + scaled_loss.backward() + + # Unscale gradients before clipping + scaler.unscale_(optimizer) + + # Clip gradients + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + # Check gradient norm + if not torch.isfinite(grad_norm): + print(f"Warning: Non-finite gradients detected in batch {batch_idx}, skipping...") + num_skipped += 1 + optimizer.zero_grad() # Clear the invalid gradients + scaler.update() # Update scaler state + continue + + # Optimizer step and scaler update + scaler.step(optimizer) + scaler.update() + + else: + # Standard training without AMP + optimizer.zero_grad() + pred = model(all_data) + loss = criterion(pred, mask) + + if not torch.isfinite(loss): + print(f"Warning: Non-finite loss detected in batch {batch_idx}, skipping...") + num_skipped += 1 + continue + + loss.backward() + + # Gradient clipping + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + if not torch.isfinite(grad_norm): + print(f"Warning: Non-finite gradients detected in batch {batch_idx}, skipping...") + num_skipped += 1 + optimizer.zero_grad() + continue + + optimizer.step() + running_loss += loss.item() - return running_loss / len(loader) + num_batches += 1 + + if num_skipped > 0: + print(f" Skipped {num_skipped} batches due to numerical issues") + + return running_loss / max(num_batches, 1) -def validate_one_epoch(model, loader, criterion, device): +def validate_one_epoch(model, loader, criterion, device, use_amp=False, amp_dtype=torch.float16): model.eval() val_loss = 0.0 with torch.no_grad(): for batch in loader: - all, mask = batch["all"].to(device), batch["mask"].to(device) - pred = model(all) - loss = criterion(pred, mask) + all_data, mask = batch["all"].to(device), batch["mask"].to(device) + + if use_amp: + with autocast(device_type='cuda', dtype=amp_dtype): + pred = model(all_data) + loss = criterion(pred, mask) + else: + pred = model(all_data) + loss = criterion(pred, mask) + val_loss += loss.item() return val_loss / len(loader) - class FocalLoss(nn.Module): def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'): """ @@ -418,7 +558,6 @@ def forward(self, inputs, targets): return focal_loss.sum() else: return focal_loss - class DiceLoss(nn.Module): def __init__(self, smooth=1.0, eps=1e-7): @@ -440,6 +579,9 @@ def forward(self, inputs, targets): """ # Apply sigmoid to get probabilities inputs = torch.sigmoid(inputs) + + # Ensure inputs are in valid range to prevent NaN + inputs = torch.clamp(inputs, min=self.eps, max=1.0 - self.eps) inputs = inputs.view(-1) targets = targets.view(-1) @@ -456,13 +598,13 @@ def forward(self, inputs, targets): # PLOTTING FUNCTION def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation, - reflectionAxis, filenameBase, interpFac, - xpoint_mask=None, + reflectionAxis, filenameBase, interpFac, + xpoint_mask=None, titleExtra="", - outDir="plots", + outDir="plots", saveFig=True): """ - Plots the vector potential 'psi_np' as contours, + Plots the vector potential 'psi_np' as contours, then overlays X-points from xpoint_mask (if provided, shape [Nx,Ny]). The figure is saved to outDir """ @@ -473,7 +615,7 @@ def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation, cs = plt.contour( x / params["axesNorm"], y / params["axesNorm"], - np.transpose(psi_np), + np.transpose(psi_np), params["numContours"], colors=params["colorContours"], linewidths=0.75 @@ -498,7 +640,7 @@ def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation, 'xk' ) - # Save the figure if needed (could be removed as we save anyway) + # Save the figure if needed if saveFig: basename = os.path.basename(filenameBase) saveFilename = os.path.join( @@ -510,14 +652,10 @@ def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation, plt.close() -def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, filenameBase, - outDir="plots", saveFig=True): +def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, filenameBase, + outDir="plots", saveFig=True): """ - Visualize model performance comparing predictions with ground truth: - - True Positives (green) - - False Positives (red) - - False Negatives (yellow) - - Background shows psi contours + Visualize model performance comparing predictions with ground truth. """ plt.figure(figsize=(12, 8)) @@ -527,7 +665,7 @@ def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, fi cs = plt.contour( x / params["axesNorm"], y / params["axesNorm"], - np.transpose(psi_np), + np.transpose(psi_np), params["numContours"], colors='k', linewidths=0.75 @@ -547,25 +685,16 @@ def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, fi fn_rows, fn_cols = np.where(fn_mask) if len(tp_rows) > 0: - plt.plot( - x[tp_rows] / params["axesNorm"], - y[tp_cols] / params["axesNorm"], - 'o', color='green', markersize=8, label="True Positives" - ) + plt.plot(x[tp_rows] / params["axesNorm"], y[tp_cols] / params["axesNorm"], + 'o', color='green', markersize=8, label="True Positives") if len(fp_rows) > 0: - plt.plot( - x[fp_rows] / params["axesNorm"], - y[fp_cols] / params["axesNorm"], - 'o', color='red', markersize=8, label="False Positives" - ) + plt.plot(x[fp_rows] / params["axesNorm"], y[fp_cols] / params["axesNorm"], + 'o', color='red', markersize=8, label="False Positives") if len(fn_rows) > 0: - plt.plot( - x[fn_rows] / params["axesNorm"], - y[fn_cols] / params["axesNorm"], - 'o', color='yellow', markersize=8, label="False Negatives" - ) + plt.plot(x[fn_rows] / params["axesNorm"], y[fn_cols] / params["axesNorm"], + 'o', color='yellow', markersize=8, label="False Negatives") plt.xlabel(r"$x/d_i$") plt.ylabel(r"$y/d_i$") @@ -598,17 +727,16 @@ def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, fi def plot_training_history(train_losses, val_losses, save_path='plots/training_history.png'): """ Plots training and validation losses across epochs. - - Parameters: - train_losses (list): List of training losses for each epoch - val_losses (list): List of validation losses for each epoch - save_path (str): Path to save the resulting plot """ plt.figure(figsize=(10, 6)) epochs = range(1, len(train_losses) + 1) - plt.plot(epochs, train_losses, 'b-', label='Training Loss') - plt.plot(epochs, val_losses, 'r-', label='Validation Loss') + # Filter out NaN values for plotting + train_losses_clean = [loss if not np.isnan(loss) else None for loss in train_losses] + val_losses_clean = [loss if not np.isnan(loss) else None for loss in val_losses] + + plt.plot(epochs, train_losses_clean, 'b-', label='Training Loss') + plt.plot(epochs, val_losses_clean, 'r-', label='Validation Loss') plt.title('Training and Validation Loss') plt.xlabel('Epochs') @@ -618,9 +746,12 @@ def plot_training_history(train_losses, val_losses, save_path='plots/training_hi plt.grid(True, linestyle='--', alpha=0.7) # Add some padding to y-axis to make visualization clearer - ymin = min(min(train_losses), min(val_losses)) * 0.9 - ymax = max(max(train_losses), max(val_losses)) * 1.1 - plt.ylim(ymin, ymax) + # Handle case where all values might be NaN + valid_losses = [loss for loss in train_losses + val_losses if loss is not None and not np.isnan(loss)] + if valid_losses: + ymin = min(valid_losses) * 0.9 + ymax = max(valid_losses) * 1.1 + plt.ylim(ymin, ymax) plt.savefig(save_path, dpi=300) print(f"Training history plot saved to {save_path}") @@ -628,88 +759,67 @@ def plot_training_history(train_losses, val_losses, save_path='plots/training_hi def parseCommandLineArgs(): parser = argparse.ArgumentParser(description='ML-based reconnection classifier') - parser.add_argument('--learningRate', type=float, default=1e-5, - help='specify the learning rate') - parser.add_argument('--batchSize', type=int, default=1, - help='specify the batch size') - parser.add_argument('--epochs', type=int, default=2000, - help='specify the number of epochs') - parser.add_argument('--trainFrameFirst', type=int, default=1, - help='specify the number of the first frame used for training') - parser.add_argument('--trainFrameLast', type=int, default=140, - help='specify the number of the last frame (exclusive) used for training') - parser.add_argument('--validationFrameFirst', type=int, default=141, - help='specify the number of the first frame used for validation') - parser.add_argument('--validationFrameLast', type=int, default=150, - help='specify the number of the last frame (exclusive) used for validation') - parser.add_argument('--minTrainingLoss', type=int, default=3, - help=''' - minimum reduction in training loss in orders of magnitude, - set to 0 to disable the check - ''') - parser.add_argument('--checkPointFrequency', type=int, default=10, - help='number of epochs between checkpoints') - parser.add_argument('--paramFile', type=Path, default=None, - help=''' - specify the path to the parameter txt file, the parent - directory of that file must contain the gkyl input training data - ''') - parser.add_argument('--xptCacheDir', type=Path, default=None, - help=''' - specify the path to a directory that will be used to cache - the outputs of the analytic Xpoint finder - ''') - parser.add_argument('--plot', action=argparse.BooleanOptionalAction, - help='create figures of the ground truth X-points and model identified X-points') - parser.add_argument('--plotDir', type=Path, default="./plots", - help='directory where figures are written') + parser.add_argument('--learningRate', type=float, default=1e-5, help='specify the learning rate') + parser.add_argument('--batchSize', type=int, default=1, help='specify the batch size') + parser.add_argument('--epochs', type=int, default=2000, help='specify the number of epochs') + parser.add_argument('--trainFrameFirst', type=int, default=1, help='specify the number of the first frame used for training') + parser.add_argument('--trainFrameLast', type=int, default=140, help='specify the number of the last frame (exclusive) used for training') + parser.add_argument('--validationFrameFirst', type=int, default=141, help='specify the number of the first frame used for validation') + parser.add_argument('--validationFrameLast', type=int, default=150, help='specify the number of the last frame (exclusive) used for validation') + parser.add_argument('--minTrainingLoss', type=int, default=3, help='''minimum reduction in training loss in orders of magnitude, set to 0 to disable the check''') + parser.add_argument('--checkPointFrequency', type=int, default=10, help='number of epochs between checkpoints') + parser.add_argument('--paramFile', type=Path, default=None, help='''specify the path to the parameter txt file, the parent directory of that file must contain the gkyl input training data''') + parser.add_argument('--xptCacheDir', type=Path, default=None, help='''specify the path to a directory that will be used to cache the outputs of the analytic Xpoint finder''') + parser.add_argument('--plot', action=argparse.BooleanOptionalAction, help='create figures of the ground truth X-points and model identified X-points') + parser.add_argument('--plotDir', type=Path, default="./plots", help='directory where figures are written') + parser.add_argument('--use-amp', action='store_true', help='use automatic mixed precision training') + parser.add_argument('--amp-dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type for mixed precision (float16 or bfloat16)') args = parser.parse_args() return args def checkCommandLineArgs(args): - if args.xptCacheDir != None: - if not args.xptCacheDir.is_dir(): - print(f"Xpoint cache directory {args.xptCacheDir} does not exist. " - "Please create the directory... exiting") - sys.exit() - - if args.paramFile == None: - print(f"parameter file required but not set... exiting") - sys.exit() + if args.xptCacheDir is not None: + if not args.xptCacheDir.is_dir(): + print(f"Xpoint cache directory {args.xptCacheDir} does not exist. Please create the directory... exiting") + sys.exit() + + if args.paramFile is None: + print(f"parameter file required but not set... exiting") + sys.exit() if args.paramFile.is_dir(): - print(f"parameter file {args.paramFile} is a directory ... exiting") - sys.exit() + print(f"parameter file {args.paramFile} is a directory ... exiting") + sys.exit() if not args.paramFile.exists(): - print(f"parameter file {args.paramFile} does not exist... exiting") - sys.exit() + print(f"parameter file {args.paramFile} does not exist... exiting") + sys.exit() if args.trainFrameFirst == 0 or args.validationFrameFirst == 0: - print(f"frame 0 does not contain valid data... exiting") - sys.exit() + print(f"frame 0 does not contain valid data... exiting") + sys.exit() if args.trainFrameLast <= args.trainFrameFirst: - print(f"training frame range isn't valid... exiting") - sys.exit() + print(f"training frame range isn't valid... exiting") + sys.exit() if args.validationFrameLast <= args.validationFrameFirst: - print(f"validation frame range isn't valid... exiting") - sys.exit() + print(f"validation frame range isn't valid... exiting") + sys.exit() if args.learningRate <= 0: - print(f"learningRate must be > 0... exiting") - sys.exit() + print(f"learningRate must be > 0... exiting") + sys.exit() if args.batchSize < 1: - print(f"batchSize must be >= 1... exiting") - sys.exit() + print(f"batchSize must be >= 1... exiting") + sys.exit() if args.minTrainingLoss < 0: - print(f"minTrainingLoss must be >= 0... exiting") - sys.exit() + print(f"minTrainingLoss must be >= 0... exiting") + sys.exit() if args.checkPointFrequency < 0: - print(f"checkPointFrequency must be >= 0... exiting") - sys.exit() + print(f"checkPointFrequency must be >= 0... exiting") + sys.exit() def printCommandLineArgs(args): print("Config {") @@ -718,23 +828,14 @@ def printCommandLineArgs(args): print("}") # Function to save model checkpoint -def save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch, checkpoint_dir="checkpoints"): +def save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch, checkpoint_dir="checkpoints", scaler=None): """ Save model checkpoint including model state, optimizer state, and training metrics - - Parameters: - model: The neural network model - optimizer: The optimizer used for training - train_loss: List of training losses - val_loss: List of validation losses - epoch: Current epoch number - checkpoint_dir: Directory to save checkpoints """ os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir, f"xpoint_model_epoch_{epoch}.pt") - # Create checkpoint dictionary checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), @@ -743,7 +844,9 @@ def save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch, checkpo 'val_loss': val_loss } - # Save checkpoint + if scaler is not None: + checkpoint['scaler_state_dict'] = scaler.state_dict() + torch.save(checkpoint, checkpoint_path) print(f"Model checkpoint saved at epoch {epoch} to {checkpoint_path}") @@ -760,25 +863,13 @@ def save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch, checkpo raise e # Function to load model checkpoint -def load_model_checkpoint(model, optimizer, checkpoint_path): +def load_model_checkpoint(model, optimizer, checkpoint_path, scaler=None): """ Load model checkpoint - - Parameters: - model: The neural network model to load weights into - optimizer: The optimizer to load state into - checkpoint_path: Path to the checkpoint file - - Returns: - model: Updated model with loaded weights - optimizer: Updated optimizer with loaded state - epoch: Last saved epoch number - train_loss: List of training losses - val_loss: List of validation losses """ if not os.path.exists(checkpoint_path): print(f"No checkpoint found at {checkpoint_path}") - return model, optimizer, 0, [], [] + return model, optimizer, 0, [], [], scaler print(f"Loading checkpoint from {checkpoint_path}") checkpoint = torch.load(checkpoint_path) @@ -790,9 +881,11 @@ def load_model_checkpoint(model, optimizer, checkpoint_path): train_loss = checkpoint['train_loss'] val_loss = checkpoint['val_loss'] + if scaler is not None and 'scaler_state_dict' in checkpoint: + scaler.load_state_dict(checkpoint['scaler_state_dict']) + print(f"Loaded checkpoint from epoch {epoch}") - return model, optimizer, epoch, train_loss, val_loss - + return model, optimizer, epoch, train_loss, val_loss, scaler def main(): args = parseCommandLineArgs() @@ -811,20 +904,62 @@ def main(): xptCacheDir=args.xptCacheDir, rotateAndReflect=True) val_dataset = XPointDataset(args.paramFile, val_fnums, xptCacheDir=args.xptCacheDir) + + train_crop = XPointPatchDataset(train_dataset, patch=64, pos_ratio=0.6, retries=20) + val_crop = XPointPatchDataset(val_dataset, patch=64, pos_ratio=0.6, retries=20) t1 = timer() print("time (s) to create gkyl data loader: " + str(t1-t0)) print(f"number of training frames (original + augmented): {len(train_dataset)}") print(f"number of validation frames: {len(val_dataset)}") - train_loader = DataLoader(train_dataset, batch_size=args.batchSize, shuffle=False) - val_loader = DataLoader(val_dataset, batch_size=args.batchSize, shuffle=False) + train_loader = DataLoader(train_crop, batch_size=args.batchSize, shuffle=False) + val_loader = DataLoader(val_crop, batch_size=args.batchSize, shuffle=False) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet(input_channels=4, base_channels=64).to(device) criterion = DiceLoss(smooth=1.0) - optimizer = optim.Adam(model.parameters(), lr=args.learningRate) + + # Reduce learning rate for mixed precision training + effective_lr = args.learningRate + if args.use_amp: + # Less aggressive reduction for AMP + effective_lr = args.learningRate * 0.5 # Reduce by 2x instead of 10x + print(f"Adjusting learning rate for AMP: {args.learningRate} -> {effective_lr}") + + optimizer = optim.Adam(model.parameters(), lr=effective_lr, eps=1e-4) # Higher epsilon for stability + + # Add learning rate scheduler + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True) + + # Initialize GradScaler for mixed precision if enabled + scaler = None + amp_dtype = torch.float16 # default + + if args.use_amp and torch.cuda.is_available(): + if args.amp_dtype == 'bfloat16': + if torch.cuda.is_bf16_supported(): + amp_dtype = torch.bfloat16 + print("Using bfloat16 for mixed precision (no GradScaler needed)") + else: + print("Warning: bfloat16 not supported on this GPU, falling back to float16") + amp_dtype = torch.float16 + + # Only use GradScaler with float16 + if amp_dtype == torch.float16: + # Initialize with very conservative settings for stability + scaler = GradScaler( + device='cuda', + init_scale=2.**4, # Much smaller initial scale (16 instead of 256) + growth_factor=1.5, # Slower growth + backoff_factor=0.5, + growth_interval=200, # Wait longer before increasing scale + enabled=True + ) + print("Initialized GradScaler with very conservative settings for stability") + + print(f"Using Automatic Mixed Precision with {amp_dtype}") checkpoint_dir = "checkpoints" os.makedirs(checkpoint_dir, exist_ok=True) @@ -834,8 +969,8 @@ def main(): val_loss = [] if os.path.exists(latest_checkpoint_path): - model, optimizer, start_epoch, train_loss, val_loss = load_model_checkpoint( - model, optimizer, latest_checkpoint_path + model, optimizer, start_epoch, train_loss, val_loss, scaler = load_model_checkpoint( + model, optimizer, latest_checkpoint_path, scaler ) print(f"Resuming training from epoch {start_epoch+1}") else: @@ -843,103 +978,114 @@ def main(): t2 = timer() print("time (s) to prepare model: " + str(t2-t1)) - - train_loss = [] - val_loss = [] + if args.use_amp: + print(f"Using Automatic Mixed Precision (AMP) training with {amp_dtype}") num_epochs = args.epochs for epoch in range(start_epoch, num_epochs): - train_loss.append(train_one_epoch(model, train_loader, criterion, optimizer, device)) - val_loss.append(validate_one_epoch(model, val_loader, criterion, device)) + train_loss.append(train_one_epoch(model, train_loader, criterion, optimizer, device, scaler, args.use_amp, amp_dtype)) + val_loss.append(validate_one_epoch(model, val_loader, criterion, device, args.use_amp, amp_dtype)) + + # Update learning rate based on validation loss + if not np.isnan(val_loss[-1]): + scheduler.step(val_loss[-1]) + print(f"[Epoch {epoch+1}/{num_epochs}] TrainLoss={train_loss[-1]} ValLoss={val_loss[-1]}") # Save model checkpoint after each epoch if epoch % args.checkPointFrequency == 0: - save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch+1, checkpoint_dir) + save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch+1, checkpoint_dir, scaler) plot_training_history(train_loss, val_loss) print("time (s) to train model: " + str(timer()-t2)) requiredLossDecreaseMagnitude = args.minTrainingLoss - if np.log10(abs(train_loss[0]/train_loss[-1])) < requiredLossDecreaseMagnitude: - print(f"TrainLoss reduced by less than {requiredLossDecreaseMagnitude} orders of magnitude: " - f"initial {train_loss[0]} final {train_loss[-1]} ... exiting") - return 1; + if len(train_loss) > 0 and train_loss[-1] > 0 and not np.isnan(train_loss[0]) and not np.isnan(train_loss[-1]): + if np.log10(abs(train_loss[0]/train_loss[-1])) < requiredLossDecreaseMagnitude: + print(f"TrainLoss reduced by less than {requiredLossDecreaseMagnitude} orders of magnitude: " + f"initial {train_loss[0]} final {train_loss[-1]} ... exiting") + return 1 + else: + print("Warning: Unable to check training loss reduction due to NaN or zero values") # (D) Plotting after training model.eval() # switch to inference mode outDir = "plots" interpFac = 1 - # Evaluate on combined set for demonstration. Exam this part to see if save to remove - full_fnums = list(train_fnums) + list(val_fnums) + # Evaluate on combined set for demonstration full_dataset = [train_dataset, val_dataset] t4 = timer() with torch.no_grad(): - for set in full_dataset: - for item in set: - # item is a dict with keys: fnum, psi, mask, psi_np, mask_np, x, y, tmp, params - fnum = item["fnum"] - rotation = item["rotation"] - reflectionAxis = item["reflectionAxis"] - psi_np = np.array(item["psi"])[0] - mask_gt = np.array(item["mask"])[0] - x = item["x"] - y = item["y"] - filenameBase = item["filenameBase"] - params = item["params"] - - # Get CNN prediction - all_torch = item["all"].unsqueeze(0).to(device) - pred_mask = model(all_torch) - pred_mask_np = pred_mask[0,0].cpu().numpy() - # Binarize - pred_bin = (pred_mask_np > 0.5).astype(np.float32) - - pred_prob = torch.sigmoid(pred_mask) - pred_prob_np = (pred_prob > 0.5).float().cpu().numpy() - - pred_mask_bin = (pred_prob_np > 0.5).astype(np.float32) # Thresholding at 0.5, can be fine tune - - print(f"Frame {fnum} rotation {rotation} reflectionAxis {reflectionAxis}:") - print(f"psi shape: {psi_np.shape}, min: {psi_np.min()}, max: {psi_np.max()}") - print(f"pred_bin shape: {pred_bin.shape}, min: {pred_bin.min()}, max: {pred_bin.max()}") - print(f" Logits - min: {pred_mask_np.min():.5f}, max: {pred_mask_np.max():.5f}, mean: {pred_mask_np.mean():.5f}") - print(f" Probabilities (after sigmoid) - min: {pred_prob_np.min():.5f}, max: {pred_prob_np.max():.5f}, mean: {pred_prob_np.mean():.5f}") - print(f" Binary Mask (X-points) - count of 1s: {np.sum(pred_mask_bin)} / {pred_mask_bin.size} pixels") - print(f" Binary Mask (X_points) - shape: {pred_mask_bin.shape}, min: {pred_mask_bin.min()}, max: {pred_mask_bin.max()}") - - if args.plot : - # Plot GROUND TRUTH - plot_psi_contours_and_xpoints( - psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, - xpoint_mask=mask_gt, - titleExtra="GTXpoints", - outDir=outDir, - saveFig=True - ) - - # Plot CNN PREDICTIONS - plot_psi_contours_and_xpoints( - psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, - xpoint_mask=np.squeeze(pred_mask_bin), - titleExtra="CNNXpoints", - outDir=outDir, - saveFig=True - ) - - pred_prob_np_full = pred_prob.cpu().numpy() - plot_model_performance( - psi_np, pred_prob_np_full, mask_gt, x, y, params, fnum, filenameBase, - outDir=outDir, - saveFig=True - ) + for dataset in full_dataset: + for item in dataset: + fnum = item["fnum"] + rotation = item["rotation"] + reflectionAxis = item["reflectionAxis"] + psi_np = item["psi"].numpy()[0] + mask_gt = item["mask"].numpy()[0] + x = item["x"] + y = item["y"] + filenameBase = item["filenameBase"] + params = item["params"] + + # Get CNN prediction + all_torch = item["all"].unsqueeze(0).to(device) + + if args.use_amp: + with autocast(device_type='cuda', dtype=amp_dtype): + pred_mask = model(all_torch) + else: + pred_mask = model(all_torch) + + pred_mask_np = pred_mask[0,0].cpu().numpy() + # Binarize + pred_bin = (pred_mask_np > 0.5).astype(np.float32) + + pred_prob = torch.sigmoid(pred_mask) + pred_prob_np = (pred_prob > 0.5).float().cpu().numpy() + + pred_mask_bin = (pred_prob_np > 0.5).astype(np.float32) + + print(f"Frame {fnum} rotation {rotation} reflectionAxis {reflectionAxis}:") + print(f"psi shape: {psi_np.shape}, min: {psi_np.min()}, max: {psi_np.max()}") + print(f"pred_bin shape: {pred_bin.shape}, min: {pred_bin.min()}, max: {pred_bin.max()}") + print(f" Logits - min: {pred_mask_np.min():.5f}, max: {pred_mask_np.max():.5f}, mean: {pred_mask_np.mean():.5f}") + print(f" Probabilities (after sigmoid) - min: {pred_prob_np.min():.5f}, max: {pred_prob_np.max():.5f}, mean: {pred_prob_np.mean():.5f}") + print(f" Binary Mask (X-points) - count of 1s: {np.sum(pred_mask_bin)} / {pred_mask_bin.size} pixels") + print(f" Binary Mask (X_points) - shape: {pred_mask_bin.shape}, min: {pred_mask_bin.min()}, max: {pred_mask_bin.max()}") + + if args.plot: + # Plot GROUND TRUTH + plot_psi_contours_and_xpoints( + psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, + xpoint_mask=mask_gt, + titleExtra="GTXpoints", + outDir=outDir, + saveFig=True + ) + + # Plot CNN PREDICTIONS + plot_psi_contours_and_xpoints( + psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, + xpoint_mask=np.squeeze(pred_mask_bin), + titleExtra="CNNXpoints", + outDir=outDir, + saveFig=True + ) + + pred_prob_np_full = pred_prob.cpu().numpy() + plot_model_performance( + psi_np, pred_prob_np_full, mask_gt, x, y, params, fnum, filenameBase, + outDir=outDir, + saveFig=True + ) t5 = timer() print("time (s) to apply model: " + str(t5-t4)) print("total time (s): " + str(t5-t0)) if __name__ == "__main__": - main() + main() \ No newline at end of file From 527a8747ae05b686049475f8fb9082009596e8ae Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Wed, 13 Aug 2025 22:52:11 -0400 Subject: [PATCH 2/7] Major architecture upgrade: ResNet-style U-Net + modern training --- StrinkedXPoint.py | 440 ------------------- XPointMLTest.py | 1027 +++++++++++++++++++++++---------------------- 2 files changed, 515 insertions(+), 952 deletions(-) delete mode 100644 StrinkedXPoint.py diff --git a/StrinkedXPoint.py b/StrinkedXPoint.py deleted file mode 100644 index e34300c..0000000 --- a/StrinkedXPoint.py +++ /dev/null @@ -1,440 +0,0 @@ -import numpy as np -import matplotlib.pyplot as plt -import os, errno, sys, argparse -from pathlib import Path -from timeit import default_timer as timer -import sys -import argparse - -from utils import gkData -from utils import auxFuncs -from utils import plotParams - -import torch -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F -from torch.utils.data import DataLoader, Dataset -from torchvision.transforms import v2 # rotate tensor - -def expand_xpoints_mask(binary_mask, kernel_size=9): - """ - Expands each X-point in a binary mask to include surrounding cells - in a square grid of size kernel_size x kernel_size. - - Parameters: - binary_mask : numpy.ndarray - 2D binary mask with 1s at X-point locations - kernel_size : int - Size of the square grid (must be odd number) - - Returns: - numpy.ndarray - Expanded binary mask with 1s in kernel_size×kernel_size regions around X-points - """ - - # Get shape of the original mask - h, w = binary_mask.shape - - # Create a copy to avoid modifying the original - expanded_mask = np.zeros_like(binary_mask) - - # Find coordinates of all X-points - x_points = np.argwhere(binary_mask > 0) - - # For each X-point, set a kernel_size×kernel_size area to 1 - half_size = kernel_size // 2 - for point in x_points: - # Get the corner coordinates for the square centered at the X-point - x_min = max(0, point[0] - half_size) - x_max = min(h, point[0] + half_size + 1) - y_min = max(0, point[1] - half_size) - y_max = min(w, point[1] + half_size + 1) - - # Set the square area to 1 - expanded_mask[x_min:x_max, y_min:y_max] = 1 - - return expanded_mask - -def rotate(frameData,deg): - if deg not in [90, 180, 270]: - print(f"invalid rotation specified... exiting") - sys.exit() - psi = v2.functional.rotate(frameData["psi"], deg, v2.InterpolationMode.BILINEAR) - all = v2.functional.rotate(frameData["all"], deg, v2.InterpolationMode.BILINEAR) - mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.BILINEAR) - return { - "fnum": frameData["fnum"], - "rotation": deg, - "reflectionAxis": -1, # no reflection - "psi": psi, - "all": all, - "mask": mask, - "x": frameData["x"], - "y": frameData["y"], - "filenameBase": frameData["filenameBase"], - "params": frameData["params"] - } - -def reflect(frameData,axis): - if axis not in [0,1]: - print(f"invalid reflection axis specified... exiting") - sys.exit() - psi = torch.flip(frameData["psi"][0], dims=(axis,)).unsqueeze(0) - all = torch.flip(frameData["all"], dims=(axis,)) - mask = torch.flip(frameData["mask"][0], dims=(axis,)).unsqueeze(0) - return { - "fnum": frameData["fnum"], - "rotation": 0, - "reflectionAxis": axis, - "psi": psi, - "all": all, - "mask": mask, - "x": frameData["x"], - "y": frameData["y"], - "filenameBase": frameData["filenameBase"], - "params": frameData["params"] - } - -def getPgkylData(paramFile, frameNumber, verbosity): - if verbosity > 0: - print(f"=== frame {frameNumber} ===") - params = {} #Initialize dictionary to store plotting and other parameters - params["polyOrderOverride"] = 0 #Override default dg interpolation and interpolate to given number of points - constrcutBandJ = 1 - #Read vector potential - var = gkData.gkData(str(paramFile),frameNumber,'psi',params).compactRead() - psi = var.data - coords = var.coords - axesNorm = var.d[ var.speciesFileIndex.index('ion') ] - if verbosity > 0: - print(f"psi shape: {psi.shape}, min={psi.min()}, max={psi.max()}") - #Construct B and J (first and second derivatives) - [df_dx,df_dy,df_dz] = auxFuncs.genGradient(psi,var.dx) - [d2f_dxdx,d2f_dxdy,d2f_dxdz] = auxFuncs.genGradient(df_dx,var.dx) - [d2f_dydx,d2f_dydy,d2f_dydz] = auxFuncs.genGradient(df_dy,var.dx) - bx = df_dy - by = -df_dx - jz = -(d2f_dxdx + d2f_dydy) / var.mu0 - del df_dx,df_dy,df_dz,d2f_dxdx,d2f_dxdy,d2f_dxdz,d2f_dydx,d2f_dydy,d2f_dydz - #Indicies of critical points, X points, and O points (max and min) - critPoints = auxFuncs.getCritPoints(psi) - [xpts, optsMax, optsMin] = auxFuncs.getXOPoints(psi, critPoints) - return [var.filenameBase, axesNorm, critPoints, xpts, optsMax, optsMin, coords, psi, bx, by, jz] - -def cachedPgkylDataExists(cacheDir, frameNumber, fieldName): - if cacheDir == None: - return False - else: - cachedFrame = cacheDir / f"{frameNumber}_{fieldName}.npy" - return cachedFrame.exists(); - -def loadPgkylDataFromCache(cacheDir, frameNumber, fields): - outFields = {} - if cacheDir != None: - for name in fields.keys(): - if name == "fileName": - with open(cacheDir / f"{frameNumber}_{name}.txt", "r") as file: - outFields[name] = file.read().rstrip() - else: - outFields[name] = np.load(cacheDir / f"{frameNumber}_{name}.npy") - return outFields - else: - return None - -def writePgkylDataToCache(cacheDir, frameNumber, fields): - if cacheDir != None: - for name, field in fields.items(): - if name == "fileName": - with open(cacheDir / f"{frameNumber}_{name}.txt", "w") as text_file: - text_file.write(f"{field}") - else: - np.save(cacheDir / f"{frameNumber}_{name}.npy",field) - - -class XPointDataset(Dataset): - """ - Dataset that processes frames in [fnumList]. For each frame (fnum): - - Sets up "params" according to your snippet. - - Reads psi from gkData (varid='psi') - - Finds X-points -> builds a 2D binary mask. - - Returns (psiTensor, maskTensor) as a PyTorch (float) pair. - """ - def __init__(self, paramFile, fnumList, xptCacheDir=None, - rotateAndReflect=False, verbosity=0): - """ - paramFile: Path to parameter file (string). - fnumList: List of frames to iterate. - """ - super().__init__() - self.paramFile = paramFile - self.fnumList = list(fnumList) # ensure indexable - self.xptCacheDir = xptCacheDir - self.verbosity = verbosity - - # We'll store a base 'params' once here, and then customize in __getitem__: - self.params = {} - # Default snippet-based constants: - self.params["lowerLimits"] = [-1e6, -1e6, -0.e6, -1.e6, -1.e6] - self.params["upperLimits"] = [1e6, 1e6, 0.e6, 1.e6, 1.e6] - self.params["restFrame"] = 1 - self.params["polyOrderOverride"] = 0 - - self.params["plotContours"] = 1 - self.params["colorContours"] = 'k' - self.params["numContours"] = 50 - self.params["axisEqual"] = 1 - self.params["symBar"] = 1 - self.params["colormap"] = 'bwr' - - - # load all the data - self.data = [] - for fnum in fnumList: - frameData = self.load(fnum) - self.data.append(frameData) - if rotateAndReflect: - self.data.append(rotate(frameData,90)) - self.data.append(rotate(frameData,180)) - self.data.append(rotate(frameData,270)) - self.data.append(reflect(frameData,0)) - self.data.append(reflect(frameData,1)) - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - return self.data[idx] - - def load(self, fnum): - t0 = timer() - - # check if cache exists - if self.xptCacheDir != None: - if not self.xptCacheDir.is_dir(): - print(f"Xpoint cache directory {self.xptCacheDir} does not exist... exiting") - sys.exit() - t2 = timer() - - fields = {"psi":None, - "critPts":None, - "xpts":None, - "optsMax":None, - "optsMin":None, - "axesNorm":None, - "coords":None, - "fileName":None, - "Bx":None, "By":None, - "Jz":None} - - # Indicies of critical points, X points, and O points (max and min) - if self.xptCacheDir != None and cachedPgkylDataExists(self.xptCacheDir, fnum, "psi"): - fields = loadPgkylDataFromCache(self.xptCacheDir, fnum, fields) - else: - [fileName, axesNorm, critPoints, xpts, optsMax, optsMin, coords, psi, bx, by, jz] = getPgkylData(self.paramFile, fnum, verbosity=self.verbosity) - fields = {"psi":psi, "critPts":critPoints, "xpts":xpts, - "optsMax":optsMax, "optsMin":optsMin, - "axesNorm": axesNorm, "coords": coords, - "fileName": fileName, - "Bx":bx, "By":by, "Jz":jz} - writePgkylDataToCache(self.xptCacheDir, fnum, fields) - self.params["axesNorm"] = fields["axesNorm"] - - if self.verbosity > 0: - print("time (s) to find X and O points: " + str(timer()-t2)) - - # Create array of 0s with 1s only at X points - binaryMap = np.zeros(np.shape(fields["psi"])) - binaryMap[fields["xpts"][:, 0], fields["xpts"][:, 1]] = 1 - - binaryMap = expand_xpoints_mask(binaryMap, kernel_size=9) - - # -------------- 6) Convert to Torch Tensors -------------- - psi_torch = torch.from_numpy(fields["psi"]).float().unsqueeze(0) # [1, Nx, Ny] - bx_torch = torch.from_numpy(fields["Bx"]).float().unsqueeze(0) - by_torch = torch.from_numpy(fields["By"]).float().unsqueeze(0) - jz_torch = torch.from_numpy(fields["Jz"]).float().unsqueeze(0) - all_torch = torch.cat((psi_torch,bx_torch,by_torch,jz_torch)) # [4, Nx, Ny] - mask_torch = torch.from_numpy(binaryMap).float().unsqueeze(0) # [1, Nx, Ny] - - if self.verbosity > 0: - print("time (s) to load and process gkyl frame: " + str(timer()-t0)) - - return { - "fnum": fnum, - "rotation": 0, - "reflectionAxis": -1, # no reflection - "psi": psi_torch, # shape [1, Nx, Ny] - "all": all_torch, # shape [4, Nx, Ny] - "mask": mask_torch, # shape [1, Nx, Ny] - "x": fields["coords"][0], - "y": fields["coords"][1], - "filenameBase": fields["fileName"], - "params": dict(self.params) # copy of the params for local plotting - } - - - -class XPointPatchDataset(Dataset): - """On‑the‑fly square crops, balancing positive / background patches.""" - def __init__(self, base_ds, patch=64, pos_ratio=0.6, retries=20): - self.base_ds = base_ds - self.patch = patch - self.pos_ratio = pos_ratio - self.retries = retries - self.rng = np.random.default_rng() - - def __len__(self): - # give each full frame K random crops per epoch (K=16 by default) - return len(self.base_ds) * 16 - - def _crop(self, arr, top, left): - return arr[..., top:top+self.patch, left:left+self.patch] - - def __getitem__(self, _): - frame = self.base_ds[self.rng.integers(len(self.base_ds))] - H, W = frame["mask"].shape[-2:] - - for attempt in range(self.retries): - y0 = self.rng.integers(0, H - self.patch + 1) - x0 = self.rng.integers(0, W - self.patch + 1) - crop_mask = self._crop(frame["mask"], y0, x0) - has_pos = crop_mask.sum() > 0 - want_pos = (attempt / self.retries) < self.pos_ratio - - if has_pos == want_pos or attempt == self.retries - 1: - return { - "all" : self._crop(frame["all"], y0, x0), - "mask": crop_mask - } - - -class UNet(nn.Module): - def __init__(self, input_channels=4, base=64): - super().__init__() - self.enc1 = self._dbl(input_channels, base, dilation=1) - self.enc2 = self._dbl(base, base*2, dilation=1) - self.enc3 = self._dbl(base*2, base*4, dilation=2) # ← dilated - self.pool = nn.MaxPool2d(2, 2) - self.bottleneck = self._dbl(base*4, base*8, dilation=4) # ← dilated - self.up3 = nn.ConvTranspose2d(base*8, base*4, 2, 2) - self.dec3 = self._dbl(base*8, base*4, dilation=1) - self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, 2) - self.dec2 = self._dbl(base*4, base*2, dilation=1) - self.up1 = nn.ConvTranspose2d(base*2, base, 2, 2) - self.dec1 = self._dbl(base*2, base, dilation=1) - self.out = nn.Conv2d(base, 1, 1) - - @staticmethod - def _dbl(inp, out, dilation=1): - pad = dilation - return nn.Sequential( - nn.Conv2d(inp, out, 3, padding=pad, dilation=dilation), - nn.ReLU(inplace=True), - nn.Conv2d(out, out, 3, padding=pad, dilation=dilation), - nn.ReLU(inplace=True) - ) - - def forward(self, x): - e1 = self.enc1(x) - e2 = self.enc2(self.pool(e1)) - e3 = self.enc3(self.pool(e2)) - b = self.bottleneck(self.pool(e3)) - d3 = self.dec3(torch.cat([self.up3(b), e3], 1)) - d2 = self.dec2(torch.cat([self.up2(d3), e2], 1)) - d1 = self.dec1(torch.cat([self.up1(d2), e1], 1)) - return self.out(d1) - - - -class DiceLoss(nn.Module): - def __init__(self, smooth=1.): - super().__init__() - self.smooth = smooth - - def forward(self, logits, targets): - probs = torch.sigmoid(logits) - probs = probs.view(-1); targets = targets.view(-1) - inter = (probs * targets).sum() - union = probs.sum() + targets.sum() - dice = (2*inter + self.smooth) / (union + self.smooth) - return 1 - dice - - -def make_criterion(pos_weight): - bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]).float()) - dice = DiceLoss() - def _loss(logits, target): - return 0.5 * bce(logits, target) + 0.5 * dice(logits, target) - return _loss - - - -@torch.no_grad() -def evaluate(model, loader, criterion, device): - model.eval(); loss = 0 - for batch in loader: - x, y = batch["all"].to(device), batch["mask"].to(device) - logits = model(x) - loss += criterion(logits, y).item() - return loss / len(loader) - - -def train_epoch(model, loader, criterion, opt, device): - model.train(); loss = 0 - for batch in loader: - x, y = batch["all"].to(device), batch["mask"].to(device) - opt.zero_grad() - l = criterion(model(x), y) - l.backward(); opt.step(); loss += l.item() - return loss / len(loader) - - - - -def main(): - p = argparse.ArgumentParser() - p.add_argument("--paramFile", type=Path, required=True) - p.add_argument("--xptCacheDir", type=Path) - p.add_argument("--epochs", type=int, default=60) - p.add_argument("--batch", type=int, default=8) - p.add_argument("--lr", type=float, default=3e-5) - args = p.parse_args() - - # frame splits hard‑coded for demo - train_fnums = range(1, 141) - val_fnums = range(141, 150) - - train_full = XPointDataset(args.paramFile, train_fnums, - xptCacheDir=args.xptCacheDir, - rotateAndReflect=True) - val_full = XPointDataset(args.paramFile, val_fnums, - xptCacheDir=args.xptCacheDir) - - train_ds = XPointPatchDataset(train_full, patch=64, pos_ratio=0.8) # 64 x 64 cropping - val_ds = XPointPatchDataset(val_full, patch=64, pos_ratio=0.5) - - loader_tr = DataLoader(train_ds, batch_size=args.batch, shuffle=True) - loader_va = DataLoader(val_ds, batch_size=args.batch, shuffle=False) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = UNet().to(device) - - # estimate pos/neg for weight (rough): - pos_px = 30 * 9 * 9 - neg_px = 1024*1024 - pos_px - criterion = make_criterion(pos_weight=neg_px/pos_px) - - opt = optim.Adam(model.parameters(), lr=args.lr) - - for ep in range(1, args.epochs+1): - tr_loss = train_epoch(model, loader_tr, criterion, opt, device) - va_loss = evaluate(model, loader_va, criterion, device) - print(f"Epoch {ep:03d}: train {tr_loss:.4f} | val {va_loss:.4f}") - - # quick checkpoint - if ep % 10 == 0: - torch.save(model.state_dict(), f"chkpt_ep{ep}.pt") - -if __name__ == "__main__": - main() diff --git a/XPointMLTest.py b/XPointMLTest.py index 45bf51d..033a45b 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -67,14 +67,14 @@ def rotate(frameData,deg): print(f"invalid rotation specified... exiting") sys.exit() psi = v2.functional.rotate(frameData["psi"], deg, v2.InterpolationMode.BILINEAR) - all_data = v2.functional.rotate(frameData["all"], deg, v2.InterpolationMode.BILINEAR) + all = v2.functional.rotate(frameData["all"], deg, v2.InterpolationMode.BILINEAR) mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.BILINEAR) return { "fnum": frameData["fnum"], "rotation": deg, "reflectionAxis": -1, # no reflection "psi": psi, - "all": all_data, + "all": all, "mask": mask, "x": frameData["x"], "y": frameData["y"], @@ -87,14 +87,14 @@ def reflect(frameData,axis): print(f"invalid reflection axis specified... exiting") sys.exit() psi = torch.flip(frameData["psi"][0], dims=(axis,)).unsqueeze(0) - all_data = torch.flip(frameData["all"], dims=(axis,)) + all = torch.flip(frameData["all"], dims=(axis,)) mask = torch.flip(frameData["mask"][0], dims=(axis,)).unsqueeze(0) return { "fnum": frameData["fnum"], "rotation": 0, "reflectionAxis": axis, "psi": psi, - "all": all_data, + "all": all, "mask": mask, "x": frameData["x"], "y": frameData["y"], @@ -103,58 +103,59 @@ def reflect(frameData,axis): } def getPgkylData(paramFile, frameNumber, verbosity): - if verbosity > 0: - print(f"=== frame {frameNumber} ===") - params = {} #Initialize dictionary to store plotting and other parameters - params["polyOrderOverride"] = 0 #Override default dg interpolation and interpolate to given number of points - #Read vector potential - var = gkData.gkData(str(paramFile),frameNumber,'psi',params).compactRead() - psi = var.data - coords = var.coords - axesNorm = var.d[ var.speciesFileIndex.index('ion') ] - if verbosity > 0: - print(f"psi shape: {psi.shape}, min={psi.min()}, max={psi.max()}") - #Construct B and J (first and second derivatives) - [df_dx,df_dy,df_dz] = auxFuncs.genGradient(psi,var.dx) - [d2f_dxdx,d2f_dxdy,d2f_dxdz] = auxFuncs.genGradient(df_dx,var.dx) - [d2f_dydx,d2f_dydy,d2f_dydz] = auxFuncs.genGradient(df_dy,var.dx) - bx = df_dy - by = -df_dx - jz = -(d2f_dxdx + d2f_dydy) / var.mu0 - del df_dx,df_dy,df_dz,d2f_dxdx,d2f_dxdy,d2f_dxdz,d2f_dydx,d2f_dydy,d2f_dydz - #Indicies of critical points, X points, and O points (max and min) - critPoints = auxFuncs.getCritPoints(psi) - [xpts, optsMax, optsMin] = auxFuncs.getXOPoints(psi, critPoints) - return [var.filenameBase, axesNorm, critPoints, xpts, optsMax, optsMin, coords, psi, bx, by, jz] + if verbosity > 0: + print(f"=== frame {frameNumber} ===") + params = {} #Initialize dictionary to store plotting and other parameters + params["polyOrderOverride"] = 0 #Override default dg interpolation and interpolate to given number of points + constrcutBandJ = 1 + #Read vector potential + var = gkData.gkData(str(paramFile),frameNumber,'psi',params).compactRead() + psi = var.data + coords = var.coords + axesNorm = var.d[ var.speciesFileIndex.index('ion') ] + if verbosity > 0: + print(f"psi shape: {psi.shape}, min={psi.min()}, max={psi.max()}") + #Construct B and J (first and second derivatives) + [df_dx,df_dy,df_dz] = auxFuncs.genGradient(psi,var.dx) + [d2f_dxdx,d2f_dxdy,d2f_dxdz] = auxFuncs.genGradient(df_dx,var.dx) + [d2f_dydx,d2f_dydy,d2f_dydz] = auxFuncs.genGradient(df_dy,var.dx) + bx = df_dy + by = -df_dx + jz = -(d2f_dxdx + d2f_dydy) / var.mu0 + del df_dx,df_dy,df_dz,d2f_dxdx,d2f_dxdy,d2f_dxdz,d2f_dydx,d2f_dydy,d2f_dydz + #Indicies of critical points, X points, and O points (max and min) + critPoints = auxFuncs.getCritPoints(psi) + [xpts, optsMax, optsMin] = auxFuncs.getXOPoints(psi, critPoints) + return [var.filenameBase, axesNorm, critPoints, xpts, optsMax, optsMin, coords, psi, bx, by, jz] def cachedPgkylDataExists(cacheDir, frameNumber, fieldName): - if cacheDir is None: - return False - else: - cachedFrame = cacheDir / f"{frameNumber}_{fieldName}.npy" - return cachedFrame.exists() + if cacheDir == None: + return False + else: + cachedFrame = cacheDir / f"{frameNumber}_{fieldName}.npy" + return cachedFrame.exists(); def loadPgkylDataFromCache(cacheDir, frameNumber, fields): - outFields = {} - if cacheDir is not None: - for name in fields.keys(): - if name == "fileName": - with open(cacheDir / f"{frameNumber}_{name}.txt", "r") as file: - outFields[name] = file.read().rstrip() - else: - outFields[name] = np.load(cacheDir / f"{frameNumber}_{name}.npy") - return outFields - else: - return None + outFields = {} + if cacheDir != None: + for name in fields.keys(): + if name == "fileName": + with open(cacheDir / f"{frameNumber}_{name}.txt", "r") as file: + outFields[name] = file.read().rstrip() + else: + outFields[name] = np.load(cacheDir / f"{frameNumber}_{name}.npy") + return outFields + else: + return None def writePgkylDataToCache(cacheDir, frameNumber, fields): - if cacheDir is not None: - for name, field in fields.items(): - if name == "fileName": - with open(cacheDir / f"{frameNumber}_{name}.txt", "w") as text_file: - text_file.write(f"{field}") - else: - np.save(cacheDir / f"{frameNumber}_{name}.npy",field) + if cacheDir != None: + for name, field in fields.items(): + if name == "fileName": + with open(cacheDir / f"{frameNumber}_{name}.txt", "w") as text_file: + text_file.write(f"{field}") + else: + np.save(cacheDir / f"{frameNumber}_{name}.npy",field) # DATASET DEFINITION class XPointDataset(Dataset): @@ -166,10 +167,10 @@ class XPointDataset(Dataset): - Returns (psiTensor, maskTensor) as a PyTorch (float) pair. """ def __init__(self, paramFile, fnumList, xptCacheDir=None, - rotateAndReflect=False, verbosity=0): + rotateAndReflect=False, verbosity=0): """ paramFile: Path to parameter file (string). - fnumList: List of frames to iterate. + fnumList: List of frames to iterate. """ super().__init__() self.paramFile = paramFile @@ -184,6 +185,7 @@ def __init__(self, paramFile, fnumList, xptCacheDir=None, self.params["upperLimits"] = [1e6, 1e6, 0.e6, 1.e6, 1.e6] self.params["restFrame"] = 1 self.params["polyOrderOverride"] = 0 + self.params["plotContours"] = 1 self.params["colorContours"] = 'k' self.params["numContours"] = 50 @@ -191,17 +193,18 @@ def __init__(self, paramFile, fnumList, xptCacheDir=None, self.params["symBar"] = 1 self.params["colormap"] = 'bwr' + # load all the data self.data = [] for fnum in fnumList: frameData = self.load(fnum) self.data.append(frameData) if rotateAndReflect: - self.data.append(rotate(frameData,90)) - self.data.append(rotate(frameData,180)) - self.data.append(rotate(frameData,270)) - self.data.append(reflect(frameData,0)) - self.data.append(reflect(frameData,1)) + self.data.append(rotate(frameData,90)) + self.data.append(rotate(frameData,180)) + self.data.append(rotate(frameData,270)) + self.data.append(reflect(frameData,0)) + self.data.append(reflect(frameData,1)) def __len__(self): return len(self.data) @@ -213,29 +216,38 @@ def load(self, fnum): t0 = timer() # check if cache exists - if self.xptCacheDir is not None: - if not self.xptCacheDir.is_dir(): - print(f"Xpoint cache directory {self.xptCacheDir} does not exist... exiting") - sys.exit() - + if self.xptCacheDir != None: + if not self.xptCacheDir.is_dir(): + print(f"Xpoint cache directory {self.xptCacheDir} does not exist... exiting") + sys.exit() t2 = timer() - fields = {"psi":None, "critPts":None, "xpts":None, "optsMax":None, "optsMin":None, "axesNorm":None, "coords":None, "fileName":None, "Bx":None, "By":None, "Jz":None} + + fields = {"psi":None, + "critPts":None, + "xpts":None, + "optsMax":None, + "optsMin":None, + "axesNorm":None, + "coords":None, + "fileName":None, + "Bx":None, "By":None, + "Jz":None} # Indicies of critical points, X points, and O points (max and min) - if self.xptCacheDir is not None and cachedPgkylDataExists(self.xptCacheDir, fnum, "psi"): - fields = loadPgkylDataFromCache(self.xptCacheDir, fnum, fields) + if self.xptCacheDir != None and cachedPgkylDataExists(self.xptCacheDir, fnum, "psi"): + fields = loadPgkylDataFromCache(self.xptCacheDir, fnum, fields) else: - [fileName, axesNorm, critPoints, xpts, optsMax, optsMin, coords, psi, bx, by, jz] = getPgkylData(self.paramFile, fnum, verbosity=self.verbosity) - fields = {"psi":psi, "critPts":critPoints, "xpts":xpts, - "optsMax":optsMax, "optsMin":optsMin, - "axesNorm": axesNorm, "coords": coords, - "fileName": fileName, - "Bx":bx, "By":by, "Jz":jz} - writePgkylDataToCache(self.xptCacheDir, fnum, fields) + [fileName, axesNorm, critPoints, xpts, optsMax, optsMin, coords, psi, bx, by, jz] = getPgkylData(self.paramFile, fnum, verbosity=self.verbosity) + fields = {"psi":psi, "critPts":critPoints, "xpts":xpts, + "optsMax":optsMax, "optsMin":optsMin, + "axesNorm": axesNorm, "coords": coords, + "fileName": fileName, + "Bx":bx, "By":by, "Jz":jz} + writePgkylDataToCache(self.xptCacheDir, fnum, fields) self.params["axesNorm"] = fields["axesNorm"] if self.verbosity > 0: - print("time (s) to find X and O points: " + str(timer()-t2)) + print("time (s) to find X and O points: " + str(timer()-t2)) # Create array of 0s with 1s only at X points binaryMap = np.zeros(np.shape(fields["psi"])) @@ -243,25 +255,30 @@ def load(self, fnum): binaryMap = expand_xpoints_mask(binaryMap, kernel_size=9) + # Normalize input features for better training stability + psi_norm = (fields["psi"] - fields["psi"].mean()) / (fields["psi"].std() + 1e-8) + bx_norm = (fields["Bx"] - fields["Bx"].mean()) / (fields["Bx"].std() + 1e-8) + by_norm = (fields["By"] - fields["By"].mean()) / (fields["By"].std() + 1e-8) + jz_norm = (fields["Jz"] - fields["Jz"].mean()) / (fields["Jz"].std() + 1e-8) + # -------------- 6) Convert to Torch Tensors -------------- - psi_torch = torch.from_numpy(fields["psi"]).float().unsqueeze(0) # [1, Nx, Ny] - bx_torch = torch.from_numpy(fields["Bx"]).float().unsqueeze(0) - by_torch = torch.from_numpy(fields["By"]).float().unsqueeze(0) - jz_torch = torch.from_numpy(fields["Jz"]).float().unsqueeze(0) - + psi_torch = torch.from_numpy(psi_norm).float().unsqueeze(0) # [1, Nx, Ny] + bx_torch = torch.from_numpy(bx_norm).float().unsqueeze(0) + by_torch = torch.from_numpy(by_norm).float().unsqueeze(0) + jz_torch = torch.from_numpy(jz_norm).float().unsqueeze(0) all_torch = torch.cat((psi_torch,bx_torch,by_torch,jz_torch)) # [4, Nx, Ny] mask_torch = torch.from_numpy(binaryMap).float().unsqueeze(0) # [1, Nx, Ny] if self.verbosity > 0: - print("time (s) to load and process gkyl frame: " + str(timer()-t0)) + print("time (s) to load and process gkyl frame: " + str(timer()-t0)) return { "fnum": fnum, "rotation": 0, "reflectionAxis": -1, # no reflection - "psi": psi_torch, # shape [1, Nx, Ny] - "all": all_torch, # shape [4, Nx, Ny] - "mask": mask_torch, # shape [1, Nx, Ny] + "psi": torch.from_numpy(fields["psi"]).float().unsqueeze(0), # Keep original for plotting + "all": all_torch, # Normalized for training + "mask": mask_torch, # shape [1, Nx, Ny] "x": fields["coords"][0], "y": fields["coords"][1], "filenameBase": fields["fileName"], @@ -270,38 +287,16 @@ def load(self, fnum): class XPointPatchDataset(Dataset): """On‑the‑fly square crops, balancing positive / background patches.""" - def __init__(self, base_ds, patch=64, pos_ratio=0.6, retries=20): + def __init__(self, base_ds, patch=64, pos_ratio=0.5, retries=30): self.base_ds = base_ds self.patch = patch self.pos_ratio = pos_ratio self.retries = retries self.rng = np.random.default_rng() - # Precompute some statistics for normalization - self.compute_normalization_stats() - - def compute_normalization_stats(self): - """Compute global mean and std for normalization""" - # Sample a few frames to compute statistics - n_samples = min(10, len(self.base_ds)) - all_values = [] - - for i in range(n_samples): - frame = self.base_ds[i] - all_values.append(frame["all"].numpy()) - - all_values = np.concatenate([v.flatten() for v in all_values]) - self.global_mean = np.mean(all_values) - self.global_std = np.std(all_values) - - # Prevent division by zero - if self.global_std == 0: - self.global_std = 1.0 - - print(f"Computed normalization stats: mean={self.global_mean:.4f}, std={self.global_std:.4f}") def __len__(self): - # give each full frame K random crops per epoch (K=16 by default) - return len(self.base_ds) * 16 + # give each full frame K random crops per epoch (K=32 for more samples) + return len(self.base_ds) * 32 def _crop(self, arr, top, left): return arr[..., top:top+self.patch, left:left+self.patch] @@ -310,7 +305,14 @@ def __getitem__(self, _): frame = self.base_ds[self.rng.integers(len(self.base_ds))] H, W = frame["mask"].shape[-2:] - # comments on the logic + # Ensure we have enough space for cropping + if H < self.patch or W < self.patch: + # Return padded version if image is too small + return { + "all": F.pad(frame["all"], (0, max(0, self.patch - W), 0, max(0, self.patch - H))), + "mask": F.pad(frame["mask"], (0, max(0, self.patch - W), 0, max(0, self.patch - H))) + } + for attempt in range(self.retries): y0 = self.rng.integers(0, H - self.patch + 1) x0 = self.rng.integers(0, W - self.patch + 1) @@ -319,246 +321,120 @@ def __getitem__(self, _): want_pos = (attempt / self.retries) < self.pos_ratio if has_pos == want_pos or attempt == self.retries - 1: - crop_all = self._crop(frame["all"], y0, x0) - # Apply global normalization - crop_all = (crop_all - self.global_mean) / self.global_std - return { - "all" : crop_all, + "all" : self._crop(frame["all"], y0, x0), "mask": crop_mask } -# 2) U-NET ARCHITECTURE -class UNet(nn.Module): +# 2) IMPROVED U-NET ARCHITECTURE WITH RESIDUAL CONNECTIONS +class ResidualBlock(nn.Module): + """Residual block with two convolutions and skip connection""" + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2d(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + # Skip connection if dimensions don't match + self.skip = nn.Identity() + if in_channels != out_channels: + self.skip = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1), + nn.BatchNorm2d(out_channels) + ) + + def forward(self, x): + residual = self.skip(x) + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += residual + out = self.relu(out) + + return out + +class ImprovedUNet(nn.Module): """ - A simplified U-Net for binary segmentation: - in: (N, 1, H, W) ++++ BX, BY, JZ - out: (N, 1, H, W) + Improved U-Net with residual blocks and better normalization """ - def __init__(self, input_channels=4, base_channels=64): + def __init__(self, input_channels=4, base_channels=32): super().__init__() - self.enc1 = self.double_conv(input_channels, base_channels) # -> base_channels - self.enc2 = self.double_conv(base_channels, base_channels*2) # -> 2*base_channels - self.enc3 = self.double_conv(base_channels*2, base_channels*4) # -> 4*base_channels + + # Encoder + self.enc1 = ResidualBlock(input_channels, base_channels) + self.enc2 = ResidualBlock(base_channels, base_channels*2) + self.enc3 = ResidualBlock(base_channels*2, base_channels*4) + self.enc4 = ResidualBlock(base_channels*4, base_channels*8) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.dropout = nn.Dropout2d(0.1) - self.bottleneck = self.double_conv(base_channels*4, base_channels*8) + # Bottleneck + self.bottleneck = ResidualBlock(base_channels*8, base_channels*16) # Decoder + self.up4 = nn.ConvTranspose2d(base_channels*16, base_channels*8, kernel_size=2, stride=2) + self.dec4 = ResidualBlock(base_channels*16, base_channels*8) + self.up3 = nn.ConvTranspose2d(base_channels*8, base_channels*4, kernel_size=2, stride=2) - self.dec3 = self.double_conv(base_channels*8, base_channels*4) + self.dec3 = ResidualBlock(base_channels*8, base_channels*4) self.up2 = nn.ConvTranspose2d(base_channels*4, base_channels*2, kernel_size=2, stride=2) - self.dec2 = self.double_conv(base_channels*4, base_channels*2) + self.dec2 = ResidualBlock(base_channels*4, base_channels*2) self.up1 = nn.ConvTranspose2d(base_channels*2, base_channels, kernel_size=2, stride=2) - self.dec1 = self.double_conv(base_channels*2, base_channels) + self.dec1 = ResidualBlock(base_channels*2, base_channels) self.out_conv = nn.Conv2d(base_channels, 1, kernel_size=1) - - # Initialize weights for better stability - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def double_conv(self, in_ch, out_ch): - return nn.Sequential( - nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), - nn.ReLU(inplace=True) - ) def forward(self, x): # Encoder - e1 = self.enc1(x) # shape: [N, base_channels, H, W] - p1 = self.pool(e1) # half spatial dims - - e2 = self.enc2(p1) # [N, 2*base_channels, H/2, W/2] + e1 = self.enc1(x) + p1 = self.pool(e1) + + e2 = self.enc2(p1) p2 = self.pool(e2) - - e3 = self.enc3(p2) # [N, 4*base_channels, H/4, W/4] - p3 = self.pool(e3) # [N, 4*base_channels, H/8, W/8] + p2 = self.dropout(p2) + + e3 = self.enc3(p2) + p3 = self.pool(e3) + p3 = self.dropout(p3) + + e4 = self.enc4(p3) + p4 = self.pool(e4) + p4 = self.dropout(p4) # Bottleneck - b = self.bottleneck(p3) # [N, 8*base_channels, H/8, W/8] + b = self.bottleneck(p4) # Decoder - u3 = self.up3(b) # -> shape ~ e3 - cat3 = torch.cat([u3, e3], dim=1) # skip connection + u4 = self.up4(b) + cat4 = torch.cat([u4, e4], dim=1) + d4 = self.dec4(cat4) + + u3 = self.up3(d4) + cat3 = torch.cat([u3, e3], dim=1) d3 = self.dec3(cat3) - u2 = self.up2(d3) # -> shape ~ e2 + u2 = self.up2(d3) cat2 = torch.cat([u2, e2], dim=1) d2 = self.dec2(cat2) - u1 = self.up1(d2) # -> shape ~ e1 + u1 = self.up1(d2) cat1 = torch.cat([u1, e1], dim=1) d1 = self.dec1(cat1) out = self.out_conv(d1) - return out # We'll apply sigmoid in the loss or after - -# TRAIN & VALIDATION UTILS -def train_one_epoch(model, loader, criterion, optimizer, device, scaler=None, use_amp=False, amp_dtype=torch.float16): - model.train() - running_loss = 0.0 - num_batches = 0 - num_skipped = 0 - - for batch_idx, batch in enumerate(loader): - all_data, mask = batch["all"].to(device), batch["mask"].to(device) - - if use_amp: - # Clear gradients - optimizer.zero_grad() - - # Use autocast for forward pass - with autocast(device_type='cuda', dtype=amp_dtype): - pred = model(all_data) - loss = criterion(pred, mask) - - # Check if loss is valid - if not torch.isfinite(loss): - print(f"Warning: Non-finite loss detected in batch {batch_idx}, skipping...") - num_skipped += 1 - continue - - # For bfloat16, we don't use GradScaler - if amp_dtype == torch.bfloat16: - loss.backward() - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - - if not torch.isfinite(grad_norm): - print(f"Warning: Non-finite gradients detected in batch {batch_idx}, skipping...") - num_skipped += 1 - optimizer.zero_grad() - continue - - optimizer.step() - else: - # Use GradScaler for float16 - scaled_loss = scaler.scale(loss) - scaled_loss.backward() - - # Unscale gradients before clipping - scaler.unscale_(optimizer) - - # Clip gradients - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - - # Check gradient norm - if not torch.isfinite(grad_norm): - print(f"Warning: Non-finite gradients detected in batch {batch_idx}, skipping...") - num_skipped += 1 - optimizer.zero_grad() # Clear the invalid gradients - scaler.update() # Update scaler state - continue - - # Optimizer step and scaler update - scaler.step(optimizer) - scaler.update() - - else: - # Standard training without AMP - optimizer.zero_grad() - pred = model(all_data) - loss = criterion(pred, mask) - - if not torch.isfinite(loss): - print(f"Warning: Non-finite loss detected in batch {batch_idx}, skipping...") - num_skipped += 1 - continue - - loss.backward() - - # Gradient clipping - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - - if not torch.isfinite(grad_norm): - print(f"Warning: Non-finite gradients detected in batch {batch_idx}, skipping...") - num_skipped += 1 - optimizer.zero_grad() - continue - - optimizer.step() - - running_loss += loss.item() - num_batches += 1 - - if num_skipped > 0: - print(f" Skipped {num_skipped} batches due to numerical issues") - - return running_loss / max(num_batches, 1) - -def validate_one_epoch(model, loader, criterion, device, use_amp=False, amp_dtype=torch.float16): - model.eval() - val_loss = 0.0 - with torch.no_grad(): - for batch in loader: - all_data, mask = batch["all"].to(device), batch["mask"].to(device) - - if use_amp: - with autocast(device_type='cuda', dtype=amp_dtype): - pred = model(all_data) - loss = criterion(pred, mask) - else: - pred = model(all_data) - loss = criterion(pred, mask) - - val_loss += loss.item() - return val_loss / len(loader) - -class FocalLoss(nn.Module): - def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'): - """ - Focal Loss implementation - - Parameters: - alpha (float): Weighting factor for the rare class (X-points), default=0.25 - gamma (float): Focusing parameter that reduces the loss for well-classified examples, default=2.0 - reduction (str): 'mean' or 'sum', how to reduce the loss over the batch - """ - super().__init__() - self.alpha = alpha - self.gamma = gamma - self.reduction = reduction - - def forward(self, inputs, targets): - """ - inputs: Model predictions (logits, before sigmoid), shape [N, 1, H, W] - targets: Ground truth binary masks, shape [N, 1, H, W] - """ - # Apply sigmoid to get probabilities - bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') - - # Get probabilities for positive class - probs = torch.sigmoid(inputs) - # For targets=1 (X-points), pt = p; for targets=0 (non-X-points), pt = 1-p - pt = torch.where(targets == 1, probs, 1 - probs) - - # Apply focusing parameter - focal_weight = (1 - pt) ** self.gamma - - # Apply alpha weighting: alpha for X-points, (1-alpha) for non-X-points - alpha_weight = torch.where(targets == 1, self.alpha, 1 - self.alpha) - - # Combine all factors - focal_loss = alpha_weight * focal_weight * bce_loss - - # Apply reduction - if self.reduction == 'mean': - return focal_loss.mean() - elif self.reduction == 'sum': - return focal_loss.sum() - else: - return focal_loss + return out +# SIMPLE DICE LOSS (from original file) class DiceLoss(nn.Module): def __init__(self, smooth=1.0, eps=1e-7): """ @@ -579,9 +455,6 @@ def forward(self, inputs, targets): """ # Apply sigmoid to get probabilities inputs = torch.sigmoid(inputs) - - # Ensure inputs are in valid range to prevent NaN - inputs = torch.clamp(inputs, min=self.eps, max=1.0 - self.eps) inputs = inputs.view(-1) targets = targets.view(-1) @@ -596,15 +469,66 @@ def forward(self, inputs, targets): # Return Dice loss (1 - Dice coefficient) return 1.0 - dice -# PLOTTING FUNCTION +# TRAIN & VALIDATION UTILS +def train_one_epoch(model, loader, criterion, optimizer, device, scaler, use_amp, amp_dtype): + model.train() + running_loss = 0.0 + + for batch in loader: + all_data, mask = batch["all"].to(device), batch["mask"].to(device) + + with autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp): + pred = model(all_data) + loss = criterion(pred, mask) + + if not torch.isfinite(loss): + print(f"Warning: Non-finite loss detected (loss = {loss.item()}). Skipping batch.") + continue + + optimizer.zero_grad() + + if use_amp and scaler is not None: # float16 path + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + scaler.step(optimizer) + scaler.update() + elif use_amp: # bfloat16 path (no scaler) + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + else: # Standard float32 path + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + + running_loss += loss.item() + + return running_loss / len(loader) if len(loader) > 0 else 0.0 + +def validate_one_epoch(model, loader, criterion, device, use_amp, amp_dtype): + model.eval() + val_loss = 0.0 + with torch.no_grad(): + for batch in loader: + all_data, mask = batch["all"].to(device), batch["mask"].to(device) + + with autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp): + pred = model(all_data) + loss = criterion(pred, mask) + + val_loss += loss.item() + return val_loss / len(loader) if len(loader) > 0 else 0.0 + +# PLOTTING FUNCTIONS def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation, - reflectionAxis, filenameBase, interpFac, - xpoint_mask=None, + reflectionAxis, filenameBase, interpFac, + xpoint_mask=None, titleExtra="", - outDir="plots", + outDir="plots", saveFig=True): """ - Plots the vector potential 'psi_np' as contours, + Plots the vector potential 'psi_np' as contours, then overlays X-points from xpoint_mask (if provided, shape [Nx,Ny]). The figure is saved to outDir """ @@ -615,7 +539,7 @@ def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation, cs = plt.contour( x / params["axesNorm"], y / params["axesNorm"], - np.transpose(psi_np), + np.transpose(psi_np), params["numContours"], colors=params["colorContours"], linewidths=0.75 @@ -652,10 +576,14 @@ def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation, plt.close() -def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, filenameBase, - outDir="plots", saveFig=True): +def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, filenameBase, + outDir="plots", saveFig=True): """ - Visualize model performance comparing predictions with ground truth. + Visualize model performance comparing predictions with ground truth: + - True Positives (green) + - False Positives (red) + - False Negatives (yellow) + - Background shows psi contours """ plt.figure(figsize=(12, 8)) @@ -665,7 +593,7 @@ def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, fi cs = plt.contour( x / params["axesNorm"], y / params["axesNorm"], - np.transpose(psi_np), + np.transpose(psi_np), params["numContours"], colors='k', linewidths=0.75 @@ -685,16 +613,25 @@ def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, fi fn_rows, fn_cols = np.where(fn_mask) if len(tp_rows) > 0: - plt.plot(x[tp_rows] / params["axesNorm"], y[tp_cols] / params["axesNorm"], - 'o', color='green', markersize=8, label="True Positives") + plt.plot( + x[tp_rows] / params["axesNorm"], + y[tp_cols] / params["axesNorm"], + 'o', color='green', markersize=8, label="True Positives" + ) if len(fp_rows) > 0: - plt.plot(x[fp_rows] / params["axesNorm"], y[fp_cols] / params["axesNorm"], - 'o', color='red', markersize=8, label="False Positives") + plt.plot( + x[fp_rows] / params["axesNorm"], + y[fp_cols] / params["axesNorm"], + 'o', color='red', markersize=8, label="False Positives" + ) if len(fn_rows) > 0: - plt.plot(x[fn_rows] / params["axesNorm"], y[fn_cols] / params["axesNorm"], - 'o', color='yellow', markersize=8, label="False Negatives") + plt.plot( + x[fn_rows] / params["axesNorm"], + y[fn_cols] / params["axesNorm"], + 'o', color='yellow', markersize=8, label="False Negatives" + ) plt.xlabel(r"$x/d_i$") plt.ylabel(r"$y/d_i$") @@ -727,16 +664,17 @@ def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, fi def plot_training_history(train_losses, val_losses, save_path='plots/training_history.png'): """ Plots training and validation losses across epochs. + + Parameters: + train_losses (list): List of training losses for each epoch + val_losses (list): List of validation losses for each epoch + save_path (str): Path to save the resulting plot """ plt.figure(figsize=(10, 6)) epochs = range(1, len(train_losses) + 1) - # Filter out NaN values for plotting - train_losses_clean = [loss if not np.isnan(loss) else None for loss in train_losses] - val_losses_clean = [loss if not np.isnan(loss) else None for loss in val_losses] - - plt.plot(epochs, train_losses_clean, 'b-', label='Training Loss') - plt.plot(epochs, val_losses_clean, 'r-', label='Validation Loss') + plt.plot(epochs, train_losses, 'b-', label='Training Loss') + plt.plot(epochs, val_losses, 'r-', label='Validation Loss') plt.title('Training and Validation Loss') plt.xlabel('Epochs') @@ -746,12 +684,9 @@ def plot_training_history(train_losses, val_losses, save_path='plots/training_hi plt.grid(True, linestyle='--', alpha=0.7) # Add some padding to y-axis to make visualization clearer - # Handle case where all values might be NaN - valid_losses = [loss for loss in train_losses + val_losses if loss is not None and not np.isnan(loss)] - if valid_losses: - ymin = min(valid_losses) * 0.9 - ymax = max(valid_losses) * 1.1 - plt.ylim(ymin, ymax) + ymin = min(min(train_losses), min(val_losses)) * 0.9 + ymax = max(max(train_losses), max(val_losses)) * 1.1 + plt.ylim(ymin, ymax) plt.savefig(save_path, dpi=300) print(f"Training history plot saved to {save_path}") @@ -759,67 +694,94 @@ def plot_training_history(train_losses, val_losses, save_path='plots/training_hi def parseCommandLineArgs(): parser = argparse.ArgumentParser(description='ML-based reconnection classifier') - parser.add_argument('--learningRate', type=float, default=1e-5, help='specify the learning rate') - parser.add_argument('--batchSize', type=int, default=1, help='specify the batch size') - parser.add_argument('--epochs', type=int, default=2000, help='specify the number of epochs') - parser.add_argument('--trainFrameFirst', type=int, default=1, help='specify the number of the first frame used for training') - parser.add_argument('--trainFrameLast', type=int, default=140, help='specify the number of the last frame (exclusive) used for training') - parser.add_argument('--validationFrameFirst', type=int, default=141, help='specify the number of the first frame used for validation') - parser.add_argument('--validationFrameLast', type=int, default=150, help='specify the number of the last frame (exclusive) used for validation') - parser.add_argument('--minTrainingLoss', type=int, default=3, help='''minimum reduction in training loss in orders of magnitude, set to 0 to disable the check''') - parser.add_argument('--checkPointFrequency', type=int, default=10, help='number of epochs between checkpoints') - parser.add_argument('--paramFile', type=Path, default=None, help='''specify the path to the parameter txt file, the parent directory of that file must contain the gkyl input training data''') - parser.add_argument('--xptCacheDir', type=Path, default=None, help='''specify the path to a directory that will be used to cache the outputs of the analytic Xpoint finder''') - parser.add_argument('--plot', action=argparse.BooleanOptionalAction, help='create figures of the ground truth X-points and model identified X-points') - parser.add_argument('--plotDir', type=Path, default="./plots", help='directory where figures are written') - parser.add_argument('--use-amp', action='store_true', help='use automatic mixed precision training') - parser.add_argument('--amp-dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type for mixed precision (float16 or bfloat16)') + parser.add_argument('--learningRate', type=float, default=1e-4, + help='specify the learning rate (default: 1e-4)') + parser.add_argument('--batchSize', type=int, default=8, + help='specify the batch size (default: 8)') + parser.add_argument('--epochs', type=int, default=100, + help='specify the number of epochs (default: 100)') + parser.add_argument('--trainFrameFirst', type=int, default=1, + help='specify the number of the first frame used for training') + parser.add_argument('--trainFrameLast', type=int, default=140, + help='specify the number of the last frame (exclusive) used for training') + parser.add_argument('--validationFrameFirst', type=int, default=141, + help='specify the number of the first frame used for validation') + parser.add_argument('--validationFrameLast', type=int, default=150, + help='specify the number of the last frame (exclusive) used for validation') + parser.add_argument('--minTrainingLoss', type=int, default=2, + help=''' + minimum reduction in training loss in orders of magnitude, + set to 0 to disable the check (default: 2) + ''') + parser.add_argument('--checkPointFrequency', type=int, default=10, + help='number of epochs between checkpoints') + parser.add_argument('--paramFile', type=Path, default=None, + help=''' + specify the path to the parameter txt file, the parent + directory of that file must contain the gkyl input training data + ''') + parser.add_argument('--xptCacheDir', type=Path, default=None, + help=''' + specify the path to a directory that will be used to cache + the outputs of the analytic Xpoint finder + ''') + parser.add_argument('--plot', action=argparse.BooleanOptionalAction, + help='create figures of the ground truth X-points and model identified X-points') + parser.add_argument('--plotDir', type=Path, default="./plots", + help='directory where figures are written') + parser.add_argument('--use-amp', action='store_true', + help='use automatic mixed precision training') + parser.add_argument('--amp-dtype', type=str, default='bfloat16', + choices=['float16', 'bfloat16'], help='data type for mixed precision (bfloat16 recommended)') + parser.add_argument('--patience', type=int, default=15, + help='patience for early stopping (default: 15)') args = parser.parse_args() return args def checkCommandLineArgs(args): - if args.xptCacheDir is not None: - if not args.xptCacheDir.is_dir(): - print(f"Xpoint cache directory {args.xptCacheDir} does not exist. Please create the directory... exiting") - sys.exit() - - if args.paramFile is None: - print(f"parameter file required but not set... exiting") - sys.exit() + if args.xptCacheDir != None: + if not args.xptCacheDir.is_dir(): + print(f"Xpoint cache directory {args.xptCacheDir} does not exist. " + "Please create the directory... exiting") + sys.exit() + + if args.paramFile == None: + print(f"parameter file required but not set... exiting") + sys.exit() if args.paramFile.is_dir(): - print(f"parameter file {args.paramFile} is a directory ... exiting") - sys.exit() + print(f"parameter file {args.paramFile} is a directory ... exiting") + sys.exit() if not args.paramFile.exists(): - print(f"parameter file {args.paramFile} does not exist... exiting") - sys.exit() + print(f"parameter file {args.paramFile} does not exist... exiting") + sys.exit() if args.trainFrameFirst == 0 or args.validationFrameFirst == 0: - print(f"frame 0 does not contain valid data... exiting") - sys.exit() + print(f"frame 0 does not contain valid data... exiting") + sys.exit() if args.trainFrameLast <= args.trainFrameFirst: - print(f"training frame range isn't valid... exiting") - sys.exit() + print(f"training frame range isn't valid... exiting") + sys.exit() if args.validationFrameLast <= args.validationFrameFirst: - print(f"validation frame range isn't valid... exiting") - sys.exit() + print(f"validation frame range isn't valid... exiting") + sys.exit() if args.learningRate <= 0: - print(f"learningRate must be > 0... exiting") - sys.exit() + print(f"learningRate must be > 0... exiting") + sys.exit() if args.batchSize < 1: - print(f"batchSize must be >= 1... exiting") - sys.exit() + print(f"batchSize must be >= 1... exiting") + sys.exit() if args.minTrainingLoss < 0: - print(f"minTrainingLoss must be >= 0... exiting") - sys.exit() + print(f"minTrainingLoss must be >= 0... exiting") + sys.exit() if args.checkPointFrequency < 0: - print(f"checkPointFrequency must be >= 0... exiting") - sys.exit() + print(f"checkPointFrequency must be >= 0... exiting") + sys.exit() def printCommandLineArgs(args): print("Config {") @@ -828,25 +790,39 @@ def printCommandLineArgs(args): print("}") # Function to save model checkpoint -def save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch, checkpoint_dir="checkpoints", scaler=None): +def save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch, checkpoint_dir="checkpoints", scaler=None, best_val_loss=None): """ Save model checkpoint including model state, optimizer state, and training metrics + + Parameters: + model: The neural network model + optimizer: The optimizer used for training + train_loss: List of training losses + val_loss: List of validation losses + epoch: Current epoch number + checkpoint_dir: Directory to save checkpoints + scaler: GradScaler instance if using AMP + best_val_loss: Best validation loss so far """ os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir, f"xpoint_model_epoch_{epoch}.pt") + # Create checkpoint dictionary checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_loss': train_loss, - 'val_loss': val_loss + 'val_loss': val_loss, + 'best_val_loss': best_val_loss } + # Save scaler state if using AMP if scaler is not None: checkpoint['scaler_state_dict'] = scaler.state_dict() + # Save checkpoint torch.save(checkpoint, checkpoint_path) print(f"Model checkpoint saved at epoch {epoch} to {checkpoint_path}") @@ -866,13 +842,28 @@ def save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch, checkpo def load_model_checkpoint(model, optimizer, checkpoint_path, scaler=None): """ Load model checkpoint + + Parameters: + model: The neural network model to load weights into + optimizer: The optimizer to load state into + checkpoint_path: Path to the checkpoint file + scaler: GradScaler instance if using AMP + + Returns: + model: Updated model with loaded weights + optimizer: Updated optimizer with loaded state + epoch: Last saved epoch number + train_loss: List of training losses + val_loss: List of validation losses + scaler: Updated scaler if using AMP + best_val_loss: Best validation loss from checkpoint """ if not os.path.exists(checkpoint_path): print(f"No checkpoint found at {checkpoint_path}") - return model, optimizer, 0, [], [], scaler + return model, optimizer, 0, [], [], scaler, float('inf') print(f"Loading checkpoint from {checkpoint_path}") - checkpoint = torch.load(checkpoint_path) + checkpoint = torch.load(checkpoint_path, weights_only=False) # Need False for optimizer state model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) @@ -880,12 +871,15 @@ def load_model_checkpoint(model, optimizer, checkpoint_path, scaler=None): epoch = checkpoint['epoch'] train_loss = checkpoint['train_loss'] val_loss = checkpoint['val_loss'] + best_val_loss = checkpoint.get('best_val_loss', float('inf')) + # Load scaler state if available if scaler is not None and 'scaler_state_dict' in checkpoint: scaler.load_state_dict(checkpoint['scaler_state_dict']) print(f"Loaded checkpoint from epoch {epoch}") - return model, optimizer, epoch, train_loss, val_loss, scaler + return model, optimizer, epoch, train_loss, val_loss, scaler, best_val_loss + def main(): args = parseCommandLineArgs() @@ -900,66 +894,60 @@ def main(): train_fnums = range(args.trainFrameFirst, args.trainFrameLast) val_fnums = range(args.validationFrameFirst, args.validationFrameLast) + print(f"Loading training data from frames {args.trainFrameFirst} to {args.trainFrameLast-1}") + print(f"Loading validation data from frames {args.validationFrameFirst} to {args.validationFrameLast-1}") + train_dataset = XPointDataset(args.paramFile, train_fnums, xptCacheDir=args.xptCacheDir, rotateAndReflect=True) val_dataset = XPointDataset(args.paramFile, val_fnums, xptCacheDir=args.xptCacheDir) - train_crop = XPointPatchDataset(train_dataset, patch=64, pos_ratio=0.6, retries=20) - val_crop = XPointPatchDataset(val_dataset, patch=64, pos_ratio=0.6, retries=20) + # Use consistent pos_ratio for both training and validation + train_crop = XPointPatchDataset(train_dataset, patch=64, pos_ratio=0.5, retries=30) + val_crop = XPointPatchDataset(val_dataset, patch=64, pos_ratio=0.5, retries=30) t1 = timer() print("time (s) to create gkyl data loader: " + str(t1-t0)) print(f"number of training frames (original + augmented): {len(train_dataset)}") print(f"number of validation frames: {len(val_dataset)}") + print(f"number of training patches per epoch: {len(train_crop)}") + print(f"number of validation patches per epoch: {len(val_crop)}") - train_loader = DataLoader(train_crop, batch_size=args.batchSize, shuffle=False) - val_loader = DataLoader(val_crop, batch_size=args.batchSize, shuffle=False) + train_loader = DataLoader(train_crop, batch_size=args.batchSize, shuffle=True, num_workers=0) + val_loader = DataLoader(val_crop, batch_size=args.batchSize, shuffle=False, num_workers=0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = UNet(input_channels=4, base_channels=64).to(device) + print(f"Using device: {device}") + + # Use the improved model + model = ImprovedUNet(input_channels=4, base_channels=32).to(device) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Total parameters: {total_params:,}") + print(f"Trainable parameters: {trainable_params:,}") + # USE SIMPLE DICE LOSS (from original) criterion = DiceLoss(smooth=1.0) - # Reduce learning rate for mixed precision training - effective_lr = args.learningRate - if args.use_amp: - # Less aggressive reduction for AMP - effective_lr = args.learningRate * 0.5 # Reduce by 2x instead of 10x - print(f"Adjusting learning rate for AMP: {args.learningRate} -> {effective_lr}") + # Use AdamW optimizer with weight decay for better generalization + optimizer = optim.AdamW(model.parameters(), lr=args.learningRate, weight_decay=1e-5) - optimizer = optim.Adam(model.parameters(), lr=effective_lr, eps=1e-4) # Higher epsilon for stability + # Learning rate scheduler with cosine annealing + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6) - # Add learning rate scheduler - scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True) - - # Initialize GradScaler for mixed precision if enabled - scaler = None - amp_dtype = torch.float16 # default + # --- AMP Setup (bfloat16 aware) --- + use_amp = args.use_amp and torch.cuda.is_available() + amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' and torch.cuda.is_bf16_supported() else torch.float16 - if args.use_amp and torch.cuda.is_available(): - if args.amp_dtype == 'bfloat16': - if torch.cuda.is_bf16_supported(): - amp_dtype = torch.bfloat16 - print("Using bfloat16 for mixed precision (no GradScaler needed)") - else: - print("Warning: bfloat16 not supported on this GPU, falling back to float16") - amp_dtype = torch.float16 - - # Only use GradScaler with float16 - if amp_dtype == torch.float16: - # Initialize with very conservative settings for stability - scaler = GradScaler( - device='cuda', - init_scale=2.**4, # Much smaller initial scale (16 instead of 256) - growth_factor=1.5, # Slower growth - backoff_factor=0.5, - growth_interval=200, # Wait longer before increasing scale - enabled=True - ) - print("Initialized GradScaler with very conservative settings for stability") - - print(f"Using Automatic Mixed Precision with {amp_dtype}") + # GradScaler is ONLY needed for float16, not bfloat16 + scaler = GradScaler(enabled=(use_amp and amp_dtype == torch.float16)) + + if use_amp: + if args.amp_dtype == 'bfloat16' and not torch.cuda.is_bf16_supported(): + print("Warning: bfloat16 not supported on this GPU. Falling back to float16.") + print(f"Using Automatic Mixed Precision (AMP) with dtype: {amp_dtype}") checkpoint_dir = "checkpoints" os.makedirs(checkpoint_dir, exist_ok=True) @@ -967,46 +955,71 @@ def main(): start_epoch = 0 train_loss = [] val_loss = [] + best_val_loss = float('inf') if os.path.exists(latest_checkpoint_path): - model, optimizer, start_epoch, train_loss, val_loss, scaler = load_model_checkpoint( + model, optimizer, start_epoch, train_loss, val_loss, scaler, best_val_loss = load_model_checkpoint( model, optimizer, latest_checkpoint_path, scaler ) print(f"Resuming training from epoch {start_epoch+1}") + print(f"Best validation loss so far: {best_val_loss:.6f}") else: print("Starting training from scratch") t2 = timer() print("time (s) to prepare model: " + str(t2-t1)) - if args.use_amp: - print(f"Using Automatic Mixed Precision (AMP) training with {amp_dtype}") + # Early stopping setup + patience_counter = 0 + num_epochs = args.epochs for epoch in range(start_epoch, num_epochs): - train_loss.append(train_one_epoch(model, train_loader, criterion, optimizer, device, scaler, args.use_amp, amp_dtype)) - val_loss.append(validate_one_epoch(model, val_loader, criterion, device, args.use_amp, amp_dtype)) + train_loss_epoch = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler, use_amp, amp_dtype) + val_loss_epoch = validate_one_epoch(model, val_loader, criterion, device, use_amp, amp_dtype) + + train_loss.append(train_loss_epoch) + val_loss.append(val_loss_epoch) + + current_lr = optimizer.param_groups[0]['lr'] + print(f"[Epoch {epoch+1}/{num_epochs}] LR={current_lr:.2e} TrainLoss={train_loss[-1]:.6f} ValLoss={val_loss[-1]:.6f}") - # Update learning rate based on validation loss - if not np.isnan(val_loss[-1]): - scheduler.step(val_loss[-1]) + # Learning rate scheduling + scheduler.step() - print(f"[Epoch {epoch+1}/{num_epochs}] TrainLoss={train_loss[-1]} ValLoss={val_loss[-1]}") + # Check for improvement + if val_loss[-1] < best_val_loss: + best_val_loss = val_loss[-1] + patience_counter = 0 + print(f" New best validation loss: {best_val_loss:.6f}") + # Save best model + torch.save(model.state_dict(), os.path.join(checkpoint_dir, "best_model.pt")) + else: + patience_counter += 1 + + # Save checkpoint periodically + if (epoch+1) % args.checkPointFrequency == 0: + save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch+1, checkpoint_dir, scaler, best_val_loss) - # Save model checkpoint after each epoch - if epoch % args.checkPointFrequency == 0: - save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch+1, checkpoint_dir, scaler) + # Early stopping + if patience_counter >= args.patience: + print(f"Early stopping triggered after {epoch+1} epochs (patience={args.patience})") + break plot_training_history(train_loss, val_loss) print("time (s) to train model: " + str(timer()-t2)) - requiredLossDecreaseMagnitude = args.minTrainingLoss - if len(train_loss) > 0 and train_loss[-1] > 0 and not np.isnan(train_loss[0]) and not np.isnan(train_loss[-1]): - if np.log10(abs(train_loss[0]/train_loss[-1])) < requiredLossDecreaseMagnitude: - print(f"TrainLoss reduced by less than {requiredLossDecreaseMagnitude} orders of magnitude: " - f"initial {train_loss[0]} final {train_loss[-1]} ... exiting") - return 1 - else: - print("Warning: Unable to check training loss reduction due to NaN or zero values") + # Check training progress + if len(train_loss) > 1 and train_loss[-1] > 0 and train_loss[0] > 0: + loss_reduction = np.log10(abs(train_loss[0]/train_loss[-1])) + print(f"Training loss reduced by {loss_reduction:.2f} orders of magnitude") + if args.minTrainingLoss > 0 and loss_reduction < args.minTrainingLoss: + print(f"Warning: TrainLoss reduced by less than {args.minTrainingLoss} orders of magnitude") + + # Load best model for evaluation + best_model_path = os.path.join(checkpoint_dir, "best_model.pt") + if os.path.exists(best_model_path): + print("Loading best model for evaluation...") + model.load_state_dict(torch.load(best_model_path, weights_only=True)) # (D) Plotting after training model.eval() # switch to inference mode @@ -1014,74 +1027,64 @@ def main(): interpFac = 1 # Evaluate on combined set for demonstration - full_dataset = [train_dataset, val_dataset] + full_dataset = train_dataset.data[:6] + val_dataset.data # Use only first 6 augmented training samples for plotting t4 = timer() with torch.no_grad(): - for dataset in full_dataset: - for item in dataset: - fnum = item["fnum"] - rotation = item["rotation"] - reflectionAxis = item["reflectionAxis"] - psi_np = item["psi"].numpy()[0] - mask_gt = item["mask"].numpy()[0] - x = item["x"] - y = item["y"] - filenameBase = item["filenameBase"] - params = item["params"] - - # Get CNN prediction - all_torch = item["all"].unsqueeze(0).to(device) - - if args.use_amp: - with autocast(device_type='cuda', dtype=amp_dtype): - pred_mask = model(all_torch) - else: - pred_mask = model(all_torch) - - pred_mask_np = pred_mask[0,0].cpu().numpy() - # Binarize - pred_bin = (pred_mask_np > 0.5).astype(np.float32) - + for item in full_dataset: + # item is a dict with keys: fnum, psi, mask, psi_np, mask_np, x, y, tmp, params + fnum = item["fnum"] + rotation = item["rotation"] + reflectionAxis = item["reflectionAxis"] + psi_np = np.array(item["psi"])[0] + mask_gt = np.array(item["mask"])[0] + x = item["x"] + y = item["y"] + filenameBase = item["filenameBase"] + params = item["params"] + + # Get CNN prediction + all_torch = item["all"].unsqueeze(0).to(device) + + with autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp): + pred_mask = model(all_torch) pred_prob = torch.sigmoid(pred_mask) - pred_prob_np = (pred_prob > 0.5).float().cpu().numpy() - - pred_mask_bin = (pred_prob_np > 0.5).astype(np.float32) - - print(f"Frame {fnum} rotation {rotation} reflectionAxis {reflectionAxis}:") - print(f"psi shape: {psi_np.shape}, min: {psi_np.min()}, max: {psi_np.max()}") - print(f"pred_bin shape: {pred_bin.shape}, min: {pred_bin.min()}, max: {pred_bin.max()}") - print(f" Logits - min: {pred_mask_np.min():.5f}, max: {pred_mask_np.max():.5f}, mean: {pred_mask_np.mean():.5f}") - print(f" Probabilities (after sigmoid) - min: {pred_prob_np.min():.5f}, max: {pred_prob_np.max():.5f}, mean: {pred_prob_np.mean():.5f}") - print(f" Binary Mask (X-points) - count of 1s: {np.sum(pred_mask_bin)} / {pred_mask_bin.size} pixels") - print(f" Binary Mask (X_points) - shape: {pred_mask_bin.shape}, min: {pred_mask_bin.min()}, max: {pred_mask_bin.max()}") - - if args.plot: - # Plot GROUND TRUTH - plot_psi_contours_and_xpoints( - psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, - xpoint_mask=mask_gt, - titleExtra="GTXpoints", - outDir=outDir, - saveFig=True - ) - - # Plot CNN PREDICTIONS - plot_psi_contours_and_xpoints( - psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, - xpoint_mask=np.squeeze(pred_mask_bin), - titleExtra="CNNXpoints", - outDir=outDir, - saveFig=True - ) - - pred_prob_np_full = pred_prob.cpu().numpy() - plot_model_performance( - psi_np, pred_prob_np_full, mask_gt, x, y, params, fnum, filenameBase, - outDir=outDir, - saveFig=True - ) + + # Convert to float32 before numpy conversion (fixes BFloat16 error) + pred_mask_np = pred_mask[0,0].float().cpu().numpy() + pred_prob_np = pred_prob.float().cpu().numpy() + + pred_mask_bin = (pred_prob_np[0,0] > 0.5).astype(np.float32) + + print(f"Frame {fnum} rotation {rotation} reflectionAxis {reflectionAxis}:") + print(f" Probabilities - min: {pred_prob_np.min():.5f}, max: {pred_prob_np.max():.5f}, mean: {pred_prob_np.mean():.5f}") + print(f" Binary Mask (X-points) - count of 1s: {np.sum(pred_mask_bin)} / {pred_mask_bin.size} pixels") + + if args.plot: + # Plot GROUND TRUTH + plot_psi_contours_and_xpoints( + psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, + xpoint_mask=mask_gt, + titleExtra="GTXpoints", + outDir=outDir, + saveFig=True + ) + + # Plot CNN PREDICTIONS + plot_psi_contours_and_xpoints( + psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, + xpoint_mask=pred_mask_bin, + titleExtra="CNNXpoints", + outDir=outDir, + saveFig=True + ) + + plot_model_performance( + psi_np, pred_prob_np, mask_gt, x, y, params, fnum, filenameBase, + outDir=outDir, + saveFig=True + ) t5 = timer() print("time (s) to apply model: " + str(t5-t4)) From 633e11104ce89765ddcbab2adbf3808bfeaaf4b8 Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Fri, 22 Aug 2025 13:09:58 -0400 Subject: [PATCH 3/7] Made changes to make implementations merged from main work --- XPointMLTest.py | 124 ++++++++++++++++++++++------------------------ ci_tests.py | 34 ++++++++++--- test_xpoint_ml.py | 22 ++++---- 3 files changed, 99 insertions(+), 81 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index 0922ddd..87558c6 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -137,32 +137,32 @@ def getPgkylData(paramFile, frameNumber, verbosity): def cachedPgkylDataExists(cacheDir, frameNumber, fieldName): if cacheDir == None: - return False + return False else: - cachedFrame = cacheDir / f"{frameNumber}_{fieldName}.npy" - return cachedFrame.exists(); + cachedFrame = cacheDir / f"{frameNumber}_{fieldName}.npy" + return cachedFrame.exists(); def loadPgkylDataFromCache(cacheDir, frameNumber, fields): outFields = {} if cacheDir != None: - for name in fields.keys(): + for name in fields.keys(): if name == "fileName": - with open(cacheDir / f"{frameNumber}_{name}.txt", "r") as file: - outFields[name] = file.read().rstrip() + with open(cacheDir / f"{frameNumber}_{name}.txt", "r") as file: + outFields[name] = file.read().rstrip() else: - outFields[name] = np.load(cacheDir / f"{frameNumber}_{name}.npy") - return outFields + outFields[name] = np.load(cacheDir / f"{frameNumber}_{name}.npy") + return outFields else: - return None + return None def writePgkylDataToCache(cacheDir, frameNumber, fields): if cacheDir != None: - for name, field in fields.items(): + for name, field in fields.items(): if name == "fileName": - with open(cacheDir / f"{frameNumber}_{name}.txt", "w") as text_file: - text_file.write(f"{field}") + with open(cacheDir / f"{frameNumber}_{name}.txt", "w") as text_file: + text_file.write(f"{field}") else: - np.save(cacheDir / f"{frameNumber}_{name}.npy",field) + np.save(cacheDir / f"{frameNumber}_{name}.npy",field) # DATASET DEFINITION class XPointDataset(Dataset): @@ -174,7 +174,7 @@ class XPointDataset(Dataset): - Returns (psiTensor, maskTensor) as a PyTorch (float) pair. """ def __init__(self, paramFile, fnumList, xptCacheDir=None, - rotateAndReflect=False, verbosity=0): + rotateAndReflect=False, verbosity=0): """ paramFile: Path to parameter file (string). fnumList: List of frames to iterate. @@ -207,11 +207,11 @@ def __init__(self, paramFile, fnumList, xptCacheDir=None, frameData = self.load(fnum) self.data.append(frameData) if rotateAndReflect: - self.data.append(rotate(frameData,90)) - self.data.append(rotate(frameData,180)) - self.data.append(rotate(frameData,270)) - self.data.append(reflect(frameData,0)) - self.data.append(reflect(frameData,1)) + self.data.append(rotate(frameData,90)) + self.data.append(rotate(frameData,180)) + self.data.append(rotate(frameData,270)) + self.data.append(reflect(frameData,0)) + self.data.append(reflect(frameData,1)) def __len__(self): return len(self.data) @@ -225,7 +225,7 @@ def load(self, fnum): # check if cache exists if self.xptCacheDir != None: if not self.xptCacheDir.is_dir(): - print(f"Xpoint cache directory {self.xptCacheDir} does not exist... exiting") + print(f"Xpoint cache directory {self.xptCacheDir} does not exist... exiting") sys.exit() t2 = timer() @@ -284,8 +284,8 @@ def load(self, fnum): "rotation": 0, "reflectionAxis": -1, # no reflection "psi": torch.from_numpy(fields["psi"]).float().unsqueeze(0), # Keep original for plotting - "all": all_torch, # Normalized for training - "mask": mask_torch, # shape [1, Nx, Ny] + "all": all_torch, # Normalized for training + "mask": mask_torch, # shape [1, Nx, Ny] "x": fields["coords"][0], "y": fields["coords"][1], "filenameBase": fields["fileName"], @@ -367,7 +367,7 @@ def forward(self, x): return out -class ImprovedUNet(nn.Module): +class UNet(nn.Module): """ Improved U-Net with residual blocks and better normalization """ @@ -529,7 +529,7 @@ def validate_one_epoch(model, loader, criterion, device, use_amp, amp_dtype): # PLOTTING FUNCTIONS def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation, - reflectionAxis, filenameBase, interpFac, + reflectionAxis, filenameBase, interpFac, xpoint_mask=None, titleExtra="", outDir="plots", @@ -584,7 +584,7 @@ def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation, plt.close() def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, filenameBase, - outDir="plots", saveFig=True): + outDir="plots", saveFig=True): """ Visualize model performance comparing predictions with ground truth: - True Positives (green) @@ -702,50 +702,50 @@ def plot_training_history(train_losses, val_losses, save_path='plots/training_hi def parseCommandLineArgs(): parser = argparse.ArgumentParser(description='ML-based reconnection classifier') parser.add_argument('--learningRate', type=float, default=1e-4, - help='specify the learning rate (default: 1e-4)') + help='specify the learning rate (default: 1e-4)') parser.add_argument('--batchSize', type=int, default=8, - help='specify the batch size (default: 8)') + help='specify the batch size (default: 8)') parser.add_argument('--epochs', type=int, default=100, - help='specify the number of epochs (default: 100)') + help='specify the number of epochs (default: 100)') parser.add_argument('--trainFrameFirst', type=int, default=1, - help='specify the number of the first frame used for training') + help='specify the number of the first frame used for training') parser.add_argument('--trainFrameLast', type=int, default=140, - help='specify the number of the last frame (exclusive) used for training') + help='specify the number of the last frame (exclusive) used for training') parser.add_argument('--validationFrameFirst', type=int, default=141, - help='specify the number of the first frame used for validation') + help='specify the number of the first frame used for validation') parser.add_argument('--validationFrameLast', type=int, default=150, - help='specify the number of the last frame (exclusive) used for validation') + help='specify the number of the last frame (exclusive) used for validation') parser.add_argument('--minTrainingLoss', type=int, default=2, - help=''' - minimum reduction in training loss in orders of magnitude, - set to 0 to disable the check (default: 2) - ''') + help=''' + minimum reduction in training loss in orders of magnitude, + set to 0 to disable the check (default: 2) + ''') parser.add_argument('--checkPointFrequency', type=int, default=10, - help='number of epochs between checkpoints') + help='number of epochs between checkpoints') parser.add_argument('--paramFile', type=Path, default=None, - help=''' - specify the path to the parameter txt file, the parent - directory of that file must contain the gkyl input training data - ''') + help=''' + specify the path to the parameter txt file, the parent + directory of that file must contain the gkyl input training data + ''') parser.add_argument('--xptCacheDir', type=Path, default=None, - help=''' - specify the path to a directory that will be used to cache - the outputs of the analytic Xpoint finder - ''') + help=''' + specify the path to a directory that will be used to cache + the outputs of the analytic Xpoint finder + ''') parser.add_argument('--plot', action=argparse.BooleanOptionalAction, - help='create figures of the ground truth X-points and model identified X-points') + help='create figures of the ground truth X-points and model identified X-points') parser.add_argument('--plotDir', type=Path, default="./plots", - help='directory where figures are written') + help='directory where figures are written') parser.add_argument('--use-amp', action='store_true', - help='use automatic mixed precision training') + help='use automatic mixed precision training') parser.add_argument('--amp-dtype', type=str, default='bfloat16', - choices=['float16', 'bfloat16'], help='data type for mixed precision (bfloat16 recommended)') + choices=['float16', 'bfloat16'], help='data type for mixed precision (bfloat16 recommended)') parser.add_argument('--patience', type=int, default=15, - help='patience for early stopping (default: 15)') + help='patience for early stopping (default: 15)') # CI TEST: Add smoke test flag parser.add_argument('--smoke-test', action='store_true', - help='Run a minimal smoke test for CI (overrides other parameters)') + help='Run a minimal smoke test for CI (overrides other parameters)') args = parser.parse_args() return args @@ -944,22 +944,18 @@ def main(): # Original data loading train_fnums = range(args.trainFrameFirst, args.trainFrameLast) val_fnums = range(args.validationFrameFirst, args.validationFrameLast) - - print(f"Loading training data from frames {args.trainFrameFirst} to {args.trainFrameLast-1}") - print(f"Loading validation data from frames {args.validationFrameFirst} to {args.validationFrameLast-1}") - - train_dataset = XPointDataset(args.paramFile, train_fnums, + + print(f"Loading training data from frames {args.trainFrameFirst} to {args.trainFrameLast-1}") + print(f"Loading validation data from frames {args.validationFrameFirst} to {args.validationFrameLast-1}") + + train_dataset = XPointDataset(args.paramFile, train_fnums, xptCacheDir=args.xptCacheDir, rotateAndReflect=True) - val_dataset = XPointDataset(args.paramFile, val_fnums, + val_dataset = XPointDataset(args.paramFile, val_fnums, xptCacheDir=args.xptCacheDir) # Use consistent pos_ratio for both training and validation train_crop = XPointPatchDataset(train_dataset, patch=64, pos_ratio=0.5, retries=30) val_crop = XPointPatchDataset(val_dataset, patch=64, pos_ratio=0.5, retries=30) - train_dataset = XPointDataset(args.paramFile, train_fnums, - xptCacheDir=args.xptCacheDir, rotateAndReflect=True) - val_dataset = XPointDataset(args.paramFile, val_fnums, - xptCacheDir=args.xptCacheDir) t1 = timer() print("time (s) to create gkyl data loader: " + str(t1-t0)) @@ -975,7 +971,7 @@ def main(): print(f"Using device: {device}") # Use the improved model - model = ImprovedUNet(input_channels=4, base_channels=32).to(device) + model = UNet(input_channels=4, base_channels=32).to(device) # Count parameters total_params = sum(p.numel() for p in model.parameters()) @@ -1001,7 +997,7 @@ def main(): if use_amp: if args.amp_dtype == 'bfloat16' and not torch.cuda.is_bf16_supported(): - print("Warning: bfloat16 not supported on this GPU. Falling back to float16.") + print("Warning: bfloat16 not supported on this GPU. Falling back to float16.") print(f"Using Automatic Mixed Precision (AMP) with dtype: {amp_dtype}") checkpoint_dir = "checkpoints" @@ -1045,7 +1041,7 @@ def main(): if val_loss[-1] < best_val_loss: best_val_loss = val_loss[-1] patience_counter = 0 - print(f" New best validation loss: {best_val_loss:.6f}") + print(f" New best validation loss: {best_val_loss:.6f}") # Save best model torch.save(model.state_dict(), os.path.join(checkpoint_dir, "best_model.pt")) else: diff --git a/ci_tests.py b/ci_tests.py index eb31555..96d1012 100644 --- a/ci_tests.py +++ b/ci_tests.py @@ -3,6 +3,10 @@ from torch.utils.data import Dataset, DataLoader import torch.optim as optim import os +import sys +# Local import within the function itself, which is a bit clunky +# but the original code did it this way. +# from XPointMLTest import validate_one_epoch class SyntheticXPointDataset(Dataset): """ @@ -50,7 +54,7 @@ def _generate_frame(self, idx): #create synthetic current (Laplacian of psi) jz = -(np.gradient(np.gradient(psi, axis=0), axis=0) + - np.gradient(np.gradient(psi, axis=1), axis=1)) + np.gradient(np.gradient(psi, axis=1), axis=1)) # create X-point mask mask = np.zeros((H, W), dtype=np.float32) @@ -97,18 +101,22 @@ def test_checkpoint_functionality(model, optimizer, device, val_loader, criterio """ # Import locally to prevent circular dependency - from XPointMLTest import validate_one_epoch + from XPointMLTest import validate_one_epoch, autocast print("\n" + "="*60) print("TESTING CHECKPOINT SAVE/LOAD FUNCTIONALITY") print("="*60) - #get initial validation loss + # Get the AMP settings from the model's current state to pass to validate_one_epoch + use_amp = isinstance(scaler, torch.cuda.amp.GradScaler) + amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + + # Get initial validation loss model.eval() - initial_loss = validate_one_epoch(model, val_loader, criterion, device) + initial_loss = validate_one_epoch(model, val_loader, criterion, device, use_amp, amp_dtype) print(f"Initial validation loss: {initial_loss:.6f}") - #saves checkpoint + # Save checkpoint with the correct AMP components test_checkpoint_path = "smoke_test_checkpoint.pt" checkpoint = { 'model_state_dict': model.state_dict(), @@ -117,24 +125,36 @@ def test_checkpoint_functionality(model, optimizer, device, val_loader, criterio 'test_value': 42 } + if scaler is not None: + checkpoint['scaler_state_dict'] = scaler.state_dict() + torch.save(checkpoint, test_checkpoint_path) print(f"Saved checkpoint to {test_checkpoint_path}") # create new model and optimizer - model2 = UNet(input_channels=4, base_channels=64).to(device) + # NOTE: The base_channels here should match the original model's base_channels (32). + # You had 64, which would cause an error later. Changed to 32. + model2 = UNet(input_channels=4, base_channels=32).to(device) optimizer2 = Adam(model2.parameters(), lr=1e-5) # load checkpoint loaded_checkpoint = torch.load(test_checkpoint_path, map_location=device, weights_only=False) model2.load_state_dict(loaded_checkpoint['model_state_dict']) optimizer2.load_state_dict(loaded_checkpoint['optimizer_state_dict']) + + # Load the scaler state if present + scaler2 = None + if 'scaler_state_dict' in loaded_checkpoint: + scaler2 = torch.cuda.amp.GradScaler() + scaler2.load_state_dict(loaded_checkpoint['scaler_state_dict']) assert loaded_checkpoint['test_value'] == 42, "Test value mismatch!" print("Checkpoint test value verified") #get loaded model validation loss model2.eval() - loaded_loss = validate_one_epoch(model2, val_loader, criterion, device) + # Now pass the AMP arguments to validate_one_epoch + loaded_loss = validate_one_epoch(model2, val_loader, criterion, device, use_amp, amp_dtype) print(f"Loaded model validation loss: {loaded_loss:.6f}") # check if losses match diff --git a/test_xpoint_ml.py b/test_xpoint_ml.py index 842a48d..5aba775 100644 --- a/test_xpoint_ml.py +++ b/test_xpoint_ml.py @@ -5,6 +5,7 @@ import os import pytest +# Make sure all required functions are imported from the main file from XPointMLTest import UNet, DiceLoss, expand_xpoints_mask, validate_one_epoch from ci_tests import SyntheticXPointDataset @@ -54,7 +55,7 @@ def test_dice_loss_no_match(dice_loss): def test_synthetic_dataset_integrity(synthetic_dataset): assert len(synthetic_dataset) == 2 item = synthetic_dataset[0] - expected_keys = ["fnum", "all", "mask", "psi", "x", "y"] + expected_keys = ["fnum", "all", "mask", "psi", "x", "y", "rotation", "reflectionAxis", "filenameBase", "params"] assert all(key in item for key in expected_keys) assert item['all'].shape == (4, 32, 32) assert item['mask'].shape == (1, 32, 32) @@ -88,13 +89,14 @@ def test_checkpoint_save_load(unet_model, synthetic_dataset): optimizer = optim.Adam(model.parameters(), lr=1e-5) criterion = DiceLoss() - #create a simple dataloader + # Create a simple dataloader val_loader = DataLoader(synthetic_dataset, batch_size=1, shuffle=False) - #get initial loss - initial_loss = validate_one_epoch(model, val_loader, criterion, device) + # get initial loss, passing the required AMP arguments + # we can assume no AMP for this CPU-based unit test + initial_loss = validate_one_epoch(model, val_loader, criterion, device, use_amp=False, amp_dtype=torch.float32) - #save checkpoint + # Save checkpoint test_checkpoint_path = "test_checkpoint_pytest.pt" checkpoint = { 'model_state_dict': model.state_dict(), @@ -104,7 +106,7 @@ def test_checkpoint_save_load(unet_model, synthetic_dataset): } torch.save(checkpoint, test_checkpoint_path) - #create new model and load + # Create new model and load model2 = UNet(input_channels=4, base_channels=16).to(device) optimizer2 = optim.Adam(model2.parameters(), lr=1e-5) @@ -114,14 +116,14 @@ def test_checkpoint_save_load(unet_model, synthetic_dataset): assert loaded_checkpoint['test_value'] == 42 - #get loaded model loss - loaded_loss = validate_one_epoch(model2, val_loader, criterion, device) + # Get loaded model loss, again passing the AMP arguments + loaded_loss = validate_one_epoch(model2, val_loader, criterion, device, use_amp=False, amp_dtype=torch.float32) - #check if losses match + # Check if losses match loss_diff = abs(initial_loss - loaded_loss) assert loss_diff < 1e-6, f"Loss difference too large: {loss_diff}" - #cleanup + # Cleanup if os.path.exists(test_checkpoint_path): os.remove(test_checkpoint_path) From 0054ff2c95dded50b8f1eb183e3ef2a76be9f623 Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Fri, 22 Aug 2025 13:26:54 -0400 Subject: [PATCH 4/7] made changes to make implementations from main work --- XPointMLTest.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index 87558c6..3294cb8 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -441,7 +441,7 @@ def forward(self, x): out = self.out_conv(d1) return out -# SIMPLE DICE LOSS (from original file) +# DICE LOSS class DiceLoss(nn.Module): def __init__(self, smooth=1.0, eps=1e-7): """ @@ -979,7 +979,6 @@ def main(): print(f"Total parameters: {total_params:,}") print(f"Trainable parameters: {trainable_params:,}") - # USE SIMPLE DICE LOSS (from original) criterion = DiceLoss(smooth=1.0) # Use AdamW optimizer with weight decay for better generalization From e8767ee5fd6ea593d6bfd812896b5c8b20944906 Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Thu, 28 Aug 2025 15:37:51 -0400 Subject: [PATCH 5/7] Fixed bug and updated README for added flags. --- README.md | 54 +++++++++++++++++++++++- XPointMLTest.py | 106 ++++++++++++++++++++++++------------------------ 2 files changed, 105 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index aabde47..55907d3 100644 --- a/README.md +++ b/README.md @@ -80,12 +80,62 @@ run it ./runReconClass.sh ``` +## Command Line Options + +The classifier supports several command line options for training configuration: + +### Training Parameters +- `--learningRate`: Learning rate for training (default: 1e-4) +- `--batchSize`: Batch size for training (default: 8) +- `--epochs`: Number of training epochs (default: 100) +- `--minTrainingLoss`: Minimum reduction in training loss in orders of magnitude (default: 2, set to 0 to disable check) + +### Data Configuration +- `--trainFrameFirst`: First frame number for training data (default: 1) +- `--trainFrameLast`: Last frame number (exclusive) for training data (default: 140) +- `--validationFrameFirst`: First frame number for validation data (default: 141) +- `--validationFrameLast`: Last frame number (exclusive) for validation data (default: 150) +- `--paramFile`: Path to the parameter txt file containing gkyl input data +- `--xptCacheDir`: Path to directory for caching X-point finder outputs + +### Training Optimization +- `--use-amp`: Enable automatic mixed precision training for faster training on modern GPUs +- `--amp-dtype`: Data type for mixed precision (`float16` or `bfloat16`, default: `bfloat16`) +- `--patience`: Patience for early stopping (default: 15 epochs) + +### Output and Monitoring +- `--plot`: Enable creation of figures showing ground truth and model-identified X-points +- `--plotDir`: Directory where figures are written (default: `./plots`) +- `--checkPointFrequency`: Number of epochs between model checkpoints (default: 10) + +### Testing +- `--smoke-test`: Run minimal smoke test for CI (overrides other parameters for quick validation) + +### Example with Advanced Options + +For faster training with mixed precision and early stopping: + +```bash +python -u ${rcRoot}/reconClassifier/XPointMLTest.py \ +--paramFile=/path/to/params.txt \ +--xptCacheDir=/path/to/cache \ +--epochs 200 \ +--learningRate 1e-4 \ +--batchSize 16 \ +--use-amp \ +--amp-dtype bfloat16 \ +--patience 20 \ +--plot \ +--trainFrameLast 100 \ +--validationFrameLast 120 +``` + ## Resuming Development Work -The following commands should be run on `checkers` **every time you create a new shell** to resume work in the existing virtual environment. +The following commands should be run on `checkers` **every time you create a new shell** to resume work in the existing virtual environment. ``` cd nsfCssiMlClassifier source envPyTorch.sh source pgkyl/bin/activate -``` +``` \ No newline at end of file diff --git a/XPointMLTest.py b/XPointMLTest.py index 3294cb8..296a864 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -1146,59 +1146,59 @@ def main(): t4 = timer() with torch.no_grad(): - for item in full_dataset: - # item is a dict with keys: fnum, psi, mask, psi_np, mask_np, x, y, tmp, params - fnum = item["fnum"] - rotation = item["rotation"] - reflectionAxis = item["reflectionAxis"] - psi_np = np.array(item["psi"])[0] - mask_gt = np.array(item["mask"])[0] - x = item["x"] - y = item["y"] - filenameBase = item["filenameBase"] - params = item["params"] - - # Get CNN prediction - all_torch = item["all"].unsqueeze(0).to(device) - - with autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp): - pred_mask = model(all_torch) - pred_prob = torch.sigmoid(pred_mask) - - # Convert to float32 before numpy conversion (fixes BFloat16 error) - pred_mask_np = pred_mask[0,0].float().cpu().numpy() - pred_prob_np = pred_prob.float().cpu().numpy() - - pred_mask_bin = (pred_prob_np[0,0] > 0.5).astype(np.float32) - - print(f"Frame {fnum} rotation {rotation} reflectionAxis {reflectionAxis}:") - print(f" Probabilities - min: {pred_prob_np.min():.5f}, max: {pred_prob_np.max():.5f}, mean: {pred_prob_np.mean():.5f}") - print(f" Binary Mask (X-points) - count of 1s: {np.sum(pred_mask_bin)} / {pred_mask_bin.size} pixels") - - if args.plot: - # Plot GROUND TRUTH - plot_psi_contours_and_xpoints( - psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, - xpoint_mask=mask_gt, - titleExtra="GTXpoints", - outDir=outDir, - saveFig=True - ) - - # Plot CNN PREDICTIONS - plot_psi_contours_and_xpoints( - psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, - xpoint_mask=pred_mask_bin, - titleExtra="CNNXpoints", - outDir=outDir, - saveFig=True - ) - - plot_model_performance( - psi_np, pred_prob_np, mask_gt, x, y, params, fnum, filenameBase, - outDir=outDir, - saveFig=True - ) + for dataset in full_dataset: + for item in dataset: + fnum = item["fnum"] + rotation = item["rotation"] + reflectionAxis = item["reflectionAxis"] + psi_np = np.array(item["psi"])[0] + mask_gt = np.array(item["mask"])[0] + x = item["x"] + y = item["y"] + filenameBase = item["filenameBase"] + params = item["params"] + + # Get CNN prediction + all_torch = item["all"].unsqueeze(0).to(device) + + with autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp): + pred_mask = model(all_torch) + pred_prob = torch.sigmoid(pred_mask) + + # Convert to float32 before numpy conversion (fixes BFloat16 error) + pred_mask_np = pred_mask[0,0].float().cpu().numpy() + pred_prob_np = pred_prob.float().cpu().numpy() + + pred_mask_bin = (pred_prob_np[0,0] > 0.5).astype(np.float32) + + print(f"Frame {fnum} rotation {rotation} reflectionAxis {reflectionAxis}:") + print(f" Probabilities - min: {pred_prob_np.min():.5f}, max: {pred_prob_np.max():.5f}, mean: {pred_prob_np.mean():.5f}") + print(f" Binary Mask (X-points) - count of 1s: {np.sum(pred_mask_bin)} / {pred_mask_bin.size} pixels") + + if args.plot: + # Plot GROUND TRUTH + plot_psi_contours_and_xpoints( + psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, + xpoint_mask=mask_gt, + titleExtra="GTXpoints", + outDir=outDir, + saveFig=True + ) + + # Plot CNN PREDICTIONS + plot_psi_contours_and_xpoints( + psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, + xpoint_mask=pred_mask_bin, + titleExtra="CNNXpoints", + outDir=outDir, + saveFig=True + ) + + plot_model_performance( + psi_np, pred_prob_np, mask_gt, x, y, params, fnum, filenameBase, + outDir=outDir, + saveFig=True + ) t5 = timer() print("time (s) to apply model: " + str(t5-t4)) From 265780badca6babee67394eb822d4507cf8b1d9f Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Thu, 28 Aug 2025 21:52:09 -0400 Subject: [PATCH 6/7] fixed bug - psi return value --- XPointMLTest.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index 296a864..536e5e3 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -283,7 +283,7 @@ def load(self, fnum): "fnum": fnum, "rotation": 0, "reflectionAxis": -1, # no reflection - "psi": torch.from_numpy(fields["psi"]).float().unsqueeze(0), # Keep original for plotting + "psi": psi_torch, # shape [1, Nx, Ny] "all": all_torch, # Normalized for training "mask": mask_torch, # shape [1, Nx, Ny] "x": fields["coords"][0], @@ -333,7 +333,11 @@ def __getitem__(self, _): "mask": crop_mask } -# 2) IMPROVED U-NET ARCHITECTURE WITH RESIDUAL CONNECTIONS + +# Improved the U-Net architecture with residual connections +# Links to understand the residual blocks: +# https://code.likeagirl.io/u-net-vs-residual-u-net-vs-attention-u-net-vs-attention-residual-u-net-899b57c5698 +# https://notes.kvfrans.com/3-building-blocks/residual-networks.html class ResidualBlock(nn.Module): """Residual block with two convolutions and skip connection""" def __init__(self, in_channels, out_channels): From f99434fccd8c4e351d642831a14e2e596bf13cf7 Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Thu, 28 Aug 2025 22:44:24 -0400 Subject: [PATCH 7/7] revert to default params --- XPointMLTest.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index 536e5e3..8404d5c 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -705,12 +705,12 @@ def plot_training_history(train_losses, val_losses, save_path='plots/training_hi def parseCommandLineArgs(): parser = argparse.ArgumentParser(description='ML-based reconnection classifier') - parser.add_argument('--learningRate', type=float, default=1e-4, - help='specify the learning rate (default: 1e-4)') - parser.add_argument('--batchSize', type=int, default=8, - help='specify the batch size (default: 8)') - parser.add_argument('--epochs', type=int, default=100, - help='specify the number of epochs (default: 100)') + parser.add_argument('--learningRate', type=float, default=1e-5, + help='specify the learning rate') + parser.add_argument('--batchSize', type=int, default=1, + help='specify the batch size') + parser.add_argument('--epochs', type=int, default=2000, + help='specify the number of epochs') parser.add_argument('--trainFrameFirst', type=int, default=1, help='specify the number of the first frame used for training') parser.add_argument('--trainFrameLast', type=int, default=140, @@ -719,10 +719,10 @@ def parseCommandLineArgs(): help='specify the number of the first frame used for validation') parser.add_argument('--validationFrameLast', type=int, default=150, help='specify the number of the last frame (exclusive) used for validation') - parser.add_argument('--minTrainingLoss', type=int, default=2, + parser.add_argument('--minTrainingLoss', type=int, default=3, help=''' minimum reduction in training loss in orders of magnitude, - set to 0 to disable the check (default: 2) + set to 0 to disable the check (default: 3) ''') parser.add_argument('--checkPointFrequency', type=int, default=10, help='number of epochs between checkpoints')