Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ideeplc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""iDeepLC: A deep Learning-based retention time predictor for unseen modified peptides with a novel encoding system"""

__version__ = "1.3.1"
__version__ = "1.3.2"
112 changes: 98 additions & 14 deletions ideeplc/data_initialize.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import logging
from typing import Tuple, Union
from typing import Tuple, Union, Iterator
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset
from ideeplc.utilities import df_to_matrix, reform_seq

LOGGER = logging.getLogger(__name__)


# Making the pytorch dataset
class MyDataset(Dataset):
def __init__(self, sequences: np.ndarray, retention: np.ndarray) -> None:
self.sequences = sequences
Expand All @@ -25,15 +24,14 @@ def data_initialize(
csv_path: str, **kwargs
) -> Union[Tuple[MyDataset, np.ndarray], Tuple[MyDataset, np.ndarray]]:
"""
Initialize peptides matrices based on a CSV file containing raw peptide sequences.
Initialize peptide matrices based on a CSV file containing raw peptide sequences.

:param csv_path: Path to the CSV file containing raw peptide sequences.
:return: DataLoader for prediction.
:return: Dataset for prediction or fine-tuning and x_shape.
"""

LOGGER.info(f"Loading peptides from {csv_path}")

try:
# Load peptides from CSV file
df = pd.read_csv(csv_path)
except FileNotFoundError:
LOGGER.error(f"File {csv_path} not found.")
Expand Down Expand Up @@ -63,22 +61,108 @@ def data_initialize(
LOGGER.info(
f"Loaded and reformed {len(reformed_peptides)} peptides sequences from the file."
)

try:
# Convert sequences to matrix format
sequences, tr, errors = df_to_matrix(reformed_peptides, df)
except Exception as e:
LOGGER.error(f"Error converting sequences to matrix format: {e}")
raise

if errors:
LOGGER.warning(f"Errors encountered during conversion: {errors}")

prediction_dataset = MyDataset(sequences, tr)

# Create DataLoader objects
dataloader_pred = DataLoader(prediction_dataset)
# passing the training X shape
for batch in dataloader_pred:
x_shape = batch[0].shape
break
if len(prediction_dataset) == 0:
LOGGER.error("No valid peptide entries were found in the input file.")
raise ValueError("No valid peptide entries were found in the input file.")

# Keep historical x_shape contract expected by model/tests: (batch, channels, length)
x_shape = (1,) + prediction_dataset[0][0].shape
LOGGER.info(f"Dataset initialized with data shape {x_shape}.")
return prediction_dataset, x_shape


def data_initialize_chunked(
csv_path: str, chunk_size: int = 10000, **kwargs
) -> Iterator[Tuple[pd.DataFrame, MyDataset, np.ndarray]]:
"""
Initialize peptide matrices from a CSV file in chunks.

:param csv_path: Path to the CSV file containing raw peptide sequences.
:param chunk_size: Number of rows to load per chunk.
:return: Iterator yielding dataframe chunk, dataset chunk, and x_shape.
"""
LOGGER.info(f"Loading peptides from {csv_path} in chunks of {chunk_size}")

try:
chunk_iter = pd.read_csv(csv_path, chunksize=chunk_size)
except FileNotFoundError:
LOGGER.error(f"File {csv_path} not found.")
raise
except pd.errors.EmptyDataError:
LOGGER.error(f"File {csv_path} is empty.")
raise
except Exception as e:
LOGGER.error(f"Error reading {csv_path}: {e}")
raise

for chunk_idx, df in enumerate(chunk_iter, start=1):
if "seq" not in df.columns:
LOGGER.error("CSV file must contain a 'seq' column with peptide sequences.")
raise ValueError("Missing 'seq' column in the CSV file.")
if "modifications" not in df.columns:
LOGGER.error(
"CSV file must contain a 'modifications' column with peptide modifications."
)
raise ValueError("Missing 'modifications' column in the CSV file.")
if "tr" not in df.columns:
LOGGER.error("CSV file must contain a 'tr' column with retention times.")
raise ValueError("Missing 'tr' column in the CSV file.")

reformed_peptides = [
reform_seq(seq, mod) for seq, mod in zip(df["seq"], df["modifications"])
]
LOGGER.info(
f"Chunk {chunk_idx}: loaded and reformed {len(reformed_peptides)} peptides sequences."
)

try:
sequences, tr, errors = df_to_matrix(reformed_peptides, df)
except Exception as e:
LOGGER.error(
f"Error converting sequences to matrix format in chunk {chunk_idx}: {e}"
)
raise

if errors:
LOGGER.warning(f"Errors encountered during conversion in chunk {chunk_idx}: {errors}")

prediction_dataset = MyDataset(sequences, tr)

if len(prediction_dataset) == 0:
LOGGER.warning(f"Chunk {chunk_idx} contains no valid peptide entries.")
continue

# Keep historical x_shape contract expected by model/tests: (batch, channels, length)
x_shape = (1,) + prediction_dataset[0][0].shape
LOGGER.info(f"Chunk {chunk_idx} initialized with data shape {x_shape}.")
yield df, prediction_dataset, x_shape


def get_input_shape_from_first_chunk(csv_path: str, chunk_size: int = 10000):
"""
Get the input shape from the first valid chunk of a CSV file.

:param csv_path: Path to the CSV file containing raw peptide sequences.
:param chunk_size: Number of rows to load per chunk.
:return: x_shape for model initialization.
"""
for _, dataset_chunk, x_shape in data_initialize_chunked(
csv_path=csv_path, chunk_size=chunk_size
):
LOGGER.info(f"Detected input shape from first valid chunk: {x_shape}")
return x_shape

LOGGER.error("No valid chunks found in the input file.")
raise ValueError("No valid chunks found in the input file.")
78 changes: 52 additions & 26 deletions ideeplc/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def __init__(
validation_data=None,
validation_split=0.1,
patience=5,
num_workers=0,
pin_memory=False,
):
"""
Initialize the fine-tuner with the model and data loaders.
Expand All @@ -38,6 +40,8 @@ def __init__(
:param validation_data: Optional validation dataset.
:param validation_split: Fraction of training data to use for validation.
:param patience: Number of epochs with no improvement after which training will be stopped.
:param num_workers: Number of workers for the DataLoader.
:param pin_memory: Whether to pin memory in the DataLoader.
"""
self.model = model.to(device)
self.train_data = train_data
Expand All @@ -49,6 +53,8 @@ def __init__(
self.validation_data = validation_data
self.validation_split = validation_split
self.patience = patience
self.num_workers = num_workers
self.pin_memory = pin_memory

def _freeze_layers(self, layers_to_freeze):
"""
Expand All @@ -71,34 +77,52 @@ def prepare_data(self, data, shuffle=True):
:param shuffle: Whether to shuffle the data.
:return: DataLoader for the dataset.
"""
return DataLoader(data, batch_size=self.batch_size, shuffle=shuffle)
return DataLoader(
data,
batch_size=self.batch_size,
shuffle=shuffle,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
)

def fine_tune(self, layers_to_freeze=None):
"""
Fine-tune the iDeepLC model on the training dataset.

:param layers_to_freeze: List of layer names to freeze during fine-tuning.
:return: Best model based on validation loss.
"""
LOGGER.info("Starting fine-tuning...")

if layers_to_freeze:
self._freeze_layers(layers_to_freeze)

optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, self.model.parameters()),
lr=self.learning_rate,
)
loss_fn = self.loss_function
# Prepare DataLoader
if self.validation_data:
dataloader_train = self.prepare_data(self.train_data)

if self.validation_data is not None:
dataloader_train = self.prepare_data(self.train_data, shuffle=True)
dataloader_val = self.prepare_data(self.validation_data, shuffle=False)
else:
# Split the training data into training and validation sets
train_size = int((1 - self.validation_split) * len(self.train_data))
val_size = len(self.train_data) - train_size

if train_size == 0 or val_size == 0:
raise ValueError(
"Training dataset is too small for the requested validation split."
)

train_dataset, val_dataset = torch.utils.data.random_split(
self.train_data, [train_size, val_size]
)
dataloader_train = self.prepare_data(train_dataset)
dataloader_train = self.prepare_data(train_dataset, shuffle=True)
dataloader_val = self.prepare_data(val_dataset, shuffle=False)

LOGGER.info(f"Training on {len(dataloader_train.dataset)} samples.")
LOGGER.info(f"Validating on {len(dataloader_val.dataset)} samples.")

best_model = copy.deepcopy(self.model)
best_loss = float("inf")
Expand All @@ -107,15 +131,15 @@ def fine_tune(self, layers_to_freeze=None):
for epoch in range(self.epochs):
self.model.train()
running_loss = 0.0

for batch in dataloader_train:
inputs, target = batch
inputs, target = inputs.to(self.device), target.to(self.device)
inputs = inputs.to(self.device, non_blocking=True)
target = target.to(self.device, non_blocking=True)

# Forward pass
outputs = self.model(inputs.float())
loss = loss_fn(outputs, target.float().view(-1, 1))

# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
Expand All @@ -125,22 +149,24 @@ def fine_tune(self, layers_to_freeze=None):
avg_loss = running_loss / len(dataloader_train.dataset)
LOGGER.info(f"Epoch [{epoch + 1}/{self.epochs}], Loss: {avg_loss:.4f}")

# Validate the model after each epoch
if dataloader_val:
val_loss, _, _, _ = validate(
self.model, dataloader_val, loss_fn, self.device
val_loss, _, _, _ = validate(
self.model, dataloader_val, loss_fn, self.device
)

if val_loss < best_loss:
best_loss = val_loss
best_model = copy.deepcopy(self.model)
patience_counter = 0
LOGGER.info(f"New best validation loss: {best_loss:.4f}")
else:
patience_counter += 1
LOGGER.info(
f"No improvement in validation loss. Patience: {patience_counter}/{self.patience}"
)
if val_loss < best_loss:
best_loss = val_loss
best_model = copy.deepcopy(self.model)
patience_counter = 0
LOGGER.info(f"New best validation loss: {best_loss:.4f}")
else:
patience_counter += 1

if patience_counter >= self.patience:
LOGGER.info("Early stopping triggered.")
break

if patience_counter >= self.patience:
LOGGER.info("Early stopping triggered.")
break

LOGGER.info("Fine-tuning complete.")
return best_model
return best_model
Loading
Loading