From 4a831835a3ed0ccdede2c105d90e3cb6897cb782 Mon Sep 17 00:00:00 2001 From: Jim Lu Date: Fri, 27 Jun 2025 14:28:31 -0400 Subject: [PATCH 1/3] example of cropping 64x64 image --- StrinkedXPoint.py | 440 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 440 insertions(+) create mode 100644 StrinkedXPoint.py diff --git a/StrinkedXPoint.py b/StrinkedXPoint.py new file mode 100644 index 0000000..e34300c --- /dev/null +++ b/StrinkedXPoint.py @@ -0,0 +1,440 @@ +import numpy as np +import matplotlib.pyplot as plt +import os, errno, sys, argparse +from pathlib import Path +from timeit import default_timer as timer +import sys +import argparse + +from utils import gkData +from utils import auxFuncs +from utils import plotParams + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset +from torchvision.transforms import v2 # rotate tensor + +def expand_xpoints_mask(binary_mask, kernel_size=9): + """ + Expands each X-point in a binary mask to include surrounding cells + in a square grid of size kernel_size x kernel_size. + + Parameters: + binary_mask : numpy.ndarray + 2D binary mask with 1s at X-point locations + kernel_size : int + Size of the square grid (must be odd number) + + Returns: + numpy.ndarray + Expanded binary mask with 1s in kernel_size×kernel_size regions around X-points + """ + + # Get shape of the original mask + h, w = binary_mask.shape + + # Create a copy to avoid modifying the original + expanded_mask = np.zeros_like(binary_mask) + + # Find coordinates of all X-points + x_points = np.argwhere(binary_mask > 0) + + # For each X-point, set a kernel_size×kernel_size area to 1 + half_size = kernel_size // 2 + for point in x_points: + # Get the corner coordinates for the square centered at the X-point + x_min = max(0, point[0] - half_size) + x_max = min(h, point[0] + half_size + 1) + y_min = max(0, point[1] - half_size) + y_max = min(w, point[1] + half_size + 1) + + # Set the square area to 1 + expanded_mask[x_min:x_max, y_min:y_max] = 1 + + return expanded_mask + +def rotate(frameData,deg): + if deg not in [90, 180, 270]: + print(f"invalid rotation specified... exiting") + sys.exit() + psi = v2.functional.rotate(frameData["psi"], deg, v2.InterpolationMode.BILINEAR) + all = v2.functional.rotate(frameData["all"], deg, v2.InterpolationMode.BILINEAR) + mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.BILINEAR) + return { + "fnum": frameData["fnum"], + "rotation": deg, + "reflectionAxis": -1, # no reflection + "psi": psi, + "all": all, + "mask": mask, + "x": frameData["x"], + "y": frameData["y"], + "filenameBase": frameData["filenameBase"], + "params": frameData["params"] + } + +def reflect(frameData,axis): + if axis not in [0,1]: + print(f"invalid reflection axis specified... exiting") + sys.exit() + psi = torch.flip(frameData["psi"][0], dims=(axis,)).unsqueeze(0) + all = torch.flip(frameData["all"], dims=(axis,)) + mask = torch.flip(frameData["mask"][0], dims=(axis,)).unsqueeze(0) + return { + "fnum": frameData["fnum"], + "rotation": 0, + "reflectionAxis": axis, + "psi": psi, + "all": all, + "mask": mask, + "x": frameData["x"], + "y": frameData["y"], + "filenameBase": frameData["filenameBase"], + "params": frameData["params"] + } + +def getPgkylData(paramFile, frameNumber, verbosity): + if verbosity > 0: + print(f"=== frame {frameNumber} ===") + params = {} #Initialize dictionary to store plotting and other parameters + params["polyOrderOverride"] = 0 #Override default dg interpolation and interpolate to given number of points + constrcutBandJ = 1 + #Read vector potential + var = gkData.gkData(str(paramFile),frameNumber,'psi',params).compactRead() + psi = var.data + coords = var.coords + axesNorm = var.d[ var.speciesFileIndex.index('ion') ] + if verbosity > 0: + print(f"psi shape: {psi.shape}, min={psi.min()}, max={psi.max()}") + #Construct B and J (first and second derivatives) + [df_dx,df_dy,df_dz] = auxFuncs.genGradient(psi,var.dx) + [d2f_dxdx,d2f_dxdy,d2f_dxdz] = auxFuncs.genGradient(df_dx,var.dx) + [d2f_dydx,d2f_dydy,d2f_dydz] = auxFuncs.genGradient(df_dy,var.dx) + bx = df_dy + by = -df_dx + jz = -(d2f_dxdx + d2f_dydy) / var.mu0 + del df_dx,df_dy,df_dz,d2f_dxdx,d2f_dxdy,d2f_dxdz,d2f_dydx,d2f_dydy,d2f_dydz + #Indicies of critical points, X points, and O points (max and min) + critPoints = auxFuncs.getCritPoints(psi) + [xpts, optsMax, optsMin] = auxFuncs.getXOPoints(psi, critPoints) + return [var.filenameBase, axesNorm, critPoints, xpts, optsMax, optsMin, coords, psi, bx, by, jz] + +def cachedPgkylDataExists(cacheDir, frameNumber, fieldName): + if cacheDir == None: + return False + else: + cachedFrame = cacheDir / f"{frameNumber}_{fieldName}.npy" + return cachedFrame.exists(); + +def loadPgkylDataFromCache(cacheDir, frameNumber, fields): + outFields = {} + if cacheDir != None: + for name in fields.keys(): + if name == "fileName": + with open(cacheDir / f"{frameNumber}_{name}.txt", "r") as file: + outFields[name] = file.read().rstrip() + else: + outFields[name] = np.load(cacheDir / f"{frameNumber}_{name}.npy") + return outFields + else: + return None + +def writePgkylDataToCache(cacheDir, frameNumber, fields): + if cacheDir != None: + for name, field in fields.items(): + if name == "fileName": + with open(cacheDir / f"{frameNumber}_{name}.txt", "w") as text_file: + text_file.write(f"{field}") + else: + np.save(cacheDir / f"{frameNumber}_{name}.npy",field) + + +class XPointDataset(Dataset): + """ + Dataset that processes frames in [fnumList]. For each frame (fnum): + - Sets up "params" according to your snippet. + - Reads psi from gkData (varid='psi') + - Finds X-points -> builds a 2D binary mask. + - Returns (psiTensor, maskTensor) as a PyTorch (float) pair. + """ + def __init__(self, paramFile, fnumList, xptCacheDir=None, + rotateAndReflect=False, verbosity=0): + """ + paramFile: Path to parameter file (string). + fnumList: List of frames to iterate. + """ + super().__init__() + self.paramFile = paramFile + self.fnumList = list(fnumList) # ensure indexable + self.xptCacheDir = xptCacheDir + self.verbosity = verbosity + + # We'll store a base 'params' once here, and then customize in __getitem__: + self.params = {} + # Default snippet-based constants: + self.params["lowerLimits"] = [-1e6, -1e6, -0.e6, -1.e6, -1.e6] + self.params["upperLimits"] = [1e6, 1e6, 0.e6, 1.e6, 1.e6] + self.params["restFrame"] = 1 + self.params["polyOrderOverride"] = 0 + + self.params["plotContours"] = 1 + self.params["colorContours"] = 'k' + self.params["numContours"] = 50 + self.params["axisEqual"] = 1 + self.params["symBar"] = 1 + self.params["colormap"] = 'bwr' + + + # load all the data + self.data = [] + for fnum in fnumList: + frameData = self.load(fnum) + self.data.append(frameData) + if rotateAndReflect: + self.data.append(rotate(frameData,90)) + self.data.append(rotate(frameData,180)) + self.data.append(rotate(frameData,270)) + self.data.append(reflect(frameData,0)) + self.data.append(reflect(frameData,1)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + def load(self, fnum): + t0 = timer() + + # check if cache exists + if self.xptCacheDir != None: + if not self.xptCacheDir.is_dir(): + print(f"Xpoint cache directory {self.xptCacheDir} does not exist... exiting") + sys.exit() + t2 = timer() + + fields = {"psi":None, + "critPts":None, + "xpts":None, + "optsMax":None, + "optsMin":None, + "axesNorm":None, + "coords":None, + "fileName":None, + "Bx":None, "By":None, + "Jz":None} + + # Indicies of critical points, X points, and O points (max and min) + if self.xptCacheDir != None and cachedPgkylDataExists(self.xptCacheDir, fnum, "psi"): + fields = loadPgkylDataFromCache(self.xptCacheDir, fnum, fields) + else: + [fileName, axesNorm, critPoints, xpts, optsMax, optsMin, coords, psi, bx, by, jz] = getPgkylData(self.paramFile, fnum, verbosity=self.verbosity) + fields = {"psi":psi, "critPts":critPoints, "xpts":xpts, + "optsMax":optsMax, "optsMin":optsMin, + "axesNorm": axesNorm, "coords": coords, + "fileName": fileName, + "Bx":bx, "By":by, "Jz":jz} + writePgkylDataToCache(self.xptCacheDir, fnum, fields) + self.params["axesNorm"] = fields["axesNorm"] + + if self.verbosity > 0: + print("time (s) to find X and O points: " + str(timer()-t2)) + + # Create array of 0s with 1s only at X points + binaryMap = np.zeros(np.shape(fields["psi"])) + binaryMap[fields["xpts"][:, 0], fields["xpts"][:, 1]] = 1 + + binaryMap = expand_xpoints_mask(binaryMap, kernel_size=9) + + # -------------- 6) Convert to Torch Tensors -------------- + psi_torch = torch.from_numpy(fields["psi"]).float().unsqueeze(0) # [1, Nx, Ny] + bx_torch = torch.from_numpy(fields["Bx"]).float().unsqueeze(0) + by_torch = torch.from_numpy(fields["By"]).float().unsqueeze(0) + jz_torch = torch.from_numpy(fields["Jz"]).float().unsqueeze(0) + all_torch = torch.cat((psi_torch,bx_torch,by_torch,jz_torch)) # [4, Nx, Ny] + mask_torch = torch.from_numpy(binaryMap).float().unsqueeze(0) # [1, Nx, Ny] + + if self.verbosity > 0: + print("time (s) to load and process gkyl frame: " + str(timer()-t0)) + + return { + "fnum": fnum, + "rotation": 0, + "reflectionAxis": -1, # no reflection + "psi": psi_torch, # shape [1, Nx, Ny] + "all": all_torch, # shape [4, Nx, Ny] + "mask": mask_torch, # shape [1, Nx, Ny] + "x": fields["coords"][0], + "y": fields["coords"][1], + "filenameBase": fields["fileName"], + "params": dict(self.params) # copy of the params for local plotting + } + + + +class XPointPatchDataset(Dataset): + """On‑the‑fly square crops, balancing positive / background patches.""" + def __init__(self, base_ds, patch=64, pos_ratio=0.6, retries=20): + self.base_ds = base_ds + self.patch = patch + self.pos_ratio = pos_ratio + self.retries = retries + self.rng = np.random.default_rng() + + def __len__(self): + # give each full frame K random crops per epoch (K=16 by default) + return len(self.base_ds) * 16 + + def _crop(self, arr, top, left): + return arr[..., top:top+self.patch, left:left+self.patch] + + def __getitem__(self, _): + frame = self.base_ds[self.rng.integers(len(self.base_ds))] + H, W = frame["mask"].shape[-2:] + + for attempt in range(self.retries): + y0 = self.rng.integers(0, H - self.patch + 1) + x0 = self.rng.integers(0, W - self.patch + 1) + crop_mask = self._crop(frame["mask"], y0, x0) + has_pos = crop_mask.sum() > 0 + want_pos = (attempt / self.retries) < self.pos_ratio + + if has_pos == want_pos or attempt == self.retries - 1: + return { + "all" : self._crop(frame["all"], y0, x0), + "mask": crop_mask + } + + +class UNet(nn.Module): + def __init__(self, input_channels=4, base=64): + super().__init__() + self.enc1 = self._dbl(input_channels, base, dilation=1) + self.enc2 = self._dbl(base, base*2, dilation=1) + self.enc3 = self._dbl(base*2, base*4, dilation=2) # ← dilated + self.pool = nn.MaxPool2d(2, 2) + self.bottleneck = self._dbl(base*4, base*8, dilation=4) # ← dilated + self.up3 = nn.ConvTranspose2d(base*8, base*4, 2, 2) + self.dec3 = self._dbl(base*8, base*4, dilation=1) + self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, 2) + self.dec2 = self._dbl(base*4, base*2, dilation=1) + self.up1 = nn.ConvTranspose2d(base*2, base, 2, 2) + self.dec1 = self._dbl(base*2, base, dilation=1) + self.out = nn.Conv2d(base, 1, 1) + + @staticmethod + def _dbl(inp, out, dilation=1): + pad = dilation + return nn.Sequential( + nn.Conv2d(inp, out, 3, padding=pad, dilation=dilation), + nn.ReLU(inplace=True), + nn.Conv2d(out, out, 3, padding=pad, dilation=dilation), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + e1 = self.enc1(x) + e2 = self.enc2(self.pool(e1)) + e3 = self.enc3(self.pool(e2)) + b = self.bottleneck(self.pool(e3)) + d3 = self.dec3(torch.cat([self.up3(b), e3], 1)) + d2 = self.dec2(torch.cat([self.up2(d3), e2], 1)) + d1 = self.dec1(torch.cat([self.up1(d2), e1], 1)) + return self.out(d1) + + + +class DiceLoss(nn.Module): + def __init__(self, smooth=1.): + super().__init__() + self.smooth = smooth + + def forward(self, logits, targets): + probs = torch.sigmoid(logits) + probs = probs.view(-1); targets = targets.view(-1) + inter = (probs * targets).sum() + union = probs.sum() + targets.sum() + dice = (2*inter + self.smooth) / (union + self.smooth) + return 1 - dice + + +def make_criterion(pos_weight): + bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]).float()) + dice = DiceLoss() + def _loss(logits, target): + return 0.5 * bce(logits, target) + 0.5 * dice(logits, target) + return _loss + + + +@torch.no_grad() +def evaluate(model, loader, criterion, device): + model.eval(); loss = 0 + for batch in loader: + x, y = batch["all"].to(device), batch["mask"].to(device) + logits = model(x) + loss += criterion(logits, y).item() + return loss / len(loader) + + +def train_epoch(model, loader, criterion, opt, device): + model.train(); loss = 0 + for batch in loader: + x, y = batch["all"].to(device), batch["mask"].to(device) + opt.zero_grad() + l = criterion(model(x), y) + l.backward(); opt.step(); loss += l.item() + return loss / len(loader) + + + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--paramFile", type=Path, required=True) + p.add_argument("--xptCacheDir", type=Path) + p.add_argument("--epochs", type=int, default=60) + p.add_argument("--batch", type=int, default=8) + p.add_argument("--lr", type=float, default=3e-5) + args = p.parse_args() + + # frame splits hard‑coded for demo + train_fnums = range(1, 141) + val_fnums = range(141, 150) + + train_full = XPointDataset(args.paramFile, train_fnums, + xptCacheDir=args.xptCacheDir, + rotateAndReflect=True) + val_full = XPointDataset(args.paramFile, val_fnums, + xptCacheDir=args.xptCacheDir) + + train_ds = XPointPatchDataset(train_full, patch=64, pos_ratio=0.8) # 64 x 64 cropping + val_ds = XPointPatchDataset(val_full, patch=64, pos_ratio=0.5) + + loader_tr = DataLoader(train_ds, batch_size=args.batch, shuffle=True) + loader_va = DataLoader(val_ds, batch_size=args.batch, shuffle=False) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = UNet().to(device) + + # estimate pos/neg for weight (rough): + pos_px = 30 * 9 * 9 + neg_px = 1024*1024 - pos_px + criterion = make_criterion(pos_weight=neg_px/pos_px) + + opt = optim.Adam(model.parameters(), lr=args.lr) + + for ep in range(1, args.epochs+1): + tr_loss = train_epoch(model, loader_tr, criterion, opt, device) + va_loss = evaluate(model, loader_va, criterion, device) + print(f"Epoch {ep:03d}: train {tr_loss:.4f} | val {va_loss:.4f}") + + # quick checkpoint + if ep % 10 == 0: + torch.save(model.state_dict(), f"chkpt_ep{ep}.pt") + +if __name__ == "__main__": + main() From aea8b59837064a7fb1906041c5302922e87d8daa Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Wed, 16 Jul 2025 04:04:08 -0400 Subject: [PATCH 2/3] feat(ci): Add smoke tests and pytest suite --- XPointMLTest.py | 145 +++++++++++++++++++++++++++++++++++++------ ci_tests.py | 153 ++++++++++++++++++++++++++++++++++++++++++++++ test_xpoint_ml.py | 135 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 416 insertions(+), 17 deletions(-) create mode 100644 ci_tests.py create mode 100644 test_xpoint_ml.py diff --git a/XPointMLTest.py b/XPointMLTest.py index 7b45044..b0c2b5d 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -20,6 +20,8 @@ from timeit import default_timer as timer +from ci_tests import SyntheticXPointDataset, test_checkpoint_functionality + def expand_xpoints_mask(binary_mask, kernel_size=9): """ Expands each X-point in a binary mask to include surrounding cells @@ -63,9 +65,12 @@ def rotate(frameData,deg): if deg not in [90, 180, 270]: print(f"invalid rotation specified... exiting") sys.exit() + psi = v2.functional.rotate(frameData["psi"], deg, v2.InterpolationMode.BILINEAR) all = v2.functional.rotate(frameData["all"], deg, v2.InterpolationMode.BILINEAR) - mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.BILINEAR) + # For mask, use nearest neighbor interpolation to preserve binary values + mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.NEAREST) + return { "fnum": frameData["fnum"], "rotation": deg, @@ -83,9 +88,11 @@ def reflect(frameData,axis): if axis not in [0,1]: print(f"invalid reflection axis specified... exiting") sys.exit() - psi = torch.flip(frameData["psi"][0], dims=(axis,)).unsqueeze(0) - all = torch.flip(frameData["all"], dims=(axis,)) - mask = torch.flip(frameData["mask"][0], dims=(axis,)).unsqueeze(0) + + psi = torch.flip(frameData["psi"], dims=(axis+1,)) + all = torch.flip(frameData["all"], dims=(axis+1,)) + mask = torch.flip(frameData["mask"], dims=(axis+1,)) + return { "fnum": frameData["fnum"], "rotation": 0, @@ -663,10 +670,19 @@ def parseCommandLineArgs(): help='create figures of the ground truth X-points and model identified X-points') parser.add_argument('--plotDir', type=Path, default="./plots", help='directory where figures are written') + + # CI TEST: Add smoke test flag + parser.add_argument('--smoke-test', action='store_true', + help='Run a minimal smoke test for CI (overrides other parameters)') + args = parser.parse_args() return args def checkCommandLineArgs(args): + # CI TEST: Skip file checks in smoke test mode + if args.smoke_test: + return + if args.xptCacheDir != None: if not args.xptCacheDir.is_dir(): print(f"Xpoint cache directory {args.xptCacheDir} does not exist. " @@ -796,6 +812,32 @@ def load_model_checkpoint(model, optimizer, checkpoint_path): def main(): args = parseCommandLineArgs() + + # CI TEST: Override parameters for smoke test + if args.smoke_test: + print("=" * 60) + print("RUNNING IN SMOKE TEST MODE FOR CI") + print("=" * 60) + + # Override with minimal parameters + args.epochs = 5 + args.batchSize = 1 + args.trainFrameFirst = 1 + args.trainFrameLast = 11 # 10 frames for training + args.validationFrameFirst = 11 + args.validationFrameLast = 12 # 1 frame for validation + args.plot = False # Disable plotting for CI + args.checkPointFrequency = 2 # Save more frequently + args.minTrainingLoss = 0 # Don't fail on convergence in smoke test + + print("Smoke test parameters:") + print(f" - Training frames: {args.trainFrameFirst} to {args.trainFrameLast-1}") + print(f" - Validation frames: {args.validationFrameFirst} to {args.validationFrameLast-1}") + print(f" - Epochs: {args.epochs}") + print(f" - Batch size: {args.batchSize}") + print(f" - Plotting disabled") + print("=" * 60) + checkCommandLineArgs(args) printCommandLineArgs(args) @@ -804,13 +846,22 @@ def main(): os.makedirs(outDir, exist_ok=True) t0 = timer() - train_fnums = range(args.trainFrameFirst, args.trainFrameLast) - val_fnums = range(args.validationFrameFirst, args.validationFrameLast) + + # CI TEST: Use synthetic data for smoke test + if args.smoke_test: + print("\nUsing synthetic data for smoke test...") + train_dataset = SyntheticXPointDataset(nframes=10, shape=(64, 64), nxpoints=3) + val_dataset = SyntheticXPointDataset(nframes=1, shape=(64, 64), nxpoints=3, seed=123) + print(f"Created synthetic datasets: {len(train_dataset)} train, {len(val_dataset)} val frames") + else: + # Original data loading + train_fnums = range(args.trainFrameFirst, args.trainFrameLast) + val_fnums = range(args.validationFrameFirst, args.validationFrameLast) - train_dataset = XPointDataset(args.paramFile, train_fnums, - xptCacheDir=args.xptCacheDir, rotateAndReflect=True) - val_dataset = XPointDataset(args.paramFile, val_fnums, - xptCacheDir=args.xptCacheDir) + train_dataset = XPointDataset(args.paramFile, train_fnums, + xptCacheDir=args.xptCacheDir, rotateAndReflect=True) + val_dataset = XPointDataset(args.paramFile, val_fnums, + xptCacheDir=args.xptCacheDir) t1 = timer() print("time (s) to create gkyl data loader: " + str(t1-t0)) @@ -833,7 +884,7 @@ def main(): train_loss = [] val_loss = [] - if os.path.exists(latest_checkpoint_path): + if os.path.exists(latest_checkpoint_path) and not args.smoke_test: model, optimizer, start_epoch, train_loss, val_loss = load_model_checkpoint( model, optimizer, latest_checkpoint_path ) @@ -844,9 +895,6 @@ def main(): t2 = timer() print("time (s) to prepare model: " + str(t2-t1)) - train_loss = [] - val_loss = [] - num_epochs = args.epochs for epoch in range(start_epoch, num_epochs): train_loss.append(train_one_epoch(model, train_loader, criterion, optimizer, device)) @@ -860,6 +908,65 @@ def main(): plot_training_history(train_loss, val_loss) print("time (s) to train model: " + str(timer()-t2)) + # CI TEST: Run additional tests if in smoke test mode + if args.smoke_test: + print("\n" + "="*60) + print("SMOKE TEST: Running additional CI tests") + print("="*60) + + # Test 1: Checkpoint save/load + checkpoint_test_passed = test_checkpoint_functionality( + model, optimizer, device, val_loader, criterion, None, UNet, optim.Adam + ) + + # Test 2: Inference test + print("Running inference test...") + model.eval() + with torch.no_grad(): + # Get one batch + test_batch = next(iter(val_loader)) + test_input = test_batch["all"].to(device) + test_output = model(test_input) + + # Apply sigmoid to get probabilities + test_probs = torch.sigmoid(test_output) + + print(f"Input shape: {test_input.shape}") + print(f"Output shape: {test_output.shape}") + print(f"Output range (logits): [{test_output.min():.3f}, {test_output.max():.3f}]") + print(f"Output range (probs): [{test_probs.min():.3f}, {test_probs.max():.3f}]") + print(f"Predicted X-points: {(test_probs > 0.5).sum().item()} pixels") + + # Test 3: Check if model learned anything + initial_train_loss = train_loss[0] if train_loss else float('inf') + final_train_loss = train_loss[-1] if train_loss else float('inf') + + print(f"\nTraining progress:") + print(f"Initial loss: {initial_train_loss:.6f}") + print(f"Final loss: {final_train_loss:.6f}") + + if final_train_loss < initial_train_loss: + print("✓ Model showed improvement during training") + training_improved = True + else: + print("✗ Model did not improve during training") + training_improved = False + + # Overall smoke test result + print("\n" + "="*60) + print("SMOKE TEST SUMMARY") + print("="*60) + print(f"Checkpoint test: {'PASSED' if checkpoint_test_passed else 'FAILED'}") + print(f"Training improvement: {'YES' if training_improved else 'NO'}") + print(f"Overall result: {'PASSED' if checkpoint_test_passed else 'FAILED'}") + print("="*60) + + # Return appropriate exit code for CI + if not checkpoint_test_passed: + return 1 + else: + return 0 + requiredLossDecreaseMagnitude = args.minTrainingLoss if np.log10(abs(train_loss[0]/train_loss[-1])) < requiredLossDecreaseMagnitude: print(f"TrainLoss reduced by less than {requiredLossDecreaseMagnitude} orders of magnitude: " @@ -872,8 +979,12 @@ def main(): interpFac = 1 # Evaluate on combined set for demonstration. Exam this part to see if save to remove - full_fnums = list(train_fnums) + list(val_fnums) - full_dataset = [train_dataset, val_dataset] + if not args.smoke_test: + train_fnums = range(args.trainFrameFirst, args.trainFrameLast) + val_fnums = range(args.validationFrameFirst, args.validationFrameLast) + full_dataset = [train_dataset, val_dataset] + else: + full_dataset = [val_dataset] # Only use validation data for smoke test t4 = timer() @@ -942,4 +1053,4 @@ def main(): print("total time (s): " + str(t5-t0)) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/ci_tests.py b/ci_tests.py new file mode 100644 index 0000000..eb31555 --- /dev/null +++ b/ci_tests.py @@ -0,0 +1,153 @@ +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim +import os + +class SyntheticXPointDataset(Dataset): + """ + Synthetic dataset for CI testing that doesn't require actual simulation data. + Creates predictable X-point patterns for testing model training pipeline. + """ + def __init__(self, nframes=2, shape=(64, 64), nxpoints=4, seed=42): + """ + nframes: Number of synthetic frames to generate + shape: Spatial dimensions (H, W) of each frame + nxpoints: Approximate number of X-points per frame + seed: Random seed for reproducibility + """ + super().__init__() + self.nframes = nframes + self.shape = shape + self.nxpoints = nxpoints + self.rng = np.random.RandomState(seed) + + #pre-generate all frames for consistency + self.data = [] + for i in range(nframes): + frame_data = self._generate_frame(i) + self.data.append(frame_data) + + def _generate_frame(self, idx): + """Generate a single synthetic frame with X-points""" + H, W = self.shape + + #create synthetic psi field with some structure + x = np.linspace(-np.pi, np.pi, W) + y = np.linspace(-np.pi, np.pi, H) + X, Y = np.meshgrid(x, y) + + #create a field with saddle points (X-points) + psi = np.sin(X + 0.1*idx) * np.cos(Y + 0.1*idx) + \ + 0.5 * np.sin(2*X) * np.cos(2*Y) + + # add some noise + psi += 0.1 * self.rng.randn(H, W) + + #create synthetic B fields (derivatives of psi) + bx = np.gradient(psi, axis=0) + by = -np.gradient(psi, axis=1) + + #create synthetic current (Laplacian of psi) + jz = -(np.gradient(np.gradient(psi, axis=0), axis=0) + + np.gradient(np.gradient(psi, axis=1), axis=1)) + + # create X-point mask + mask = np.zeros((H, W), dtype=np.float32) + + for _ in range(self.nxpoints): + x_loc = self.rng.randint(10, W-10) + y_loc = self.rng.randint(10, H-10) + # Create 9x9 region around X-point + mask[max(0, y_loc-4):min(H, y_loc+5), + max(0, x_loc-4):min(W, x_loc+5)] = 1.0 + + #Convert to torch tensors + psi_torch = torch.from_numpy(psi.astype(np.float32)).unsqueeze(0) + bx_torch = torch.from_numpy(bx.astype(np.float32)).unsqueeze(0) + by_torch = torch.from_numpy(by.astype(np.float32)).unsqueeze(0) + jz_torch = torch.from_numpy(jz.astype(np.float32)).unsqueeze(0) + all_torch = torch.cat((psi_torch, bx_torch, by_torch, jz_torch)) + mask_torch = torch.from_numpy(mask).float().unsqueeze(0) + + x_coords = np.linspace(0, 1, W) + y_coords = np.linspace(0, 1, H) + + params = { + "axesNorm": 1.0, "plotContours": 1, "colorContours": 'k', + "numContours": 50, "axisEqual": 1, "symBar": 1, "colormap": 'bwr' + } + + return { + "fnum": idx, "rotation": 0, "reflectionAxis": -1, "psi": psi_torch, + "all": all_torch, "mask": mask_torch, "x": x_coords, "y": y_coords, + "filenameBase": f"synthetic_frame_{idx}", "params": params + } + + def __len__(self): + return self.nframes + + def __getitem__(self, idx): + return self.data[idx] + +def test_checkpoint_functionality(model, optimizer, device, val_loader, criterion, scaler, UNet, Adam): + """ + Test that model can be saved and loaded correctly. + Returns True if test passes, False otherwise. + + """ + # Import locally to prevent circular dependency + from XPointMLTest import validate_one_epoch + + print("\n" + "="*60) + print("TESTING CHECKPOINT SAVE/LOAD FUNCTIONALITY") + print("="*60) + + #get initial validation loss + model.eval() + initial_loss = validate_one_epoch(model, val_loader, criterion, device) + print(f"Initial validation loss: {initial_loss:.6f}") + + #saves checkpoint + test_checkpoint_path = "smoke_test_checkpoint.pt" + checkpoint = { + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'val_loss': initial_loss, + 'test_value': 42 + } + + torch.save(checkpoint, test_checkpoint_path) + print(f"Saved checkpoint to {test_checkpoint_path}") + + # create new model and optimizer + model2 = UNet(input_channels=4, base_channels=64).to(device) + optimizer2 = Adam(model2.parameters(), lr=1e-5) + + # load checkpoint + loaded_checkpoint = torch.load(test_checkpoint_path, map_location=device, weights_only=False) + model2.load_state_dict(loaded_checkpoint['model_state_dict']) + optimizer2.load_state_dict(loaded_checkpoint['optimizer_state_dict']) + + assert loaded_checkpoint['test_value'] == 42, "Test value mismatch!" + print("Checkpoint test value verified") + + #get loaded model validation loss + model2.eval() + loaded_loss = validate_one_epoch(model2, val_loader, criterion, device) + print(f"Loaded model validation loss: {loaded_loss:.6f}") + + # check if losses match + loss_diff = abs(initial_loss - loaded_loss) + success = loss_diff < 1e-6 + if success: + print(f"✓ Checkpoint test PASSED (loss difference: {loss_diff:.2e})") + else: + print(f"✗ Checkpoint test FAILED (loss difference: {loss_diff:.2e})") + + if os.path.exists(test_checkpoint_path): + os.remove(test_checkpoint_path) + print(f"Cleaned up {test_checkpoint_path}") + + print("="*60 + "\n") + return success \ No newline at end of file diff --git a/test_xpoint_ml.py b/test_xpoint_ml.py new file mode 100644 index 0000000..842a48d --- /dev/null +++ b/test_xpoint_ml.py @@ -0,0 +1,135 @@ +import numpy as np +import torch +from torch.utils.data import DataLoader +import torch.optim as optim +import os +import pytest + +from XPointMLTest import UNet, DiceLoss, expand_xpoints_mask, validate_one_epoch +from ci_tests import SyntheticXPointDataset + +# --- Pytest Fixtures --- +@pytest.fixture +def unet_model(): + return UNet(input_channels=4, base_channels=16) + +@pytest.fixture +def dice_loss(): + return DiceLoss() + +@pytest.fixture +def synthetic_dataset(): + return SyntheticXPointDataset(nframes=2, shape=(32, 32)) + +@pytest.fixture +def synthetic_batch(synthetic_dataset): + return synthetic_dataset[0] + +# --- 1. Unit Tests (Utils & Loss Functions) --- +def test_expand_xpoints_mask(): + mask = np.zeros((20, 20)) + mask[10, 10] = 1 + expanded = expand_xpoints_mask(mask, kernel_size=5) + assert expanded.shape == (20, 20) + assert np.sum(expanded) == 25 + assert expanded[10, 10] == 1 + assert expanded[8, 8] == 1 + assert expanded[7, 7] == 0 + +def test_dice_loss_perfect_match(dice_loss): + target = torch.ones(1, 1, 10, 10) + logits = torch.full((1, 1, 10, 10), 10.0) #large positive logits + loss = dice_loss(logits, target) + #due to smoothing factor, perfect match doesn't give exactly 0 + assert loss < 1e-4, f"Loss should be near 0, got {loss.item()}" + +def test_dice_loss_no_match(dice_loss): + target = torch.zeros(1, 1, 10, 10) + logits = torch.full((1, 1, 10, 10), 10.0) + loss = dice_loss(logits, target) + expected_loss = 1.0 - (1.0 / (100 + 1.0)) + assert torch.isclose(loss, torch.tensor(expected_loss), atol=1e-3) + +# --- 2. Dataset Integrity Test --- +def test_synthetic_dataset_integrity(synthetic_dataset): + assert len(synthetic_dataset) == 2 + item = synthetic_dataset[0] + expected_keys = ["fnum", "all", "mask", "psi", "x", "y"] + assert all(key in item for key in expected_keys) + assert item['all'].shape == (4, 32, 32) + assert item['mask'].shape == (1, 32, 32) + assert item['psi'].shape == (1, 32, 32) + assert item['all'].dtype == torch.float32 + assert item['mask'].dtype == torch.float32 + +# --- 3. Model Forward/Backward Pass Test --- +def test_model_forward_backward_pass(unet_model, synthetic_batch, dice_loss): + model = unet_model + loss_fn = dice_loss + input_tensor = synthetic_batch['all'].unsqueeze(0) + target_tensor = synthetic_batch['mask'].unsqueeze(0) + prediction = model(input_tensor) + assert prediction.shape == target_tensor.shape + loss = loss_fn(prediction, target_tensor) + assert loss.item() > 0 + loss.backward() + has_grads = any(p.grad is not None for p in model.parameters()) + assert has_grads, "No gradients were computed during the backward pass." + grad_sum = sum(p.grad.sum() for p in model.parameters() if p.grad is not None) + assert grad_sum != 0, "Gradients are all zero." + +# --- 4. Standalone checkpoint test for pytest --- +def test_checkpoint_save_load(unet_model, synthetic_dataset): + """ + Standalone pytest version of checkpoint functionality test + """ + device = torch.device("cpu") + model = unet_model.to(device) + optimizer = optim.Adam(model.parameters(), lr=1e-5) + criterion = DiceLoss() + + #create a simple dataloader + val_loader = DataLoader(synthetic_dataset, batch_size=1, shuffle=False) + + #get initial loss + initial_loss = validate_one_epoch(model, val_loader, criterion, device) + + #save checkpoint + test_checkpoint_path = "test_checkpoint_pytest.pt" + checkpoint = { + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'val_loss': initial_loss, + 'test_value': 42 + } + torch.save(checkpoint, test_checkpoint_path) + + #create new model and load + model2 = UNet(input_channels=4, base_channels=16).to(device) + optimizer2 = optim.Adam(model2.parameters(), lr=1e-5) + + loaded_checkpoint = torch.load(test_checkpoint_path, map_location=device, weights_only=False) + model2.load_state_dict(loaded_checkpoint['model_state_dict']) + optimizer2.load_state_dict(loaded_checkpoint['optimizer_state_dict']) + + assert loaded_checkpoint['test_value'] == 42 + + #get loaded model loss + loaded_loss = validate_one_epoch(model2, val_loader, criterion, device) + + #check if losses match + loss_diff = abs(initial_loss - loaded_loss) + assert loss_diff < 1e-6, f"Loss difference too large: {loss_diff}" + + #cleanup + if os.path.exists(test_checkpoint_path): + os.remove(test_checkpoint_path) + +def test_model_inference(unet_model, synthetic_batch): + model = unet_model + input_tensor = synthetic_batch['all'].unsqueeze(0) + with torch.no_grad(): + output = model(input_tensor) + assert output.shape == (1, 1, 32, 32) + assert output.dtype == torch.float32 + assert torch.isfinite(output).all() \ No newline at end of file From 2fa5452d7de2f0f1040158bc5421f4f87397c233 Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Wed, 16 Jul 2025 04:42:05 -0400 Subject: [PATCH 3/3] remove strinked --- StrinkedXPoint.py | 440 ---------------------------------------------- 1 file changed, 440 deletions(-) delete mode 100644 StrinkedXPoint.py diff --git a/StrinkedXPoint.py b/StrinkedXPoint.py deleted file mode 100644 index e34300c..0000000 --- a/StrinkedXPoint.py +++ /dev/null @@ -1,440 +0,0 @@ -import numpy as np -import matplotlib.pyplot as plt -import os, errno, sys, argparse -from pathlib import Path -from timeit import default_timer as timer -import sys -import argparse - -from utils import gkData -from utils import auxFuncs -from utils import plotParams - -import torch -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F -from torch.utils.data import DataLoader, Dataset -from torchvision.transforms import v2 # rotate tensor - -def expand_xpoints_mask(binary_mask, kernel_size=9): - """ - Expands each X-point in a binary mask to include surrounding cells - in a square grid of size kernel_size x kernel_size. - - Parameters: - binary_mask : numpy.ndarray - 2D binary mask with 1s at X-point locations - kernel_size : int - Size of the square grid (must be odd number) - - Returns: - numpy.ndarray - Expanded binary mask with 1s in kernel_size×kernel_size regions around X-points - """ - - # Get shape of the original mask - h, w = binary_mask.shape - - # Create a copy to avoid modifying the original - expanded_mask = np.zeros_like(binary_mask) - - # Find coordinates of all X-points - x_points = np.argwhere(binary_mask > 0) - - # For each X-point, set a kernel_size×kernel_size area to 1 - half_size = kernel_size // 2 - for point in x_points: - # Get the corner coordinates for the square centered at the X-point - x_min = max(0, point[0] - half_size) - x_max = min(h, point[0] + half_size + 1) - y_min = max(0, point[1] - half_size) - y_max = min(w, point[1] + half_size + 1) - - # Set the square area to 1 - expanded_mask[x_min:x_max, y_min:y_max] = 1 - - return expanded_mask - -def rotate(frameData,deg): - if deg not in [90, 180, 270]: - print(f"invalid rotation specified... exiting") - sys.exit() - psi = v2.functional.rotate(frameData["psi"], deg, v2.InterpolationMode.BILINEAR) - all = v2.functional.rotate(frameData["all"], deg, v2.InterpolationMode.BILINEAR) - mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.BILINEAR) - return { - "fnum": frameData["fnum"], - "rotation": deg, - "reflectionAxis": -1, # no reflection - "psi": psi, - "all": all, - "mask": mask, - "x": frameData["x"], - "y": frameData["y"], - "filenameBase": frameData["filenameBase"], - "params": frameData["params"] - } - -def reflect(frameData,axis): - if axis not in [0,1]: - print(f"invalid reflection axis specified... exiting") - sys.exit() - psi = torch.flip(frameData["psi"][0], dims=(axis,)).unsqueeze(0) - all = torch.flip(frameData["all"], dims=(axis,)) - mask = torch.flip(frameData["mask"][0], dims=(axis,)).unsqueeze(0) - return { - "fnum": frameData["fnum"], - "rotation": 0, - "reflectionAxis": axis, - "psi": psi, - "all": all, - "mask": mask, - "x": frameData["x"], - "y": frameData["y"], - "filenameBase": frameData["filenameBase"], - "params": frameData["params"] - } - -def getPgkylData(paramFile, frameNumber, verbosity): - if verbosity > 0: - print(f"=== frame {frameNumber} ===") - params = {} #Initialize dictionary to store plotting and other parameters - params["polyOrderOverride"] = 0 #Override default dg interpolation and interpolate to given number of points - constrcutBandJ = 1 - #Read vector potential - var = gkData.gkData(str(paramFile),frameNumber,'psi',params).compactRead() - psi = var.data - coords = var.coords - axesNorm = var.d[ var.speciesFileIndex.index('ion') ] - if verbosity > 0: - print(f"psi shape: {psi.shape}, min={psi.min()}, max={psi.max()}") - #Construct B and J (first and second derivatives) - [df_dx,df_dy,df_dz] = auxFuncs.genGradient(psi,var.dx) - [d2f_dxdx,d2f_dxdy,d2f_dxdz] = auxFuncs.genGradient(df_dx,var.dx) - [d2f_dydx,d2f_dydy,d2f_dydz] = auxFuncs.genGradient(df_dy,var.dx) - bx = df_dy - by = -df_dx - jz = -(d2f_dxdx + d2f_dydy) / var.mu0 - del df_dx,df_dy,df_dz,d2f_dxdx,d2f_dxdy,d2f_dxdz,d2f_dydx,d2f_dydy,d2f_dydz - #Indicies of critical points, X points, and O points (max and min) - critPoints = auxFuncs.getCritPoints(psi) - [xpts, optsMax, optsMin] = auxFuncs.getXOPoints(psi, critPoints) - return [var.filenameBase, axesNorm, critPoints, xpts, optsMax, optsMin, coords, psi, bx, by, jz] - -def cachedPgkylDataExists(cacheDir, frameNumber, fieldName): - if cacheDir == None: - return False - else: - cachedFrame = cacheDir / f"{frameNumber}_{fieldName}.npy" - return cachedFrame.exists(); - -def loadPgkylDataFromCache(cacheDir, frameNumber, fields): - outFields = {} - if cacheDir != None: - for name in fields.keys(): - if name == "fileName": - with open(cacheDir / f"{frameNumber}_{name}.txt", "r") as file: - outFields[name] = file.read().rstrip() - else: - outFields[name] = np.load(cacheDir / f"{frameNumber}_{name}.npy") - return outFields - else: - return None - -def writePgkylDataToCache(cacheDir, frameNumber, fields): - if cacheDir != None: - for name, field in fields.items(): - if name == "fileName": - with open(cacheDir / f"{frameNumber}_{name}.txt", "w") as text_file: - text_file.write(f"{field}") - else: - np.save(cacheDir / f"{frameNumber}_{name}.npy",field) - - -class XPointDataset(Dataset): - """ - Dataset that processes frames in [fnumList]. For each frame (fnum): - - Sets up "params" according to your snippet. - - Reads psi from gkData (varid='psi') - - Finds X-points -> builds a 2D binary mask. - - Returns (psiTensor, maskTensor) as a PyTorch (float) pair. - """ - def __init__(self, paramFile, fnumList, xptCacheDir=None, - rotateAndReflect=False, verbosity=0): - """ - paramFile: Path to parameter file (string). - fnumList: List of frames to iterate. - """ - super().__init__() - self.paramFile = paramFile - self.fnumList = list(fnumList) # ensure indexable - self.xptCacheDir = xptCacheDir - self.verbosity = verbosity - - # We'll store a base 'params' once here, and then customize in __getitem__: - self.params = {} - # Default snippet-based constants: - self.params["lowerLimits"] = [-1e6, -1e6, -0.e6, -1.e6, -1.e6] - self.params["upperLimits"] = [1e6, 1e6, 0.e6, 1.e6, 1.e6] - self.params["restFrame"] = 1 - self.params["polyOrderOverride"] = 0 - - self.params["plotContours"] = 1 - self.params["colorContours"] = 'k' - self.params["numContours"] = 50 - self.params["axisEqual"] = 1 - self.params["symBar"] = 1 - self.params["colormap"] = 'bwr' - - - # load all the data - self.data = [] - for fnum in fnumList: - frameData = self.load(fnum) - self.data.append(frameData) - if rotateAndReflect: - self.data.append(rotate(frameData,90)) - self.data.append(rotate(frameData,180)) - self.data.append(rotate(frameData,270)) - self.data.append(reflect(frameData,0)) - self.data.append(reflect(frameData,1)) - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - return self.data[idx] - - def load(self, fnum): - t0 = timer() - - # check if cache exists - if self.xptCacheDir != None: - if not self.xptCacheDir.is_dir(): - print(f"Xpoint cache directory {self.xptCacheDir} does not exist... exiting") - sys.exit() - t2 = timer() - - fields = {"psi":None, - "critPts":None, - "xpts":None, - "optsMax":None, - "optsMin":None, - "axesNorm":None, - "coords":None, - "fileName":None, - "Bx":None, "By":None, - "Jz":None} - - # Indicies of critical points, X points, and O points (max and min) - if self.xptCacheDir != None and cachedPgkylDataExists(self.xptCacheDir, fnum, "psi"): - fields = loadPgkylDataFromCache(self.xptCacheDir, fnum, fields) - else: - [fileName, axesNorm, critPoints, xpts, optsMax, optsMin, coords, psi, bx, by, jz] = getPgkylData(self.paramFile, fnum, verbosity=self.verbosity) - fields = {"psi":psi, "critPts":critPoints, "xpts":xpts, - "optsMax":optsMax, "optsMin":optsMin, - "axesNorm": axesNorm, "coords": coords, - "fileName": fileName, - "Bx":bx, "By":by, "Jz":jz} - writePgkylDataToCache(self.xptCacheDir, fnum, fields) - self.params["axesNorm"] = fields["axesNorm"] - - if self.verbosity > 0: - print("time (s) to find X and O points: " + str(timer()-t2)) - - # Create array of 0s with 1s only at X points - binaryMap = np.zeros(np.shape(fields["psi"])) - binaryMap[fields["xpts"][:, 0], fields["xpts"][:, 1]] = 1 - - binaryMap = expand_xpoints_mask(binaryMap, kernel_size=9) - - # -------------- 6) Convert to Torch Tensors -------------- - psi_torch = torch.from_numpy(fields["psi"]).float().unsqueeze(0) # [1, Nx, Ny] - bx_torch = torch.from_numpy(fields["Bx"]).float().unsqueeze(0) - by_torch = torch.from_numpy(fields["By"]).float().unsqueeze(0) - jz_torch = torch.from_numpy(fields["Jz"]).float().unsqueeze(0) - all_torch = torch.cat((psi_torch,bx_torch,by_torch,jz_torch)) # [4, Nx, Ny] - mask_torch = torch.from_numpy(binaryMap).float().unsqueeze(0) # [1, Nx, Ny] - - if self.verbosity > 0: - print("time (s) to load and process gkyl frame: " + str(timer()-t0)) - - return { - "fnum": fnum, - "rotation": 0, - "reflectionAxis": -1, # no reflection - "psi": psi_torch, # shape [1, Nx, Ny] - "all": all_torch, # shape [4, Nx, Ny] - "mask": mask_torch, # shape [1, Nx, Ny] - "x": fields["coords"][0], - "y": fields["coords"][1], - "filenameBase": fields["fileName"], - "params": dict(self.params) # copy of the params for local plotting - } - - - -class XPointPatchDataset(Dataset): - """On‑the‑fly square crops, balancing positive / background patches.""" - def __init__(self, base_ds, patch=64, pos_ratio=0.6, retries=20): - self.base_ds = base_ds - self.patch = patch - self.pos_ratio = pos_ratio - self.retries = retries - self.rng = np.random.default_rng() - - def __len__(self): - # give each full frame K random crops per epoch (K=16 by default) - return len(self.base_ds) * 16 - - def _crop(self, arr, top, left): - return arr[..., top:top+self.patch, left:left+self.patch] - - def __getitem__(self, _): - frame = self.base_ds[self.rng.integers(len(self.base_ds))] - H, W = frame["mask"].shape[-2:] - - for attempt in range(self.retries): - y0 = self.rng.integers(0, H - self.patch + 1) - x0 = self.rng.integers(0, W - self.patch + 1) - crop_mask = self._crop(frame["mask"], y0, x0) - has_pos = crop_mask.sum() > 0 - want_pos = (attempt / self.retries) < self.pos_ratio - - if has_pos == want_pos or attempt == self.retries - 1: - return { - "all" : self._crop(frame["all"], y0, x0), - "mask": crop_mask - } - - -class UNet(nn.Module): - def __init__(self, input_channels=4, base=64): - super().__init__() - self.enc1 = self._dbl(input_channels, base, dilation=1) - self.enc2 = self._dbl(base, base*2, dilation=1) - self.enc3 = self._dbl(base*2, base*4, dilation=2) # ← dilated - self.pool = nn.MaxPool2d(2, 2) - self.bottleneck = self._dbl(base*4, base*8, dilation=4) # ← dilated - self.up3 = nn.ConvTranspose2d(base*8, base*4, 2, 2) - self.dec3 = self._dbl(base*8, base*4, dilation=1) - self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, 2) - self.dec2 = self._dbl(base*4, base*2, dilation=1) - self.up1 = nn.ConvTranspose2d(base*2, base, 2, 2) - self.dec1 = self._dbl(base*2, base, dilation=1) - self.out = nn.Conv2d(base, 1, 1) - - @staticmethod - def _dbl(inp, out, dilation=1): - pad = dilation - return nn.Sequential( - nn.Conv2d(inp, out, 3, padding=pad, dilation=dilation), - nn.ReLU(inplace=True), - nn.Conv2d(out, out, 3, padding=pad, dilation=dilation), - nn.ReLU(inplace=True) - ) - - def forward(self, x): - e1 = self.enc1(x) - e2 = self.enc2(self.pool(e1)) - e3 = self.enc3(self.pool(e2)) - b = self.bottleneck(self.pool(e3)) - d3 = self.dec3(torch.cat([self.up3(b), e3], 1)) - d2 = self.dec2(torch.cat([self.up2(d3), e2], 1)) - d1 = self.dec1(torch.cat([self.up1(d2), e1], 1)) - return self.out(d1) - - - -class DiceLoss(nn.Module): - def __init__(self, smooth=1.): - super().__init__() - self.smooth = smooth - - def forward(self, logits, targets): - probs = torch.sigmoid(logits) - probs = probs.view(-1); targets = targets.view(-1) - inter = (probs * targets).sum() - union = probs.sum() + targets.sum() - dice = (2*inter + self.smooth) / (union + self.smooth) - return 1 - dice - - -def make_criterion(pos_weight): - bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]).float()) - dice = DiceLoss() - def _loss(logits, target): - return 0.5 * bce(logits, target) + 0.5 * dice(logits, target) - return _loss - - - -@torch.no_grad() -def evaluate(model, loader, criterion, device): - model.eval(); loss = 0 - for batch in loader: - x, y = batch["all"].to(device), batch["mask"].to(device) - logits = model(x) - loss += criterion(logits, y).item() - return loss / len(loader) - - -def train_epoch(model, loader, criterion, opt, device): - model.train(); loss = 0 - for batch in loader: - x, y = batch["all"].to(device), batch["mask"].to(device) - opt.zero_grad() - l = criterion(model(x), y) - l.backward(); opt.step(); loss += l.item() - return loss / len(loader) - - - - -def main(): - p = argparse.ArgumentParser() - p.add_argument("--paramFile", type=Path, required=True) - p.add_argument("--xptCacheDir", type=Path) - p.add_argument("--epochs", type=int, default=60) - p.add_argument("--batch", type=int, default=8) - p.add_argument("--lr", type=float, default=3e-5) - args = p.parse_args() - - # frame splits hard‑coded for demo - train_fnums = range(1, 141) - val_fnums = range(141, 150) - - train_full = XPointDataset(args.paramFile, train_fnums, - xptCacheDir=args.xptCacheDir, - rotateAndReflect=True) - val_full = XPointDataset(args.paramFile, val_fnums, - xptCacheDir=args.xptCacheDir) - - train_ds = XPointPatchDataset(train_full, patch=64, pos_ratio=0.8) # 64 x 64 cropping - val_ds = XPointPatchDataset(val_full, patch=64, pos_ratio=0.5) - - loader_tr = DataLoader(train_ds, batch_size=args.batch, shuffle=True) - loader_va = DataLoader(val_ds, batch_size=args.batch, shuffle=False) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = UNet().to(device) - - # estimate pos/neg for weight (rough): - pos_px = 30 * 9 * 9 - neg_px = 1024*1024 - pos_px - criterion = make_criterion(pos_weight=neg_px/pos_px) - - opt = optim.Adam(model.parameters(), lr=args.lr) - - for ep in range(1, args.epochs+1): - tr_loss = train_epoch(model, loader_tr, criterion, opt, device) - va_loss = evaluate(model, loader_va, criterion, device) - print(f"Epoch {ep:03d}: train {tr_loss:.4f} | val {va_loss:.4f}") - - # quick checkpoint - if ep % 10 == 0: - torch.save(model.state_dict(), f"chkpt_ep{ep}.pt") - -if __name__ == "__main__": - main()