From e2fd0f110778a9b3ed292f729e4d624127da3daf Mon Sep 17 00:00:00 2001 From: Dev Kumar Acharya Date: Fri, 14 Mar 2025 19:57:59 +0530 Subject: [PATCH 1/2] Add precursor_cond support in p_sample, sample, and train_step (update docstrings) --- dquartic/model/model.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/dquartic/model/model.py b/dquartic/model/model.py index 64d223e..140ce2c 100644 --- a/dquartic/model/model.py +++ b/dquartic/model/model.py @@ -241,7 +241,7 @@ def q_sample(self, x_0, t, noise=None): return sqrt_alpha_bar_t * x_0 + sqrt_one_minus_alpha_bar_t * noise - def p_sample(self, x_t, t, init_cond=None, attn_cond=None): + def p_sample(self, x_t, t, init_cond=None, attn_cond=None, precursor_cond=None): """ Performs a reverse sampling step, estimating the original input x_0 or the noise epsilon. @@ -253,6 +253,7 @@ def p_sample(self, x_t, t, init_cond=None, attn_cond=None): t (torch.Tensor): The current timestep index. init_cond (torch.Tensor, optional): Initial conditions to use for prediction. attn_cond (torch.Tensor, optional): Attention conditions to use for semantic guidance. + precursor_cond (torch.Tensor, optional): Precursor feature masks used as additional conditioning signals. Returns: torch.Tensor: The estimated previous state x_{t-1}. @@ -268,12 +269,12 @@ def p_sample(self, x_t, t, init_cond=None, attn_cond=None): if self.pred_type == "eps": # Predict noise - eps_pred = self.model(x_t, t_tensor, init_cond, attn_cond) + eps_pred = self.model(x_t, t_tensor, init_cond, attn_cond, precursor_cond) # Compute x_0 prediction x0_pred = (x_t - sqrt_one_minus_alpha_bar_t * eps_pred) / sqrt_alpha_bar_t elif self.pred_type == "x0": # Predict x_0 directly - x0_pred = self.model(x_t, t_tensor, init_cond, attn_cond) + x0_pred = self.model(x_t, t_tensor, init_cond, attn_cond, precursor_cond) # Compute eps_pred from x0_pred eps_pred = (x_t - sqrt_alpha_bar_t * x0_pred) / sqrt_one_minus_alpha_bar_t else: @@ -290,9 +291,10 @@ def p_sample(self, x_t, t, init_cond=None, attn_cond=None): return x_t_prev, eps_pred - def sample(self, x_t, ms2_cond=None, ms1_cond=None, num_steps=1000): + def sample(self, x_t, ms2_cond=None, ms1_cond=None, precursor_cond=None, num_steps=1000): """ - Generates samples from the model starting from a noisy input x_t, optionally conditioned on MS2 and MS1 data. + Generates samples from the model starting from a noisy input x_t, optionally conditioned on MS2, MS1, + and precursor feature masks. This function iteratively applies the reverse diffusion process (`p_sample`) to generate samples moving from noisy data back towards the original data distribution. @@ -301,6 +303,7 @@ def sample(self, x_t, ms2_cond=None, ms1_cond=None, num_steps=1000): x_t (torch.Tensor): The initial noisy input tensor. ms2_cond (torch.Tensor, optional): MS2 mixture data maps for conditioning. ms1_cond (torch.Tensor, optional): Clean MS1 data maps for conditioning. + precursor_cond (torch.Tensor, optional): Precursor feature masks used as additional conditioning signals. num_steps (int): The number of reverse sampling steps to perform. Returns: @@ -309,6 +312,7 @@ def sample(self, x_t, ms2_cond=None, ms1_cond=None, num_steps=1000): """ ms2_cond = self.normalize(ms2_cond) if ms2_cond is not None else None ms1_cond = self.normalize(ms1_cond) if ms1_cond is not None else None + precursor_cond = self.normalize(precursor_cond) if precursor_cond is not None else None pred_noise = None time_steps = torch.linspace(self.num_timesteps - 1, 0, num_steps, dtype=torch.long) @@ -323,7 +327,7 @@ def sample(self, x_t, ms2_cond=None, ms1_cond=None, num_steps=1000): return x_t, pred_noise - def train_step(self, x_0, ms2_cond=None, ms1_cond=None, noise=None, ms1_loss_weight=0.0): + def train_step(self, x_0, ms2_cond=None, ms1_cond=None, precursor_cond=None, noise=None, ms1_loss_weight=0.0): """ Performs a single training step using the specified input data, optionally with additional MS1 loss weighting. @@ -334,6 +338,7 @@ def train_step(self, x_0, ms2_cond=None, ms1_cond=None, noise=None, ms1_loss_wei x_0 (torch.Tensor): The clean MS2 data maps (original data). ms2_cond (torch.Tensor, optional): MS2 mixture data maps for additional conditioning. ms1_cond (torch.Tensor, optional): Clean MS1 data maps for additional conditioning. + precursor_cond (torch.Tensor, optional): Precursor feature masks used as additional conditioning signals. noise (torch.Tensor, optional): Noise tensor to use during forward diffusion. If None, noise is sampled randomly. ms1_loss_weight (float): Weighting factor for an additional MS1-specific loss component. @@ -356,7 +361,7 @@ def train_step(self, x_0, ms2_cond=None, ms1_cond=None, noise=None, ms1_loss_wei if self.pred_type == "eps": # Predict noise - eps_pred = self.model(x_t, t, ms2_cond, ms1_cond) + eps_pred = self.model(x_t, t, ms2_cond, ms1_cond, precursor_cond) # Compute primary loss between predicted noise and true noise primary_loss = F.mse_loss(eps_pred, noise) @@ -371,7 +376,7 @@ def train_step(self, x_0, ms2_cond=None, ms1_cond=None, noise=None, ms1_loss_wei ) elif self.pred_type == "x0": # Predict x0 - x0_pred = self.model(x_t, t, ms2_cond, ms1_cond) + x0_pred = self.model(x_t, t, ms2_cond, ms1_cond, precursor_cond) # Compute primary loss between predicted x0 and true x0 primary_loss = F.mse_loss(x0_pred, x_0) From 7bd4889d9bae750edaeeb15cfd3968a2e2adf971 Mon Sep 17 00:00:00 2001 From: Dev Kumar Acharya Date: Tue, 1 Apr 2025 22:28:09 +0530 Subject: [PATCH 2/2] Add feature mask support for conditioning (#20) --- dquartic/cli.py | 3 ++- dquartic/model/model.py | 21 ++++++++------------- dquartic/utils/config_loader.py | 8 ++++++++ dquartic/utils/data_loader.py | 16 ++++++++++++---- 4 files changed, 30 insertions(+), 18 deletions(-) diff --git a/dquartic/cli.py b/dquartic/cli.py index 1118b47..ea2b2c9 100644 --- a/dquartic/cli.py +++ b/dquartic/cli.py @@ -77,12 +77,13 @@ def train( parquet_directory = config['data']['parquet_directory'] ms2_data_path = config['data']['ms2_data_path'] ms1_data_path = config['data']['ms1_data_path'] + feature_mask_file=config["data"]["feature_mask_path"], batch_size = config['model']['batch_size'] checkpoint_path = config['model']['checkpoint_path'] use_wandb = config['wandb']['use_wandb'] threads = config['threads'] - dataset = DIAMSDataset(parquet_directory, ms2_data_path, ms1_data_path, normalize=config['data']['normalize']) + dataset = DIAMSDataset(parquet_directory, ms2_data_path, ms1_data_path, feature_mask_file, normalize=config['data']['normalize']) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=threads) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/dquartic/model/model.py b/dquartic/model/model.py index 140ce2c..64d223e 100644 --- a/dquartic/model/model.py +++ b/dquartic/model/model.py @@ -241,7 +241,7 @@ def q_sample(self, x_0, t, noise=None): return sqrt_alpha_bar_t * x_0 + sqrt_one_minus_alpha_bar_t * noise - def p_sample(self, x_t, t, init_cond=None, attn_cond=None, precursor_cond=None): + def p_sample(self, x_t, t, init_cond=None, attn_cond=None): """ Performs a reverse sampling step, estimating the original input x_0 or the noise epsilon. @@ -253,7 +253,6 @@ def p_sample(self, x_t, t, init_cond=None, attn_cond=None, precursor_cond=None): t (torch.Tensor): The current timestep index. init_cond (torch.Tensor, optional): Initial conditions to use for prediction. attn_cond (torch.Tensor, optional): Attention conditions to use for semantic guidance. - precursor_cond (torch.Tensor, optional): Precursor feature masks used as additional conditioning signals. Returns: torch.Tensor: The estimated previous state x_{t-1}. @@ -269,12 +268,12 @@ def p_sample(self, x_t, t, init_cond=None, attn_cond=None, precursor_cond=None): if self.pred_type == "eps": # Predict noise - eps_pred = self.model(x_t, t_tensor, init_cond, attn_cond, precursor_cond) + eps_pred = self.model(x_t, t_tensor, init_cond, attn_cond) # Compute x_0 prediction x0_pred = (x_t - sqrt_one_minus_alpha_bar_t * eps_pred) / sqrt_alpha_bar_t elif self.pred_type == "x0": # Predict x_0 directly - x0_pred = self.model(x_t, t_tensor, init_cond, attn_cond, precursor_cond) + x0_pred = self.model(x_t, t_tensor, init_cond, attn_cond) # Compute eps_pred from x0_pred eps_pred = (x_t - sqrt_alpha_bar_t * x0_pred) / sqrt_one_minus_alpha_bar_t else: @@ -291,10 +290,9 @@ def p_sample(self, x_t, t, init_cond=None, attn_cond=None, precursor_cond=None): return x_t_prev, eps_pred - def sample(self, x_t, ms2_cond=None, ms1_cond=None, precursor_cond=None, num_steps=1000): + def sample(self, x_t, ms2_cond=None, ms1_cond=None, num_steps=1000): """ - Generates samples from the model starting from a noisy input x_t, optionally conditioned on MS2, MS1, - and precursor feature masks. + Generates samples from the model starting from a noisy input x_t, optionally conditioned on MS2 and MS1 data. This function iteratively applies the reverse diffusion process (`p_sample`) to generate samples moving from noisy data back towards the original data distribution. @@ -303,7 +301,6 @@ def sample(self, x_t, ms2_cond=None, ms1_cond=None, precursor_cond=None, num_ste x_t (torch.Tensor): The initial noisy input tensor. ms2_cond (torch.Tensor, optional): MS2 mixture data maps for conditioning. ms1_cond (torch.Tensor, optional): Clean MS1 data maps for conditioning. - precursor_cond (torch.Tensor, optional): Precursor feature masks used as additional conditioning signals. num_steps (int): The number of reverse sampling steps to perform. Returns: @@ -312,7 +309,6 @@ def sample(self, x_t, ms2_cond=None, ms1_cond=None, precursor_cond=None, num_ste """ ms2_cond = self.normalize(ms2_cond) if ms2_cond is not None else None ms1_cond = self.normalize(ms1_cond) if ms1_cond is not None else None - precursor_cond = self.normalize(precursor_cond) if precursor_cond is not None else None pred_noise = None time_steps = torch.linspace(self.num_timesteps - 1, 0, num_steps, dtype=torch.long) @@ -327,7 +323,7 @@ def sample(self, x_t, ms2_cond=None, ms1_cond=None, precursor_cond=None, num_ste return x_t, pred_noise - def train_step(self, x_0, ms2_cond=None, ms1_cond=None, precursor_cond=None, noise=None, ms1_loss_weight=0.0): + def train_step(self, x_0, ms2_cond=None, ms1_cond=None, noise=None, ms1_loss_weight=0.0): """ Performs a single training step using the specified input data, optionally with additional MS1 loss weighting. @@ -338,7 +334,6 @@ def train_step(self, x_0, ms2_cond=None, ms1_cond=None, precursor_cond=None, noi x_0 (torch.Tensor): The clean MS2 data maps (original data). ms2_cond (torch.Tensor, optional): MS2 mixture data maps for additional conditioning. ms1_cond (torch.Tensor, optional): Clean MS1 data maps for additional conditioning. - precursor_cond (torch.Tensor, optional): Precursor feature masks used as additional conditioning signals. noise (torch.Tensor, optional): Noise tensor to use during forward diffusion. If None, noise is sampled randomly. ms1_loss_weight (float): Weighting factor for an additional MS1-specific loss component. @@ -361,7 +356,7 @@ def train_step(self, x_0, ms2_cond=None, ms1_cond=None, precursor_cond=None, noi if self.pred_type == "eps": # Predict noise - eps_pred = self.model(x_t, t, ms2_cond, ms1_cond, precursor_cond) + eps_pred = self.model(x_t, t, ms2_cond, ms1_cond) # Compute primary loss between predicted noise and true noise primary_loss = F.mse_loss(eps_pred, noise) @@ -376,7 +371,7 @@ def train_step(self, x_0, ms2_cond=None, ms1_cond=None, precursor_cond=None, noi ) elif self.pred_type == "x0": # Predict x0 - x0_pred = self.model(x_t, t, ms2_cond, ms1_cond, precursor_cond) + x0_pred = self.model(x_t, t, ms2_cond, ms1_cond) # Compute primary loss between predicted x0 and true x0 primary_loss = F.mse_loss(x0_pred, x_0) diff --git a/dquartic/utils/config_loader.py b/dquartic/utils/config_loader.py index d1faba5..190c70a 100644 --- a/dquartic/utils/config_loader.py +++ b/dquartic/utils/config_loader.py @@ -24,6 +24,9 @@ def load_train_config(config_path: str, **kwargs): if "ms1_data_path" not in config_params["data"]: config_params["data"]["ms1_data_path"] = None + + if "feature_mask_path" not in config_params["data"]: + config_params["data"]["feature_mask_path"] = None # Override the config params with the keyword arguments if "parquet_directory" in kwargs: @@ -37,6 +40,10 @@ def load_train_config(config_path: str, **kwargs): if "ms1_data_path" in kwargs: if kwargs["ms1_data_path"] is not None: config_params["data"]["ms1_data_path"] = kwargs["ms1_data_path"] + + if "feature_mask_path" in kwargs: + if kwargs["feature_mask_path"] is not None: + config_params["data"]["feature_mask_path"] = kwargs["feature_mask_path"] if "batch_size" in kwargs: if kwargs["batch_size"] is not None: @@ -69,6 +76,7 @@ def generate_train_config(config_path: str): "parquet_directory": "data/", "ms2_data_path": None, "ms1_data_path": None, + "feature_mask_path": None, # Add feature mask path "normalize": "minmax", }, "model": { diff --git a/dquartic/utils/data_loader.py b/dquartic/utils/data_loader.py index 454a987..f0a8fbf 100644 --- a/dquartic/utils/data_loader.py +++ b/dquartic/utils/data_loader.py @@ -30,12 +30,20 @@ class DIAMSDataset(Dataset): __getitem__(idx): Retrieves an item from the dataset. """ - def __init__(self, parquet_directory=None, ms2_file=None, ms1_file=None, normalize: Literal[None, "minmax"] = None): + def __init__(self, parquet_directory=None, ms2_file=None, ms1_file=None, feature_mask_file=None, normalize: Literal[None, "minmax"] = None): if parquet_directory is None and ms1_file is not None and ms2_file is not None: self.ms2_data = np.load(ms2_file, mmap_mode="r") - self.ms1_data = np.load(ms1_file, mmap_mode="r") - self.data_type = "npy" - print(f"Info: Loaded {len(self.ms2_data)} MS2 slice samples and {len(self.ms1_data)} MS1 slice samples from NPY files.") + # Use feature mask instead of MS1 if provided + if feature_mask_file is not None: + self.ms1_data = np.load(feature_mask_file, mmap_mode="r") + print(f"Info: Using feature masks instead of MS1 data") + self.data_type = "npy" + print(f"Info: Loaded {len(self.ms2_data)} MS2 samples and {len(self.ms1_data)} conditioning samples") + else: + self.ms1_data = np.load(ms1_file, mmap_mode="r") + self.data_type = "npy" + print(f"Info: Loaded {len(self.ms2_data)} MS2 slice samples and {len(self.ms1_data)} MS1 slice samples from NPY files.") + elif parquet_directory is not None and ms1_file is None and ms2_file is None: self.meta_df = self.read_parquet_meta(parquet_directory) self.parquet_directory = parquet_directory