diff --git a/dquartic/cli.py b/dquartic/cli.py index 1118b47..ea2b2c9 100644 --- a/dquartic/cli.py +++ b/dquartic/cli.py @@ -77,12 +77,13 @@ def train( parquet_directory = config['data']['parquet_directory'] ms2_data_path = config['data']['ms2_data_path'] ms1_data_path = config['data']['ms1_data_path'] + feature_mask_file=config["data"]["feature_mask_path"], batch_size = config['model']['batch_size'] checkpoint_path = config['model']['checkpoint_path'] use_wandb = config['wandb']['use_wandb'] threads = config['threads'] - dataset = DIAMSDataset(parquet_directory, ms2_data_path, ms1_data_path, normalize=config['data']['normalize']) + dataset = DIAMSDataset(parquet_directory, ms2_data_path, ms1_data_path, feature_mask_file, normalize=config['data']['normalize']) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=threads) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/dquartic/utils/config_loader.py b/dquartic/utils/config_loader.py index d1faba5..190c70a 100644 --- a/dquartic/utils/config_loader.py +++ b/dquartic/utils/config_loader.py @@ -24,6 +24,9 @@ def load_train_config(config_path: str, **kwargs): if "ms1_data_path" not in config_params["data"]: config_params["data"]["ms1_data_path"] = None + + if "feature_mask_path" not in config_params["data"]: + config_params["data"]["feature_mask_path"] = None # Override the config params with the keyword arguments if "parquet_directory" in kwargs: @@ -37,6 +40,10 @@ def load_train_config(config_path: str, **kwargs): if "ms1_data_path" in kwargs: if kwargs["ms1_data_path"] is not None: config_params["data"]["ms1_data_path"] = kwargs["ms1_data_path"] + + if "feature_mask_path" in kwargs: + if kwargs["feature_mask_path"] is not None: + config_params["data"]["feature_mask_path"] = kwargs["feature_mask_path"] if "batch_size" in kwargs: if kwargs["batch_size"] is not None: @@ -69,6 +76,7 @@ def generate_train_config(config_path: str): "parquet_directory": "data/", "ms2_data_path": None, "ms1_data_path": None, + "feature_mask_path": None, # Add feature mask path "normalize": "minmax", }, "model": { diff --git a/dquartic/utils/data_loader.py b/dquartic/utils/data_loader.py index 454a987..f0a8fbf 100644 --- a/dquartic/utils/data_loader.py +++ b/dquartic/utils/data_loader.py @@ -30,12 +30,20 @@ class DIAMSDataset(Dataset): __getitem__(idx): Retrieves an item from the dataset. """ - def __init__(self, parquet_directory=None, ms2_file=None, ms1_file=None, normalize: Literal[None, "minmax"] = None): + def __init__(self, parquet_directory=None, ms2_file=None, ms1_file=None, feature_mask_file=None, normalize: Literal[None, "minmax"] = None): if parquet_directory is None and ms1_file is not None and ms2_file is not None: self.ms2_data = np.load(ms2_file, mmap_mode="r") - self.ms1_data = np.load(ms1_file, mmap_mode="r") - self.data_type = "npy" - print(f"Info: Loaded {len(self.ms2_data)} MS2 slice samples and {len(self.ms1_data)} MS1 slice samples from NPY files.") + # Use feature mask instead of MS1 if provided + if feature_mask_file is not None: + self.ms1_data = np.load(feature_mask_file, mmap_mode="r") + print(f"Info: Using feature masks instead of MS1 data") + self.data_type = "npy" + print(f"Info: Loaded {len(self.ms2_data)} MS2 samples and {len(self.ms1_data)} conditioning samples") + else: + self.ms1_data = np.load(ms1_file, mmap_mode="r") + self.data_type = "npy" + print(f"Info: Loaded {len(self.ms2_data)} MS2 slice samples and {len(self.ms1_data)} MS1 slice samples from NPY files.") + elif parquet_directory is not None and ms1_file is None and ms2_file is None: self.meta_df = self.read_parquet_meta(parquet_directory) self.parquet_directory = parquet_directory