From 3835dc2682d3f2f8b03029be65a4a79b436c2054 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Wed, 18 Jun 2025 00:32:10 +0530 Subject: [PATCH 01/39] Update README.md --- CIGIN_V2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CIGIN_V2/README.md b/CIGIN_V2/README.md index b3d2b2e..a7d698a 100644 --- a/CIGIN_V2/README.md +++ b/CIGIN_V2/README.md @@ -7,7 +7,7 @@ `$ conda install -c rdkit rdkit==2019.03.1` * Installing other dependencies:\ `$ conda install -c pytorch pytorch `\ - `$ pip install dgl` (Please check [here](https://docs.dgl.ai/en/0.4.x/install/) for + `$ pip install dgl` (Please check [here](https://www.dgl.ai/pages/start.html) for installing for different cuda builds)\ `$ pip install numpy`\ `$ pip install pandas` From bbc0e910c90c4049f973080854ee97dc0f6b7bfe Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 11:18:37 +0530 Subject: [PATCH 02/39] Update model.py --- CIGIN_V2/model.py | 249 +++++++++++++++++++++------------------------- 1 file changed, 113 insertions(+), 136 deletions(-) diff --git a/CIGIN_V2/model.py b/CIGIN_V2/model.py index eb93686..c3d1145 100644 --- a/CIGIN_V2/model.py +++ b/CIGIN_V2/model.py @@ -1,179 +1,156 @@ import numpy as np - from dgl import DGLGraph -from dgl.nn.pytorch import Set2Set, NNConv, GATConv - +from dgl.nn.pytorch import Set2Set, NNConv import torch import torch.nn as nn import torch.nn.functional as F - - -class GatherModel(nn.Module): - """ - MPNN from - `Neural Message Passing for Quantum Chemistry ` - Parameters - ---------- - node_input_dim : int - Dimension of input node feature, default to be 42. - edge_input_dim : int - Dimension of input edge feature, default to be 10. - node_hidden_dim : int - Dimension of node feature in hidden layers, default to be 42. - edge_hidden_dim : int - Dimension of edge feature in hidden layers, default to be 128. - num_step_message_passing : int - Number of message passing steps, default to be 6. - """ - - def __init__(self, - node_input_dim=42, - edge_input_dim=10, - node_hidden_dim=42, - edge_hidden_dim=42, - num_step_message_passing=6, - ): - super(GatherModel, self).__init__() - self.num_step_message_passing = num_step_message_passing +class EnhancedGatherModel(nn.Module): + def __init__(self, node_input_dim=42, edge_input_dim=10, + node_hidden_dim=42, edge_hidden_dim=42, + num_step_message_passing=6): + super().__init__() + + # Enhanced edge processing + self.edge_network = nn.Sequential( + nn.Linear(edge_input_dim, edge_hidden_dim), + nn.ReLU(), + nn.LayerNorm(edge_hidden_dim), + nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim) + ) + + self.conv = NNConv( + in_feats=node_hidden_dim, + out_feats=node_hidden_dim, + edge_func=self.edge_network, + aggregator_type='mean', # Changed from sum to mean + residual=True + ) + + # Hierarchical message passing + self.num_steps = num_step_message_passing self.lin0 = nn.Linear(node_input_dim, node_hidden_dim) - self.set2set = Set2Set(node_hidden_dim, 2, 1) self.message_layer = nn.Linear(2 * node_hidden_dim, node_hidden_dim) - edge_network = nn.Sequential( - nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(), - nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim)) - self.conv = NNConv(in_feats=node_hidden_dim, - out_feats=node_hidden_dim, - edge_func=edge_network, - aggregator_type='sum', - residual=True - ) + + # Subgraph aggregation + self.subgraph_proj = nn.Linear(node_hidden_dim * 2, node_hidden_dim) def forward(self, g, n_feat, e_feat): - """Returns the node embeddings after message passing phase. - Parameters - ---------- - g : DGLGraph - Input DGLGraph for molecule(s) - n_feat : tensor of dtype float32 and shape (B1, D1) - Node features. B1 for number of nodes and D1 for - the node feature size. - e_feat : tensor of dtype float32 and shape (B2, D2) - Edge features. B2 for number of edges and D2 for - the edge feature size. - Returns - ------- - res : node features - """ - init = n_feat.clone() out = F.relu(self.lin0(n_feat)) - for i in range(self.num_step_message_passing): - if e_feat is not None: - m = torch.relu(self.conv(g, out, e_feat)) - else: - m = torch.relu(self.conv.bias + self.conv.res_fc(out)) + + # First-level atomic aggregation + for _ in range(self.num_steps // 2): + m = torch.relu(self.conv(g, out, e_feat)) if e_feat is not None \ + else torch.relu(self.conv.bias + self.conv.res_fc(out)) out = self.message_layer(torch.cat([m, out], dim=1)) + + # Second-level functional group aggregation + with g.local_scope(): + g.ndata['h'] = out + g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_group')) + group_feat = g.ndata['h_group'] + out = self.subgraph_proj(torch.cat([out, group_feat], dim=1)) + return out + init - -class CIGINModel(nn.Module): - """ - This the main class for CIGIN model - """ - - def __init__(self, - node_input_dim=42, - edge_input_dim=10, - node_hidden_dim=42, - edge_hidden_dim=42, - num_step_message_passing=6, - interaction='dot', - num_step_set2_set=2, - num_layer_set2set=1, - ): - super(CIGINModel, self).__init__() - - self.node_input_dim = node_input_dim +class EnhancedCIGINModel(nn.Module): + def __init__(self, node_input_dim=42, edge_input_dim=10, + node_hidden_dim=42, edge_hidden_dim=42, + num_step_message_passing=6, interaction='dot', + num_step_set2_set=2, num_layer_set2set=1): + super().__init__() + self.node_hidden_dim = node_hidden_dim - self.edge_input_dim = edge_input_dim - self.edge_hidden_dim = edge_hidden_dim - self.num_step_message_passing = num_step_message_passing self.interaction = interaction - self.solute_gather = GatherModel(self.node_input_dim, self.edge_input_dim, - self.node_hidden_dim, self.edge_input_dim, - self.num_step_message_passing, - ) - self.solvent_gather = GatherModel(self.node_input_dim, self.edge_input_dim, - self.node_hidden_dim, self.edge_input_dim, - self.num_step_message_passing, - ) - - self.fc1 = nn.Linear(8 * self.node_hidden_dim, 256) + + # Learnable interaction scaling + self.temperature = nn.Parameter(torch.tensor(1.0)) + + # Enhanced gather models + self.solute_gather = EnhancedGatherModel( + node_input_dim, edge_input_dim, + node_hidden_dim, edge_hidden_dim, + num_step_message_passing + ) + self.solvent_gather = EnhancedGatherModel( + node_input_dim, edge_input_dim, + node_hidden_dim, edge_hidden_dim, + num_step_message_passing + ) + + # Residual Set2Set pooling + self.set2set_solute = Set2Set(2 * node_hidden_dim, num_step_set2_set, num_layer_set2set) + self.set2set_solvent = Set2Set(2 * node_hidden_dim, num_step_set2_set, num_layer_set2set) + + # Multi-task prediction heads + self.fc1 = nn.Linear(8 * node_hidden_dim, 256) self.fc2 = nn.Linear(256, 128) self.fc3 = nn.Linear(128, 1) self.imap = nn.Linear(80, 1) - - self.num_step_set2set = num_step_set2_set - self.num_layer_set2set = num_layer_set2set - self.set2set_solute = Set2Set(2 * node_hidden_dim, self.num_step_set2set, self.num_layer_set2set) - self.set2set_solvent = Set2Set(2 * node_hidden_dim, self.num_step_set2set, self.num_layer_set2set) + + # Auxiliary prediction head + self.aux_head = nn.Sequential( + nn.Linear(8 * node_hidden_dim, 64), + nn.ReLU(), + nn.Linear(64, 3) # Predicts [logP, TPSA, QED] + ) def forward(self, data): - solute = data[0] - solvent = data[1] - solute_len = data[2] - solvent_len = data[3] - # node embeddings after interaction phase - solute_features = self.solute_gather(solute, solute.ndata['x'].float(), solute.edata['w'].float()) + solute, solvent, solute_len, solvent_len = data + + # Node embeddings + solute_features = self.solute_gather( + solute, solute.ndata['x'].float(), solute.edata['w'].float()) try: - # if edge exists in a molecule - solvent_features = self.solvent_gather(solvent, solvent.ndata['x'].float(), solvent.edata['w'].float()) + solvent_features = self.solvent_gather( + solvent, solvent.ndata['x'].float(), solvent.edata['w'].float()) except: - # if edge doesn't exist in a molecule, for example in case of water - solvent_features = self.solvent_gather(solvent, solvent.ndata['x'].float(), None) + solvent_features = self.solvent_gather( + solvent, solvent.ndata['x'].float(), None) - # Interaction phase + # Enhanced interaction phase len_map = torch.mm(solute_len.t(), solvent_len) - + if 'dot' not in self.interaction: X1 = solute_features.unsqueeze(0) Y1 = solvent_features.unsqueeze(1) - X2 = X1.repeat(solvent_features.shape[0], 1, 1) - Y2 = Y1.repeat(1, solute_features.shape[0], 1) - Z = torch.cat([X2, Y2], -1) - - if self.interaction == 'general': - interaction_map = self.imap(Z).squeeze(2) + Z = torch.cat([ + X1.repeat(solvent_features.shape[0], 1, 1), + Y1.repeat(1, solute_features.shape[0], 1) + ], -1) + + interaction_map = self.imap(Z).squeeze(2) if self.interaction == 'tanh-general': - interaction_map = torch.tanh(self.imap(Z)).squeeze(2) - + interaction_map = torch.tanh(interaction_map) interaction_map = torch.mul(len_map.float(), interaction_map.t()) - ret_interaction_map = torch.clone(interaction_map) - - elif 'dot' in self.interaction: + + else: interaction_map = torch.mm(solute_features, solvent_features.t()) if 'scaled' in self.interaction: - interaction_map = interaction_map / (np.sqrt(self.node_hidden_dim)) - - ret_interaction_map = torch.clone(interaction_map) - ret_interaction_map = torch.mul(len_map.float(), ret_interaction_map) + interaction_map = interaction_map / (self.temperature.abs() + 1e-8) interaction_map = torch.tanh(interaction_map) interaction_map = torch.mul(len_map.float(), interaction_map) - + + ret_interaction_map = torch.clone(interaction_map) solvent_prime = torch.mm(interaction_map.t(), solute_features) solute_prime = torch.mm(interaction_map, solvent_features) - # Prediction phase + # Prediction phase with residual connections solute_features = torch.cat((solute_features, solute_prime), dim=1) solvent_features = torch.cat((solvent_features, solvent_prime), dim=1) - + solute_features = self.set2set_solute(solute, solute_features) solvent_features = self.set2set_solvent(solvent, solvent_features) - + final_features = torch.cat((solute_features, solvent_features), 1) - predictions = torch.relu(self.fc1(final_features)) - predictions = torch.relu(self.fc2(predictions)) - predictions = self.fc3(predictions) - - return predictions, ret_interaction_map + + # Main prediction + main_pred = F.relu(self.fc1(final_features)) + main_pred = F.relu(self.fc2(main_pred)) + main_pred = self.fc3(main_pred) + + # Auxiliary predictions + aux_pred = self.aux_head(final_features.detach()) + + return main_pred, aux_pred, ret_interaction_map From 5e220193c7ce21ff6512d820cb634b4dd0359398 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 11:19:49 +0530 Subject: [PATCH 03/39] Update utils.py --- CIGIN_V2/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CIGIN_V2/utils.py b/CIGIN_V2/utils.py index 1fe5231..0e085d9 100644 --- a/CIGIN_V2/utils.py +++ b/CIGIN_V2/utils.py @@ -1,5 +1,7 @@ import numpy as np + + def one_of_k_encoding(x, allowable_set): if x not in allowable_set: raise Exception("input {0} not in allowable set{1}:".format( From ebf657779c1cbb303dd9557eb5e2e4f69780b41d Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 11:20:28 +0530 Subject: [PATCH 04/39] Update molecular_graph.py --- CIGIN_V2/molecular_graph.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/CIGIN_V2/molecular_graph.py b/CIGIN_V2/molecular_graph.py index 8580397..52e745a 100644 --- a/CIGIN_V2/molecular_graph.py +++ b/CIGIN_V2/molecular_graph.py @@ -3,7 +3,7 @@ from rdkit import Chem from rdkit.Chem import rdMolDescriptors as rdDesc from utils import one_of_k_encoding_unk, one_of_k_encoding - +import torch def get_atom_features(atom, stereo, features, explicit_H=False): """ @@ -87,10 +87,10 @@ def get_graph_from_smile(molecule_smile): for j in range(molecule.GetNumAtoms()): bond_ij = molecule.GetBondBetweenAtoms(i, j) if bond_ij is not None: - G.add_edge(i, j) + G.add_edges(i, j) bond_features_ij = get_bond_features(bond_ij) edge_features.append(bond_features_ij) - G.ndata['x'] = np.array(node_features) - G.edata['w'] = np.array(edge_features) + G.ndata['x'] = torch.tensor(node_features) + G.edata['w'] = torch.tensor(edge_features) return G From aa02b7b2fecc833599b854ca3e14c1c05e4073c8 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 11:21:03 +0530 Subject: [PATCH 05/39] Update main.py --- CIGIN_V2/main.py | 64 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 45 insertions(+), 19 deletions(-) diff --git a/CIGIN_V2/main.py b/CIGIN_V2/main.py index 3be9f5f..fe9d16f 100644 --- a/CIGIN_V2/main.py +++ b/CIGIN_V2/main.py @@ -3,6 +3,8 @@ import warnings import os import argparse +from sklearn.model_selection import train_test_split, KFold +import numpy as np # rdkit imports from rdkit import RDLogger @@ -19,9 +21,14 @@ # local imports from model import CIGINModel -from train import train +from train import train, evaluate_model, get_metrics from molecular_graph import get_graph_from_smile from utils import * +from models.van_GAT import CIGINGAT +from models.van_GGN import CIGINGGN +from models.van_GCN import CIGINGCN +from models.van_GAP import CIGINGAP +from models.van_WAS import CIGINWAS lg = RDLogger.logger() lg.setLevel(RDLogger.CRITICAL) @@ -44,17 +51,21 @@ use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") -if not os.path.isdir("runs/run-" + str(project_name)): - os.makedirs("./runs/run-" + str(project_name)) - os.makedirs("./runs/run-" + str(project_name) + "/models") +models = [CIGINGAP(interaction=interaction), CIGINWAS(interaction=interaction)] +model_names = ['cigin_gap', 'cigin_was'] + +for project_name in model_names: + if not os.path.isdir("runs/run-" + str(project_name)): + os.makedirs("./runs/run-" + str(project_name)) + os.makedirs("./runs/run-" + str(project_name) + "/models") def collate(samples): solute_graphs, solvent_graphs, labels = map(list, zip(*samples)) solute_graphs = dgl.batch(solute_graphs) solvent_graphs = dgl.batch(solvent_graphs) - solute_len_matrix = get_len_matrix(solute_graphs.batch_num_nodes) - solvent_len_matrix = get_len_matrix(solvent_graphs.batch_num_nodes) + solute_len_matrix = get_len_matrix(solute_graphs.batch_num_nodes().tolist()) + solvent_len_matrix = get_len_matrix(solvent_graphs.batch_num_nodes().tolist()) return solute_graphs, solvent_graphs, solute_len_matrix, solvent_len_matrix, labels @@ -67,38 +78,53 @@ def __len__(self): def __getitem__(self, idx): - solute = self.dataset.loc[idx]['SoluteSMILES'] + # print('solute', self.dataset.iloc[idx]['SoluteSMILES']) + solute = self.dataset.iloc[idx]['SoluteSMILES'] mol = Chem.MolFromSmiles(solute) mol = Chem.AddHs(mol) solute = Chem.MolToSmiles(mol) solute_graph = get_graph_from_smile(solute) + # print('solvent',self.dataset.iloc[idx]['SolventSMILES']) + solvent = self.dataset.iloc[idx]['SolventSMILES'] - solvent = self.dataset.loc[idx]['SolventSMILES'] mol = Chem.MolFromSmiles(solvent) mol = Chem.AddHs(mol) solvent = Chem.MolToSmiles(mol) solvent_graph = get_graph_from_smile(solvent) - delta_g = self.dataset.loc[idx]['DeltaGsolv'] + delta_g = self.dataset.iloc[idx]['delGsolv'] return [solute_graph, solvent_graph, [delta_g]] - def main(): - train_df = pd.read_csv('data/train.csv', sep=";") - valid_df = pd.read_csv('data/valid.csv', sep=";") + # train_df = pd.read_csv('data/train.csv', sep=";") + # valid_df = pd.read_csv('data/valid.csv', sep=";") + + df = pd.read_csv('https://raw.githubusercontent.com/adithyamauryakr/CIGIN-DevaLab/refs/heads/master/CIGIN_V2/data/whole_data.csv') + df.columns = df.columns.str.strip() + print(df.columns) + train_df, test_df = train_test_split(df, test_size=0.1, random_state=42) + train_df, valid_df = train_test_split(train_df, test_size=0.111, random_state=42) train_dataset = Dataclass(train_df) valid_dataset = Dataclass(valid_df) + test_dataset = Dataclass(test_df) train_loader = DataLoader(train_dataset, collate_fn=collate, batch_size=batch_size, shuffle=True) valid_loader = DataLoader(valid_dataset, collate_fn=collate, batch_size=128) - - model = CIGINModel(interaction=interaction) - model.to(device) - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - scheduler = ReduceLROnPlateau(optimizer, patience=5, mode='min', verbose=True) - - train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, project_name) + test_loader = DataLoader(test_dataset, collate_fn=collate, batch_size=128) + for model, project_name in zip(models, model_names): + print('current_model:', project_name) + + model.to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + scheduler = ReduceLROnPlateau(optimizer, patience=5, mode='min', verbose=True) + + train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, project_name) + + # check on testing data: + model.eval() + loss, mae_loss = get_metrics(model, test_loader) + print(f"Model performance on the testing data: Loss: {loss}, MAE_Loss: {mae_loss}") if __name__ == '__main__': From b9d46de30ee1b23ea3e9b450392d7f2a7f12ea72 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 11:23:22 +0530 Subject: [PATCH 06/39] Update train.py --- CIGIN_V2/train.py | 221 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 172 insertions(+), 49 deletions(-) diff --git a/CIGIN_V2/train.py b/CIGIN_V2/train.py index 48f8400..3504d78 100644 --- a/CIGIN_V2/train.py +++ b/CIGIN_V2/train.py @@ -1,60 +1,183 @@ -from tqdm import tqdm -import torch import numpy as np +import torch +import torch.nn as nn +from tqdm import tqdm +from dgl import DGLGraph +from torch.optim.lr_scheduler import ReduceLROnPlateau +import torch.cuda.amp as amp -loss_fn = torch.nn.MSELoss() -mae_loss_fn = torch.nn.L1Loss() - +# Device configuration use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") +class MultiTaskLossWrapper(nn.Module): + """Adaptive loss weighting for multi-task learning""" + def __init__(self, task_num=3): + super().__init__() + self.task_num = task_num + self.log_vars = nn.Parameter(torch.zeros(task_num)) + self.mse = nn.MSELoss() + self.mae = nn.L1Loss() -def get_metrics(model, data_loader): - valid_outputs = [] - valid_labels = [] - valid_loss = [] - valid_mae_loss = [] - for solute_graphs, solvent_graphs, solute_lens, solvent_lens, labels in tqdm(data_loader): - outputs, i_map = model( - [solute_graphs.to(device), solvent_graphs.to(device), torch.tensor(solute_lens).to(device), - torch.tensor(solvent_lens).to(device)]) - loss = loss_fn(outputs, torch.tensor(labels).to(device).float()) - mae_loss = mae_loss_fn(outputs, torch.tensor(labels).to(device).float()) - valid_outputs += outputs.cpu().detach().numpy().tolist() - valid_loss.append(loss.cpu().detach().numpy()) - valid_mae_loss.append(mae_loss.cpu().detach().numpy()) - valid_labels += labels - - loss = np.mean(np.array(valid_loss).flatten()) - mae_loss = np.mean(np.array(valid_mae_loss).flatten()) - return loss, mae_loss + def forward(self, preds, targets): + # Main task loss (ΔG prediction) + mse_loss = self.mse(preds[0], targets[0]) + mae_loss = self.mae(preds[0], targets[0]) + + # Auxiliary task losses (logP, TPSA, QED) + aux_loss = 0 + if len(targets) > 1 and targets[1] is not None: + for i in range(self.task_num): + aux_loss += torch.exp(-self.log_vars[i]) * self.mse(preds[1][:,i], targets[1][:,i]) + self.log_vars[i] + + return { + 'total': mse_loss + 0.3 * aux_loss, # Weighted sum + 'mse': mse_loss, + 'mae': mae_loss, + 'aux': aux_loss + } +def get_metrics(model, data_loader, return_preds=False): + """Enhanced evaluation with optional prediction returns""" + model.eval() + total_loss = {'mse': 0, 'mae': 0, 'aux': 0} + all_preds = [] + all_labels = [] + + with torch.no_grad(): + for solute_graphs, solvent_graphs, solute_lens, solvent_lens, labels in data_loader: + inputs = [ + solute_graphs.to(device), + solvent_graphs.to(device), + solute_lens.to(device), + solvent_lens.to(device) + ] + labels = labels.to(device) + + # Forward pass + with amp.autocast(enabled=use_cuda): + main_pred, aux_pred, _ = model(inputs) + loss_fn = MultiTaskLossWrapper() + losses = loss_fn((main_pred, aux_pred), (labels, None)) + + # Accumulate metrics + for k in total_loss: + total_loss[k] += losses[k].item() * len(labels) + + if return_preds: + all_preds.extend(main_pred.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + + # Calculate averages + num_samples = len(data_loader.dataset) + metrics = {k: v / num_samples for k, v in total_loss.items()} + + if return_preds: + return metrics, (np.array(all_preds), np.array(all_labels)) + return metrics def train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, project_name): - best_val_loss = 100 + """Enhanced training loop with multiple improvements""" + best_val_loss = float('inf') + loss_fn = MultiTaskLossWrapper() + scaler = amp.GradScaler(enabled=use_cuda) + + # Training statistics + history = { + 'train_loss': [], + 'val_loss': [], + 'val_mae': [], + 'lr': [] + } + for epoch in range(max_epochs): model.train() - running_loss = [] - tq_loader = tqdm(train_loader) - o = {} - for samples in tq_loader: - optimizer.zero_grad() - outputs, interaction_map = model( - [samples[0].to(device), samples[1].to(device), torch.tensor(samples[2]).to(device), - torch.tensor(samples[3]).to(device)]) - l1_norm = torch.norm(interaction_map, p=2) * 1e-4 - loss = loss_fn(outputs, torch.tensor(samples[4]).to(device).float()) + l1_norm - loss.backward() - optimizer.step() - loss = loss - l1_norm - running_loss.append(loss.cpu().detach()) - tq_loader.set_description( - "Epoch: " + str(epoch + 1) + " Training loss: " + str(np.mean(np.array(running_loss)))) - model.eval() - val_loss, mae_loss = get_metrics(model, valid_loader) - scheduler.step(val_loss) - print("Epoch: " + str(epoch + 1) + " train_loss " + str(np.mean(np.array(running_loss))) + " Val_loss " + str( - val_loss) + " MAE Val_loss " + str(mae_loss)) - if val_loss < best_val_loss: - best_val_loss = val_loss - torch.save(model.state_dict(), "./runs/run-" + str(project_name) + "/models/best_model.tar") + running_loss = {'mse': 0, 'mae': 0, 'aux': 0} + total_samples = 0 + + # Gradient accumulation + accum_steps = 4 + optimizer.zero_grad() + + with tqdm(train_loader, unit="batch") as tepoch: + for i, samples in enumerate(tepoch): + inputs = [ + samples[0].to(device), + samples[1].to(device), + samples[2].to(device), + samples[3].to(device) + ] + labels = samples[4].to(device) + batch_size = labels.shape[0] + total_samples += batch_size + + # Mixed precision forward + with amp.autocast(enabled=use_cuda): + main_pred, aux_pred, i_map = model(inputs) + l1_norm = torch.norm(i_map, p=2) * 1e-4 + losses = loss_fn((main_pred, aux_pred), (labels, None)) + loss = losses['total'] / accum_steps + l1_norm + + # Backward pass + scaler.scale(loss).backward() + + # Gradient accumulation update + if (i + 1) % accum_steps == 0 or (i + 1) == len(train_loader): + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + # Update running losses + for k in running_loss: + running_loss[k] += losses[k].item() * batch_size + + # Progress bar update + tepoch.set_postfix({ + 'loss': f"{running_loss['mse']/total_samples:.4f}", + 'mae': f"{running_loss['mae']/total_samples:.4f}" + }) + + # Calculate epoch metrics + train_metrics = {k: v/total_samples for k, v in running_loss.items()} + val_metrics = get_metrics(model, valid_loader) + + # Update scheduler + scheduler.step(val_metrics['mse']) + + # Store history + history['train_loss'].append(train_metrics['mse']) + history['val_loss'].append(val_metrics['mse']) + history['val_mae'].append(val_metrics['mae']) + history['lr'].append(optimizer.param_groups[0]['lr']) + + # Print epoch summary + print(f"\nEpoch {epoch+1}/{max_epochs}:") + print(f"Train MSE: {train_metrics['mse']:.4f} | Val MSE: {val_metrics['mse']:.4f}") + print(f"Train MAE: {train_metrics['mae']:.4f} | Val MAE: {val_metrics['mae']:.4f}") + print(f"LR: {history['lr'][-1]:.2e}") + + # Save best model + if val_metrics['mse'] < best_val_loss: + best_val_loss = val_metrics['mse'] + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + 'loss': best_val_loss, + 'metrics': val_metrics, + 'history': history + }, f"./runs/run-{project_name}/models/best_model.tar") + print(f"New best model saved with Val MSE: {best_val_loss:.4f}") + + return history + +def load_best_model(model, project_name): + """Load the best saved model""" + checkpoint = torch.load(f"./runs/run-{project_name}/models/best_model.tar") + model.load_state_dict(checkpoint['model_state_dict']) + return model, checkpoint['history'] + +if __name__ == '__main__': + # Example usage (would normally be called from main.py) + print("This module contains training utilities and should be imported") From 5411125a9a6fe8f1172f665bd865b096dba11c5f Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 11:26:23 +0530 Subject: [PATCH 07/39] Update model.py --- CIGIN_V2/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CIGIN_V2/model.py b/CIGIN_V2/model.py index c3d1145..f7b34ae 100644 --- a/CIGIN_V2/model.py +++ b/CIGIN_V2/model.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F -class EnhancedGatherModel(nn.Module): +class GatherModel(nn.Module): def __init__(self, node_input_dim=42, edge_input_dim=10, node_hidden_dim=42, edge_hidden_dim=42, num_step_message_passing=6): @@ -68,12 +68,12 @@ def __init__(self, node_input_dim=42, edge_input_dim=10, self.temperature = nn.Parameter(torch.tensor(1.0)) # Enhanced gather models - self.solute_gather = EnhancedGatherModel( + self.solute_gather = GatherModel( node_input_dim, edge_input_dim, node_hidden_dim, edge_hidden_dim, num_step_message_passing ) - self.solvent_gather = EnhancedGatherModel( + self.solvent_gather = GatherModel( node_input_dim, edge_input_dim, node_hidden_dim, edge_hidden_dim, num_step_message_passing From 78f19cc44c918c1d7d8fbd01396622cde8cc11bf Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 11:27:37 +0530 Subject: [PATCH 08/39] Update model.py --- CIGIN_V2/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CIGIN_V2/model.py b/CIGIN_V2/model.py index f7b34ae..55b51ee 100644 --- a/CIGIN_V2/model.py +++ b/CIGIN_V2/model.py @@ -54,7 +54,7 @@ def forward(self, g, n_feat, e_feat): return out + init -class EnhancedCIGINModel(nn.Module): +class CIGINModel(nn.Module): def __init__(self, node_input_dim=42, edge_input_dim=10, node_hidden_dim=42, edge_hidden_dim=42, num_step_message_passing=6, interaction='dot', From 921f74dae5cd06c395f56bca552da9eb8ce66689 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 11:34:50 +0530 Subject: [PATCH 09/39] added evaluation function --- CIGIN_V2/train.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/CIGIN_V2/train.py b/CIGIN_V2/train.py index 3504d78..5404c90 100644 --- a/CIGIN_V2/train.py +++ b/CIGIN_V2/train.py @@ -10,6 +10,23 @@ use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") +# Add this function to your existing train.py (keep everything else the same) +def evaluate_model(model, dataloader): + model.eval() + preds = [] + targets = [] + with torch.no_grad(): + for solute_graphs, solvent_graphs, solute_lens, solvent_lens, labels in dataloader: + outputs, _ = model([ + solute_graphs.to(device), + solvent_graphs.to(device), + solute_lens.to(device), + solvent_lens.to(device) + ]) + preds.extend(outputs.cpu().numpy()) + targets.extend(labels) + return np.sqrt(np.mean((np.array(preds) - np.array(targets))**2)) + class MultiTaskLossWrapper(nn.Module): """Adaptive loss weighting for multi-task learning""" def __init__(self, task_num=3): From 3427ac76ad3fb4357b2f6340e006e574d5d6cf6c Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 11:37:00 +0530 Subject: [PATCH 10/39] Update main.py --- CIGIN_V2/main.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/CIGIN_V2/main.py b/CIGIN_V2/main.py index fe9d16f..bc9900d 100644 --- a/CIGIN_V2/main.py +++ b/CIGIN_V2/main.py @@ -21,14 +21,10 @@ # local imports from model import CIGINModel -from train import train, evaluate_model, get_metrics +from train import train, get_metrics from molecular_graph import get_graph_from_smile from utils import * -from models.van_GAT import CIGINGAT -from models.van_GGN import CIGINGGN -from models.van_GCN import CIGINGCN -from models.van_GAP import CIGINGAP -from models.van_WAS import CIGINWAS + lg = RDLogger.logger() lg.setLevel(RDLogger.CRITICAL) From 92dd891eb8ba3a26eaafd0c5e8a599b731bbe462 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 11:40:54 +0530 Subject: [PATCH 11/39] Update main.py --- CIGIN_V2/main.py | 92 +++++++++++++++++++++++++----------------------- 1 file changed, 47 insertions(+), 45 deletions(-) diff --git a/CIGIN_V2/main.py b/CIGIN_V2/main.py index bc9900d..cf91bfb 100644 --- a/CIGIN_V2/main.py +++ b/CIGIN_V2/main.py @@ -3,7 +3,7 @@ import warnings import os import argparse -from sklearn.model_selection import train_test_split, KFold +from sklearn.model_selection import train_test_split import numpy as np # rdkit imports @@ -16,7 +16,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau import torch -#dgl imports +# dgl imports import dgl # local imports @@ -25,38 +25,37 @@ from molecular_graph import get_graph_from_smile from utils import * - +# Disable logs and warnings lg = RDLogger.logger() lg.setLevel(RDLogger.CRITICAL) rdBase.DisableLog('rdApp.error') warnings.filterwarnings("ignore") +# Argument parsing parser = argparse.ArgumentParser() parser.add_argument('--name', default='cigin', help="The name of the current project: default: CIGIN") -parser.add_argument('--interaction', help="type of interaction function to use: dot | scaled-dot | general | " - "tanh-general", default='dot') -parser.add_argument('--max_epochs', required=False, default=100, help="The max number of epochs for training") -parser.add_argument('--batch_size', required=False, default=32, help="The batch size for training") +parser.add_argument('--interaction', help="type of interaction function to use: dot | scaled-dot | general | tanh-general", + default='dot') +parser.add_argument('--max_epochs', type=int, default=100, help="The max number of epochs for training") +parser.add_argument('--batch_size', type=int, default=32, help="The batch size for training") args = parser.parse_args() project_name = args.name interaction = args.interaction -max_epochs = int(args.max_epochs) -batch_size = int(args.batch_size) +max_epochs = args.max_epochs +batch_size = args.batch_size +# Device configuration use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") -models = [CIGINGAP(interaction=interaction), CIGINWAS(interaction=interaction)] -model_names = ['cigin_gap', 'cigin_was'] - -for project_name in model_names: - if not os.path.isdir("runs/run-" + str(project_name)): - os.makedirs("./runs/run-" + str(project_name)) - os.makedirs("./runs/run-" + str(project_name) + "/models") - +# Create output directory +if not os.path.isdir("runs/run-" + str(project_name)): + os.makedirs("./runs/run-" + str(project_name)) + os.makedirs("./runs/run-" + str(project_name) + "/models") def collate(samples): + """Batch preparation function""" solute_graphs, solvent_graphs, labels = map(list, zip(*samples)) solute_graphs = dgl.batch(solute_graphs) solvent_graphs = dgl.batch(solvent_graphs) @@ -64,8 +63,8 @@ def collate(samples): solvent_len_matrix = get_len_matrix(solvent_graphs.batch_num_nodes().tolist()) return solute_graphs, solvent_graphs, solute_len_matrix, solvent_len_matrix, labels - -class Dataclass(Dataset): +class SolvationDataset(Dataset): + """Custom dataset class for solvation data""" def __init__(self, dataset): self.dataset = dataset @@ -73,55 +72,58 @@ def __len__(self): return len(self.dataset) def __getitem__(self, idx): - - # print('solute', self.dataset.iloc[idx]['SoluteSMILES']) + # Process solute solute = self.dataset.iloc[idx]['SoluteSMILES'] mol = Chem.MolFromSmiles(solute) mol = Chem.AddHs(mol) solute = Chem.MolToSmiles(mol) solute_graph = get_graph_from_smile(solute) - # print('solvent',self.dataset.iloc[idx]['SolventSMILES']) - solvent = self.dataset.iloc[idx]['SolventSMILES'] + # Process solvent + solvent = self.dataset.iloc[idx]['SolventSMILES'] mol = Chem.MolFromSmiles(solvent) mol = Chem.AddHs(mol) solvent = Chem.MolToSmiles(mol) - solvent_graph = get_graph_from_smile(solvent) + delta_g = self.dataset.iloc[idx]['delGsolv'] - return [solute_graph, solvent_graph, [delta_g]] + return solute_graph, solvent_graph, delta_g def main(): - # train_df = pd.read_csv('data/train.csv', sep=";") - # valid_df = pd.read_csv('data/valid.csv', sep=";") - + # Load and prepare data df = pd.read_csv('https://raw.githubusercontent.com/adithyamauryakr/CIGIN-DevaLab/refs/heads/master/CIGIN_V2/data/whole_data.csv') df.columns = df.columns.str.strip() - print(df.columns) + + # Split data into train/valid/test (80/10/10) train_df, test_df = train_test_split(df, test_size=0.1, random_state=42) train_df, valid_df = train_test_split(train_df, test_size=0.111, random_state=42) - train_dataset = Dataclass(train_df) - valid_dataset = Dataclass(valid_df) - test_dataset = Dataclass(test_df) + # Create datasets and data loaders + train_dataset = SolvationDataset(train_df) + valid_dataset = SolvationDataset(valid_df) + test_dataset = SolvationDataset(test_df) train_loader = DataLoader(train_dataset, collate_fn=collate, batch_size=batch_size, shuffle=True) valid_loader = DataLoader(valid_dataset, collate_fn=collate, batch_size=128) test_loader = DataLoader(test_dataset, collate_fn=collate, batch_size=128) - for model, project_name in zip(models, model_names): - print('current_model:', project_name) - - model.to(device) - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - scheduler = ReduceLROnPlateau(optimizer, patience=5, mode='min', verbose=True) - - train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, project_name) - - # check on testing data: - model.eval() - loss, mae_loss = get_metrics(model, test_loader) - print(f"Model performance on the testing data: Loss: {loss}, MAE_Loss: {mae_loss}") + # Initialize model + model = CIGINModel(interaction=interaction) + model.to(device) + + # Training setup + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + scheduler = ReduceLROnPlateau(optimizer, patience=5, mode='min', verbose=True) + + # Training loop + train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, project_name) + + # Final evaluation + model.eval() + loss, mae_loss = get_metrics(model, test_loader) + print(f"\nFinal test set performance:") + print(f"MSE Loss: {loss:.4f}") + print(f"MAE Loss: {mae_loss:.4f}") if __name__ == '__main__': main() From 8e8f609a7eb28af2a3496159338b271fd74d6df7 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 11:44:17 +0530 Subject: [PATCH 12/39] changed from np to tensor --- CIGIN_V2/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CIGIN_V2/main.py b/CIGIN_V2/main.py index cf91bfb..6c68094 100644 --- a/CIGIN_V2/main.py +++ b/CIGIN_V2/main.py @@ -55,12 +55,12 @@ os.makedirs("./runs/run-" + str(project_name) + "/models") def collate(samples): - """Batch preparation function""" solute_graphs, solvent_graphs, labels = map(list, zip(*samples)) solute_graphs = dgl.batch(solute_graphs) solvent_graphs = dgl.batch(solvent_graphs) - solute_len_matrix = get_len_matrix(solute_graphs.batch_num_nodes().tolist()) - solvent_len_matrix = get_len_matrix(solvent_graphs.batch_num_nodes().tolist()) + solute_len_matrix = torch.FloatTensor(get_len_matrix(solute_graphs.batch_num_nodes().tolist())) + solvent_len_matrix = torch.FloatTensor(get_len_matrix(solvent_graphs.batch_num_nodes().tolist())) + labels = torch.FloatTensor(labels) return solute_graphs, solvent_graphs, solute_len_matrix, solvent_len_matrix, labels class SolvationDataset(Dataset): From 836adba61091662694b65858505a0b5aa1a00882 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 11:50:08 +0530 Subject: [PATCH 13/39] Update train.py --- CIGIN_V2/train.py | 241 +++++++++++++++++++--------------------------- 1 file changed, 98 insertions(+), 143 deletions(-) diff --git a/CIGIN_V2/train.py b/CIGIN_V2/train.py index 5404c90..4937bc8 100644 --- a/CIGIN_V2/train.py +++ b/CIGIN_V2/train.py @@ -2,67 +2,60 @@ import torch import torch.nn as nn from tqdm import tqdm -from dgl import DGLGraph from torch.optim.lr_scheduler import ReduceLROnPlateau -import torch.cuda.amp as amp # Device configuration use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") -# Add this function to your existing train.py (keep everything else the same) +# Loss functions +loss_fn = torch.nn.MSELoss() +mae_loss_fn = torch.nn.L1Loss() + def evaluate_model(model, dataloader): + """Evaluation function for compatibility with original code""" model.eval() preds = [] targets = [] with torch.no_grad(): for solute_graphs, solvent_graphs, solute_lens, solvent_lens, labels in dataloader: + # Convert inputs to tensors if they aren't already + if not isinstance(solute_lens, torch.Tensor): + solute_lens = torch.FloatTensor(solute_lens) + if not isinstance(solvent_lens, torch.Tensor): + solvent_lens = torch.FloatTensor(solvent_lens) + if not isinstance(labels, torch.Tensor): + labels = torch.FloatTensor(labels) + outputs, _ = model([ solute_graphs.to(device), solvent_graphs.to(device), solute_lens.to(device), solvent_lens.to(device) ]) - preds.extend(outputs.cpu().numpy()) - targets.extend(labels) + preds.extend(outputs.cpu().numpy().flatten()) + targets.extend(labels.cpu().numpy().flatten()) return np.sqrt(np.mean((np.array(preds) - np.array(targets))**2)) -class MultiTaskLossWrapper(nn.Module): - """Adaptive loss weighting for multi-task learning""" - def __init__(self, task_num=3): - super().__init__() - self.task_num = task_num - self.log_vars = nn.Parameter(torch.zeros(task_num)) - self.mse = nn.MSELoss() - self.mae = nn.L1Loss() - - def forward(self, preds, targets): - # Main task loss (ΔG prediction) - mse_loss = self.mse(preds[0], targets[0]) - mae_loss = self.mae(preds[0], targets[0]) - - # Auxiliary task losses (logP, TPSA, QED) - aux_loss = 0 - if len(targets) > 1 and targets[1] is not None: - for i in range(self.task_num): - aux_loss += torch.exp(-self.log_vars[i]) * self.mse(preds[1][:,i], targets[1][:,i]) + self.log_vars[i] - - return { - 'total': mse_loss + 0.3 * aux_loss, # Weighted sum - 'mse': mse_loss, - 'mae': mae_loss, - 'aux': aux_loss - } - -def get_metrics(model, data_loader, return_preds=False): - """Enhanced evaluation with optional prediction returns""" +def get_metrics(model, data_loader): + """Calculate MSE and MAE losses""" model.eval() - total_loss = {'mse': 0, 'mae': 0, 'aux': 0} - all_preds = [] - all_labels = [] + valid_loss = [] + valid_mae_loss = [] + valid_outputs = [] + valid_labels = [] with torch.no_grad(): for solute_graphs, solvent_graphs, solute_lens, solvent_lens, labels in data_loader: + # Ensure all inputs are tensors + if not isinstance(solute_lens, torch.Tensor): + solute_lens = torch.FloatTensor(solute_lens) + if not isinstance(solvent_lens, torch.Tensor): + solvent_lens = torch.FloatTensor(solvent_lens) + if not isinstance(labels, torch.Tensor): + labels = torch.FloatTensor(labels) + + # Move data to device inputs = [ solute_graphs.to(device), solvent_graphs.to(device), @@ -72,129 +65,91 @@ def get_metrics(model, data_loader, return_preds=False): labels = labels.to(device) # Forward pass - with amp.autocast(enabled=use_cuda): - main_pred, aux_pred, _ = model(inputs) - loss_fn = MultiTaskLossWrapper() - losses = loss_fn((main_pred, aux_pred), (labels, None)) + outputs, _ = model(inputs) - # Accumulate metrics - for k in total_loss: - total_loss[k] += losses[k].item() * len(labels) + # Calculate losses + loss = loss_fn(outputs, labels) + mae_loss = mae_loss_fn(outputs, labels) - if return_preds: - all_preds.extend(main_pred.cpu().numpy()) - all_labels.extend(labels.cpu().numpy()) + # Store results + valid_loss.append(loss.cpu().item()) + valid_mae_loss.append(mae_loss.cpu().item()) + valid_outputs.extend(outputs.cpu().numpy().flatten()) + valid_labels.extend(labels.cpu().numpy().flatten()) - # Calculate averages - num_samples = len(data_loader.dataset) - metrics = {k: v / num_samples for k, v in total_loss.items()} - - if return_preds: - return metrics, (np.array(all_preds), np.array(all_labels)) - return metrics + # Calculate mean losses + loss = np.mean(valid_loss) + mae_loss = np.mean(valid_mae_loss) + return loss, mae_loss def train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, project_name): - """Enhanced training loop with multiple improvements""" + """Main training loop""" best_val_loss = float('inf') - loss_fn = MultiTaskLossWrapper() - scaler = amp.GradScaler(enabled=use_cuda) - - # Training statistics - history = { - 'train_loss': [], - 'val_loss': [], - 'val_mae': [], - 'lr': [] - } for epoch in range(max_epochs): model.train() - running_loss = {'mse': 0, 'mae': 0, 'aux': 0} - total_samples = 0 + running_loss = [] - # Gradient accumulation - accum_steps = 4 - optimizer.zero_grad() + # Initialize progress bar + tq_loader = tqdm(train_loader, desc=f"Epoch {epoch+1}/{max_epochs}") - with tqdm(train_loader, unit="batch") as tepoch: - for i, samples in enumerate(tepoch): - inputs = [ - samples[0].to(device), - samples[1].to(device), - samples[2].to(device), - samples[3].to(device) - ] - labels = samples[4].to(device) - batch_size = labels.shape[0] - total_samples += batch_size - - # Mixed precision forward - with amp.autocast(enabled=use_cuda): - main_pred, aux_pred, i_map = model(inputs) - l1_norm = torch.norm(i_map, p=2) * 1e-4 - losses = loss_fn((main_pred, aux_pred), (labels, None)) - loss = losses['total'] / accum_steps + l1_norm - - # Backward pass - scaler.scale(loss).backward() - - # Gradient accumulation update - if (i + 1) % accum_steps == 0 or (i + 1) == len(train_loader): - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - - # Update running losses - for k in running_loss: - running_loss[k] += losses[k].item() * batch_size - - # Progress bar update - tepoch.set_postfix({ - 'loss': f"{running_loss['mse']/total_samples:.4f}", - 'mae': f"{running_loss['mae']/total_samples:.4f}" - }) - - # Calculate epoch metrics - train_metrics = {k: v/total_samples for k, v in running_loss.items()} - val_metrics = get_metrics(model, valid_loader) - - # Update scheduler - scheduler.step(val_metrics['mse']) + for samples in tq_loader: + optimizer.zero_grad() + + # Convert and move data to device + solute_graphs = samples[0].to(device) + solvent_graphs = samples[1].to(device) + + # Handle length matrices + solute_lens = samples[2] + solvent_lens = samples[3] + if not isinstance(solute_lens, torch.Tensor): + solute_lens = torch.FloatTensor(solute_lens) + if not isinstance(solvent_lens, torch.Tensor): + solvent_lens = torch.FloatTensor(solvent_lens) + solute_lens = solute_lens.to(device) + solvent_lens = solvent_lens.to(device) + + # Handle labels + labels = samples[4] + if not isinstance(labels, torch.Tensor): + labels = torch.FloatTensor(labels) + labels = labels.to(device) + + # Forward pass + outputs, interaction_map = model([solute_graphs, solvent_graphs, solute_lens, solvent_lens]) + + # Calculate loss with L1 regularization + l1_norm = torch.norm(interaction_map, p=2) * 1e-4 + loss = loss_fn(outputs, labels) + l1_norm + + # Backward pass + loss.backward() + optimizer.step() + + # Update running loss (without regularization term) + running_loss.append((loss - l1_norm).cpu().item()) + + # Update progress bar + tq_loader.set_postfix(loss=np.mean(running_loss)) - # Store history - history['train_loss'].append(train_metrics['mse']) - history['val_loss'].append(val_metrics['mse']) - history['val_mae'].append(val_metrics['mae']) - history['lr'].append(optimizer.param_groups[0]['lr']) + # Validation phase + val_loss, mae_loss = get_metrics(model, valid_loader) + scheduler.step(val_loss) # Print epoch summary - print(f"\nEpoch {epoch+1}/{max_epochs}:") - print(f"Train MSE: {train_metrics['mse']:.4f} | Val MSE: {val_metrics['mse']:.4f}") - print(f"Train MAE: {train_metrics['mae']:.4f} | Val MAE: {val_metrics['mae']:.4f}") - print(f"LR: {history['lr'][-1]:.2e}") + print(f"\nEpoch {epoch+1}:") + print(f"Train Loss: {np.mean(running_loss):.4f}") + print(f"Val Loss: {val_loss:.4f}") + print(f"Val MAE: {mae_loss:.4f}") # Save best model - if val_metrics['mse'] < best_val_loss: - best_val_loss = val_metrics['mse'] - torch.save({ - 'epoch': epoch, - 'model_state_dict': model.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - 'scheduler_state_dict': scheduler.state_dict(), - 'loss': best_val_loss, - 'metrics': val_metrics, - 'history': history - }, f"./runs/run-{project_name}/models/best_model.tar") - print(f"New best model saved with Val MSE: {best_val_loss:.4f}") + if val_loss < best_val_loss: + best_val_loss = val_loss + torch.save(model.state_dict(), f"./runs/run-{project_name}/models/best_model.pt") + print(f"Saved new best model with Val Loss: {best_val_loss:.4f}") - return history - -def load_best_model(model, project_name): - """Load the best saved model""" - checkpoint = torch.load(f"./runs/run-{project_name}/models/best_model.tar") - model.load_state_dict(checkpoint['model_state_dict']) - return model, checkpoint['history'] + print("\nTraining completed!") if __name__ == '__main__': - # Example usage (would normally be called from main.py) print("This module contains training utilities and should be imported") From 1097a3bdd7c9b5979bf1a2a147ce863e19cffb92 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 11:55:27 +0530 Subject: [PATCH 14/39] Update model.py --- CIGIN_V2/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CIGIN_V2/model.py b/CIGIN_V2/model.py index 55b51ee..49f7737 100644 --- a/CIGIN_V2/model.py +++ b/CIGIN_V2/model.py @@ -4,6 +4,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +import dgl.function as fn + class GatherModel(nn.Module): def __init__(self, node_input_dim=42, edge_input_dim=10, From 0c19302eb30c264e191187bcf19048435c0d64fe Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 12:00:03 +0530 Subject: [PATCH 15/39] Update train.py --- CIGIN_V2/train.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/CIGIN_V2/train.py b/CIGIN_V2/train.py index 4937bc8..ae23fdb 100644 --- a/CIGIN_V2/train.py +++ b/CIGIN_V2/train.py @@ -27,7 +27,7 @@ def evaluate_model(model, dataloader): if not isinstance(labels, torch.Tensor): labels = torch.FloatTensor(labels) - outputs, _ = model([ + outputs, _ , _= model([ solute_graphs.to(device), solvent_graphs.to(device), solute_lens.to(device), @@ -65,7 +65,8 @@ def get_metrics(model, data_loader): labels = labels.to(device) # Forward pass - outputs, _ = model(inputs) + outputs, _, _ = model(inputs) + # Calculate losses loss = loss_fn(outputs, labels) @@ -117,7 +118,8 @@ def train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, p labels = labels.to(device) # Forward pass - outputs, interaction_map = model([solute_graphs, solvent_graphs, solute_lens, solvent_lens]) + outputs, _, interaction_map = model([solute_graphs, solvent_graphs, solute_lens, solvent_lens]) + # Calculate loss with L1 regularization l1_norm = torch.norm(interaction_map, p=2) * 1e-4 From 57f43735b2cb17fae6e4a16c56723ef6b4698c44 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 12:24:46 +0530 Subject: [PATCH 16/39] Update model.py --- CIGIN_V2/model.py | 143 +++++++++++++++++++++++++--------------------- 1 file changed, 78 insertions(+), 65 deletions(-) diff --git a/CIGIN_V2/model.py b/CIGIN_V2/model.py index 49f7737..84c0110 100644 --- a/CIGIN_V2/model.py +++ b/CIGIN_V2/model.py @@ -1,77 +1,101 @@ import numpy as np -from dgl import DGLGraph -from dgl.nn.pytorch import Set2Set, NNConv import torch import torch.nn as nn import torch.nn.functional as F import dgl.function as fn +from dgl.nn.pytorch import Set2Set, NNConv + + +class EnhancedEdgeNetwork(nn.Module): + def __init__(self, edge_dim): + super().__init__() + self.edge_proj = nn.Linear(edge_dim, edge_dim) + self.gate = nn.Sequential( + nn.Linear(edge_dim, 1), + nn.Sigmoid() + ) + + def forward(self, e_feat): + projected = torch.relu(self.edge_proj(e_feat)) + gate = self.gate(e_feat) + return projected * gate + + +class ResidualSet2Set(Set2Set): + def forward(self, graph, feat): + pooled = super().forward(graph, feat) + mean_feat = feat.mean(dim=0, keepdim=True).repeat(pooled.shape[0], 1) + return torch.cat([pooled, mean_feat], dim=-1) + + +def add_rwse(g, k=16): + adj = g.adjacency_matrix().to_dense() + rwpe = torch.zeros(g.num_nodes(), k) + deg = adj.sum(1) + deg[deg == 0] = 1 # Prevent div by zero + for i in range(k): + if i == 0: + rwpe[:, i] = deg + else: + rwpe[:, i] = (adj @ rwpe[:, i-1]) / deg + return rwpe class GatherModel(nn.Module): - def __init__(self, node_input_dim=42, edge_input_dim=10, - node_hidden_dim=42, edge_hidden_dim=42, + def __init__(self, node_input_dim=42, edge_input_dim=10, + node_hidden_dim=42, edge_hidden_dim=42, num_step_message_passing=6): super().__init__() - - # Enhanced edge processing - self.edge_network = nn.Sequential( - nn.Linear(edge_input_dim, edge_hidden_dim), - nn.ReLU(), - nn.LayerNorm(edge_hidden_dim), - nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim) - ) - + + self.edge_network = EnhancedEdgeNetwork(edge_input_dim) + self.conv = NNConv( in_feats=node_hidden_dim, out_feats=node_hidden_dim, edge_func=self.edge_network, - aggregator_type='mean', # Changed from sum to mean + aggregator_type='mean', residual=True ) - - # Hierarchical message passing - self.num_steps = num_step_message_passing - self.lin0 = nn.Linear(node_input_dim, node_hidden_dim) + + self.lin0 = nn.Linear(node_input_dim + 16, node_hidden_dim) # +RWSE self.message_layer = nn.Linear(2 * node_hidden_dim, node_hidden_dim) - - # Subgraph aggregation - self.subgraph_proj = nn.Linear(node_hidden_dim * 2, node_hidden_dim) + self.subgraph_proj = nn.Linear(2 * node_hidden_dim, node_hidden_dim) + self.num_steps = num_step_message_passing def forward(self, g, n_feat, e_feat): + rwse = add_rwse(g).to(n_feat.device) + n_feat = torch.cat([n_feat, rwse], dim=-1) + init = n_feat.clone() out = F.relu(self.lin0(n_feat)) - - # First-level atomic aggregation + for _ in range(self.num_steps // 2): m = torch.relu(self.conv(g, out, e_feat)) if e_feat is not None \ else torch.relu(self.conv.bias + self.conv.res_fc(out)) out = self.message_layer(torch.cat([m, out], dim=1)) - - # Second-level functional group aggregation + with g.local_scope(): g.ndata['h'] = out g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_group')) group_feat = g.ndata['h_group'] out = self.subgraph_proj(torch.cat([out, group_feat], dim=1)) - + return out + init + class CIGINModel(nn.Module): - def __init__(self, node_input_dim=42, edge_input_dim=10, + def __init__(self, node_input_dim=42, edge_input_dim=10, node_hidden_dim=42, edge_hidden_dim=42, num_step_message_passing=6, interaction='dot', num_step_set2_set=2, num_layer_set2set=1): super().__init__() - + self.node_hidden_dim = node_hidden_dim self.interaction = interaction - - # Learnable interaction scaling self.temperature = nn.Parameter(torch.tensor(1.0)) - - # Enhanced gather models + self.solute_gather = GatherModel( - node_input_dim, edge_input_dim, + node_input_dim, edge_input_dim, node_hidden_dim, edge_hidden_dim, num_step_message_passing ) @@ -80,40 +104,33 @@ def __init__(self, node_input_dim=42, edge_input_dim=10, node_hidden_dim, edge_hidden_dim, num_step_message_passing ) - - # Residual Set2Set pooling - self.set2set_solute = Set2Set(2 * node_hidden_dim, num_step_set2_set, num_layer_set2set) - self.set2set_solvent = Set2Set(2 * node_hidden_dim, num_step_set2_set, num_layer_set2set) - - # Multi-task prediction heads + + self.set2set_solute = ResidualSet2Set(2 * node_hidden_dim, num_step_set2_set, num_layer_set2set) + self.set2set_solvent = ResidualSet2Set(2 * node_hidden_dim, num_step_set2_set, num_layer_set2set) + self.fc1 = nn.Linear(8 * node_hidden_dim, 256) self.fc2 = nn.Linear(256, 128) self.fc3 = nn.Linear(128, 1) + self.imap = nn.Linear(80, 1) - - # Auxiliary prediction head + self.aux_head = nn.Sequential( nn.Linear(8 * node_hidden_dim, 64), nn.ReLU(), - nn.Linear(64, 3) # Predicts [logP, TPSA, QED] + nn.Linear(64, 3) ) def forward(self, data): solute, solvent, solute_len, solvent_len = data - - # Node embeddings - solute_features = self.solute_gather( - solute, solute.ndata['x'].float(), solute.edata['w'].float()) + + solute_features = self.solute_gather(solute, solute.ndata['x'].float(), solute.edata['w'].float()) try: - solvent_features = self.solvent_gather( - solvent, solvent.ndata['x'].float(), solvent.edata['w'].float()) + solvent_features = self.solvent_gather(solvent, solvent.ndata['x'].float(), solvent.edata['w'].float()) except: - solvent_features = self.solvent_gather( - solvent, solvent.ndata['x'].float(), None) + solvent_features = self.solvent_gather(solvent, solvent.ndata['x'].float(), None) - # Enhanced interaction phase len_map = torch.mm(solute_len.t(), solvent_len) - + if 'dot' not in self.interaction: X1 = solute_features.unsqueeze(0) Y1 = solvent_features.unsqueeze(1) @@ -121,38 +138,34 @@ def forward(self, data): X1.repeat(solvent_features.shape[0], 1, 1), Y1.repeat(1, solute_features.shape[0], 1) ], -1) - + interaction_map = self.imap(Z).squeeze(2) if self.interaction == 'tanh-general': interaction_map = torch.tanh(interaction_map) interaction_map = torch.mul(len_map.float(), interaction_map.t()) - + else: interaction_map = torch.mm(solute_features, solvent_features.t()) if 'scaled' in self.interaction: interaction_map = interaction_map / (self.temperature.abs() + 1e-8) interaction_map = torch.tanh(interaction_map) interaction_map = torch.mul(len_map.float(), interaction_map) - - ret_interaction_map = torch.clone(interaction_map) + solvent_prime = torch.mm(interaction_map.t(), solute_features) solute_prime = torch.mm(interaction_map, solvent_features) - # Prediction phase with residual connections solute_features = torch.cat((solute_features, solute_prime), dim=1) solvent_features = torch.cat((solvent_features, solvent_prime), dim=1) - + solute_features = self.set2set_solute(solute, solute_features) solvent_features = self.set2set_solvent(solvent, solvent_features) - + final_features = torch.cat((solute_features, solvent_features), 1) - - # Main prediction + main_pred = F.relu(self.fc1(final_features)) main_pred = F.relu(self.fc2(main_pred)) main_pred = self.fc3(main_pred) - - # Auxiliary predictions + aux_pred = self.aux_head(final_features.detach()) - - return main_pred, aux_pred, ret_interaction_map + + return main_pred, aux_pred, interaction_map From 2ee917b1efc049f8314b525cf9329c4a91c3f55b Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 12:25:13 +0530 Subject: [PATCH 17/39] Update train.py --- CIGIN_V2/train.py | 226 ++++++++++++++++------------------------------ 1 file changed, 79 insertions(+), 147 deletions(-) diff --git a/CIGIN_V2/train.py b/CIGIN_V2/train.py index ae23fdb..c047900 100644 --- a/CIGIN_V2/train.py +++ b/CIGIN_V2/train.py @@ -1,157 +1,89 @@ -import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from tqdm import tqdm -from torch.optim.lr_scheduler import ReduceLROnPlateau - -# Device configuration -use_cuda = torch.cuda.is_available() -device = torch.device("cuda" if use_cuda else "cpu") - -# Loss functions -loss_fn = torch.nn.MSELoss() -mae_loss_fn = torch.nn.L1Loss() - -def evaluate_model(model, dataloader): - """Evaluation function for compatibility with original code""" - model.eval() - preds = [] - targets = [] - with torch.no_grad(): - for solute_graphs, solvent_graphs, solute_lens, solvent_lens, labels in dataloader: - # Convert inputs to tensors if they aren't already - if not isinstance(solute_lens, torch.Tensor): - solute_lens = torch.FloatTensor(solute_lens) - if not isinstance(solvent_lens, torch.Tensor): - solvent_lens = torch.FloatTensor(solvent_lens) - if not isinstance(labels, torch.Tensor): - labels = torch.FloatTensor(labels) - - outputs, _ , _= model([ - solute_graphs.to(device), - solvent_graphs.to(device), - solute_lens.to(device), - solvent_lens.to(device) - ]) - preds.extend(outputs.cpu().numpy().flatten()) - targets.extend(labels.cpu().numpy().flatten()) - return np.sqrt(np.mean((np.array(preds) - np.array(targets))**2)) - -def get_metrics(model, data_loader): - """Calculate MSE and MAE losses""" - model.eval() - valid_loss = [] - valid_mae_loss = [] - valid_outputs = [] - valid_labels = [] - - with torch.no_grad(): - for solute_graphs, solvent_graphs, solute_lens, solvent_lens, labels in data_loader: - # Ensure all inputs are tensors - if not isinstance(solute_lens, torch.Tensor): - solute_lens = torch.FloatTensor(solute_lens) - if not isinstance(solvent_lens, torch.Tensor): - solvent_lens = torch.FloatTensor(solvent_lens) - if not isinstance(labels, torch.Tensor): - labels = torch.FloatTensor(labels) - - # Move data to device - inputs = [ - solute_graphs.to(device), - solvent_graphs.to(device), - solute_lens.to(device), - solvent_lens.to(device) - ] - labels = labels.to(device) - - # Forward pass - outputs, _, _ = model(inputs) - - - # Calculate losses - loss = loss_fn(outputs, labels) - mae_loss = mae_loss_fn(outputs, labels) - - # Store results - valid_loss.append(loss.cpu().item()) - valid_mae_loss.append(mae_loss.cpu().item()) - valid_outputs.extend(outputs.cpu().numpy().flatten()) - valid_labels.extend(labels.cpu().numpy().flatten()) - - # Calculate mean losses - loss = np.mean(valid_loss) - mae_loss = np.mean(valid_mae_loss) - return loss, mae_loss +import numpy as np +import wandb + +from utils import save_model, MAE, initialize_logger + def train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, project_name): - """Main training loop""" + logger = initialize_logger(project_name) + wandb.init(project=project_name) + wandb.watch(model) + best_val_loss = float('inf') - - for epoch in range(max_epochs): + best_epoch = -1 + + for epoch in range(1, max_epochs + 1): model.train() - running_loss = [] - - # Initialize progress bar - tq_loader = tqdm(train_loader, desc=f"Epoch {epoch+1}/{max_epochs}") - - for samples in tq_loader: + train_losses = [] + + print(f"Epoch {epoch}/{max_epochs}:") + for batch in tqdm(train_loader): + solute_graphs, solvent_graphs, labels = batch + solute_lens = solute_graphs.batch_num_nodes() + solvent_lens = solvent_graphs.batch_num_nodes() + + try: + main_pred, aux_pred, interaction_map = model([solute_graphs, solvent_graphs, solute_lens, solvent_lens]) + except Exception as e: + print(f"Model forward failed: {e}") + continue + + loss_main = F.mse_loss(main_pred, labels) + loss_aux = F.mse_loss(aux_pred, labels) + loss = loss_main + 0.3 * loss_aux # weighted auxiliary loss + optimizer.zero_grad() - - # Convert and move data to device - solute_graphs = samples[0].to(device) - solvent_graphs = samples[1].to(device) - - # Handle length matrices - solute_lens = samples[2] - solvent_lens = samples[3] - if not isinstance(solute_lens, torch.Tensor): - solute_lens = torch.FloatTensor(solute_lens) - if not isinstance(solvent_lens, torch.Tensor): - solvent_lens = torch.FloatTensor(solvent_lens) - solute_lens = solute_lens.to(device) - solvent_lens = solvent_lens.to(device) - - # Handle labels - labels = samples[4] - if not isinstance(labels, torch.Tensor): - labels = torch.FloatTensor(labels) - labels = labels.to(device) - - # Forward pass - outputs, _, interaction_map = model([solute_graphs, solvent_graphs, solute_lens, solvent_lens]) - - - # Calculate loss with L1 regularization - l1_norm = torch.norm(interaction_map, p=2) * 1e-4 - loss = loss_fn(outputs, labels) + l1_norm - - # Backward pass loss.backward() optimizer.step() - - # Update running loss (without regularization term) - running_loss.append((loss - l1_norm).cpu().item()) - - # Update progress bar - tq_loader.set_postfix(loss=np.mean(running_loss)) - - # Validation phase - val_loss, mae_loss = get_metrics(model, valid_loader) - scheduler.step(val_loss) - - # Print epoch summary - print(f"\nEpoch {epoch+1}:") - print(f"Train Loss: {np.mean(running_loss):.4f}") - print(f"Val Loss: {val_loss:.4f}") - print(f"Val MAE: {mae_loss:.4f}") - - # Save best model - if val_loss < best_val_loss: - best_val_loss = val_loss - torch.save(model.state_dict(), f"./runs/run-{project_name}/models/best_model.pt") - print(f"Saved new best model with Val Loss: {best_val_loss:.4f}") - - print("\nTraining completed!") - -if __name__ == '__main__': - print("This module contains training utilities and should be imported") + + train_losses.append(loss.item()) + + avg_train_loss = np.mean(train_losses) + + model.eval() + val_losses = [] + val_maes = [] + with torch.no_grad(): + for batch in valid_loader: + solute_graphs, solvent_graphs, labels = batch + solute_lens = solute_graphs.batch_num_nodes() + solvent_lens = solvent_graphs.batch_num_nodes() + + try: + main_pred, aux_pred, interaction_map = model([solute_graphs, solvent_graphs, solute_lens, solvent_lens]) + except Exception as e: + print(f"Validation forward failed: {e}") + continue + + loss_main = F.mse_loss(main_pred, labels) + loss_aux = F.mse_loss(aux_pred, labels) + loss = loss_main + 0.3 * loss_aux + + val_losses.append(loss.item()) + val_maes.append(MAE(main_pred, labels).item()) + + avg_val_loss = np.mean(val_losses) + avg_val_mae = np.mean(val_maes) + + print(f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val MAE: {avg_val_mae:.4f}") + + wandb.log({ + "epoch": epoch, + "train_loss": avg_train_loss, + "val_loss": avg_val_loss, + "val_mae": avg_val_mae + }) + + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + best_epoch = epoch + save_model(model, optimizer, epoch, best_val_loss, project_name) + + scheduler.step() + + print(f"Training complete. Best epoch: {best_epoch} with Val Loss: {best_val_loss:.4f}") + logger.info(f"Training complete. Best epoch: {best_epoch} with Val Loss: {best_val_loss:.4f}") From a72cba28614383fbbbbd85bec3f1ff424fe6eaeb Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 12:26:57 +0530 Subject: [PATCH 18/39] Update train.py --- CIGIN_V2/train.py | 161 +++++++++++++++++++++++++++++----------------- 1 file changed, 102 insertions(+), 59 deletions(-) diff --git a/CIGIN_V2/train.py b/CIGIN_V2/train.py index c047900..6b53b29 100644 --- a/CIGIN_V2/train.py +++ b/CIGIN_V2/train.py @@ -1,89 +1,132 @@ import torch import torch.nn as nn -import torch.nn.functional as F from tqdm import tqdm -import numpy as np import wandb +import logging +import os + +# Loss functions +mae_criterion = nn.L1Loss() +mse_criterion = nn.MSELoss() + +# ---------------- Logger Setup ---------------- +def initialize_logger(log_file="training.log"): + logging.basicConfig( + filename=log_file, + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + filemode='w' + ) + logging.getLogger().addHandler(logging.StreamHandler()) + +def log_metrics_to_file(epoch, train_loss, train_mae, val_loss, val_mae): + logging.info(f"Epoch {epoch}:") + logging.info(f" Train Loss: {train_loss:.4f}, MAE: {train_mae:.4f}") + logging.info(f" Val Loss: {val_loss:.4f}, MAE: {val_mae:.4f}") + +def log_metrics_to_wandb(epoch, train_loss, train_mae, val_loss, val_mae): + wandb.log({ + "Epoch": epoch, + "Train Loss": train_loss, + "Train MAE": train_mae, + "Val Loss": val_loss, + "Val MAE": val_mae + }) + +# ---------------- Training Function ---------------- +def train(max_epochs, model, optimizer, scheduler, train_loader, val_loader, project_name): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) -from utils import save_model, MAE, initialize_logger - - -def train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, project_name): - logger = initialize_logger(project_name) wandb.init(project=project_name) wandb.watch(model) - best_val_loss = float('inf') + best_val_loss = float("inf") best_epoch = -1 for epoch in range(1, max_epochs + 1): model.train() - train_losses = [] - - print(f"Epoch {epoch}/{max_epochs}:") - for batch in tqdm(train_loader): - solute_graphs, solvent_graphs, labels = batch - solute_lens = solute_graphs.batch_num_nodes() - solvent_lens = solvent_graphs.batch_num_nodes() + epoch_loss = 0.0 + epoch_mae = 0.0 + progress = tqdm(train_loader, desc=f"Epoch {epoch}/{max_epochs}") + for batch in progress: try: - main_pred, aux_pred, interaction_map = model([solute_graphs, solvent_graphs, solute_lens, solvent_lens]) + batch = [item.to(device) if hasattr(item, 'to') else item for item in batch] + outputs, interaction_map = model(batch) + + labels = batch[-1].to(device) + + loss = mse_criterion(outputs, labels) + mae = mae_criterion(outputs, labels) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + epoch_loss += loss.item() + epoch_mae += mae.item() except Exception as e: - print(f"Model forward failed: {e}") + logging.error(f"Error in training batch: {e}") continue - loss_main = F.mse_loss(main_pred, labels) - loss_aux = F.mse_loss(aux_pred, labels) - loss = loss_main + 0.3 * loss_aux # weighted auxiliary loss + avg_train_loss = epoch_loss / len(train_loader) + avg_train_mae = epoch_mae / len(train_loader) - optimizer.zero_grad() - loss.backward() - optimizer.step() + # Validation + val_loss, val_mae = evaluate_model(model, val_loader, device) - train_losses.append(loss.item()) + # Scheduler step + scheduler.step(val_loss) - avg_train_loss = np.mean(train_losses) + # Logging + log_metrics_to_file(epoch, avg_train_loss, avg_train_mae, val_loss, val_mae) + log_metrics_to_wandb(epoch, avg_train_loss, avg_train_mae, val_loss, val_mae) - model.eval() - val_losses = [] - val_maes = [] - with torch.no_grad(): - for batch in valid_loader: - solute_graphs, solvent_graphs, labels = batch - solute_lens = solute_graphs.batch_num_nodes() - solvent_lens = solvent_graphs.batch_num_nodes() + # Save best model + if val_loss < best_val_loss: + best_val_loss = val_loss + best_epoch = epoch + torch.save(model.state_dict(), f"{project_name}_best_model.pth") - try: - main_pred, aux_pred, interaction_map = model([solute_graphs, solvent_graphs, solute_lens, solvent_lens]) - except Exception as e: - print(f"Validation forward failed: {e}") - continue + logging.info(f"Finished Epoch {epoch}: Train Loss={avg_train_loss:.4f}, Val Loss={val_loss:.4f}") - loss_main = F.mse_loss(main_pred, labels) - loss_aux = F.mse_loss(aux_pred, labels) - loss = loss_main + 0.3 * loss_aux + logging.info(f"Training complete. Best validation loss: {best_val_loss:.4f} at epoch {best_epoch}") - val_losses.append(loss.item()) - val_maes.append(MAE(main_pred, labels).item()) +# ---------------- Evaluation Functions ---------------- +def evaluate_model(model, data_loader, device, return_predictions=False): + model.eval() + total_loss = 0.0 + total_mae = 0.0 + predictions = [] + ground_truth = [] - avg_val_loss = np.mean(val_losses) - avg_val_mae = np.mean(val_maes) + with torch.no_grad(): + for batch in data_loader: + try: + batch = [item.to(device) if hasattr(item, 'to') else item for item in batch] + outputs, _ = model(batch) + labels = batch[-1].to(device) - print(f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val MAE: {avg_val_mae:.4f}") + loss = mse_criterion(outputs, labels) + mae = mae_criterion(outputs, labels) - wandb.log({ - "epoch": epoch, - "train_loss": avg_train_loss, - "val_loss": avg_val_loss, - "val_mae": avg_val_mae - }) + total_loss += loss.item() + total_mae += mae.item() - if avg_val_loss < best_val_loss: - best_val_loss = avg_val_loss - best_epoch = epoch - save_model(model, optimizer, epoch, best_val_loss, project_name) + if return_predictions: + predictions.append(outputs.cpu()) + ground_truth.append(labels.cpu()) + except Exception as e: + logging.error(f"Error in validation batch: {e}") + continue + + avg_loss = total_loss / len(data_loader) + avg_mae = total_mae / len(data_loader) - scheduler.step() + if return_predictions: + return avg_loss, avg_mae, torch.cat(predictions), torch.cat(ground_truth) + return avg_loss, avg_mae - print(f"Training complete. Best epoch: {best_epoch} with Val Loss: {best_val_loss:.4f}") - logger.info(f"Training complete. Best epoch: {best_epoch} with Val Loss: {best_val_loss:.4f}") +def get_metrics(model, data_loader, device, return_predictions=False): + return evaluate_model(model, data_loader, device, return_predictions) From 32230e5cd3817d748f0b491e3b68295880c8fc2f Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 12:28:47 +0530 Subject: [PATCH 19/39] Update main.py --- CIGIN_V2/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CIGIN_V2/main.py b/CIGIN_V2/main.py index 6c68094..18d49fb 100644 --- a/CIGIN_V2/main.py +++ b/CIGIN_V2/main.py @@ -21,7 +21,7 @@ # local imports from model import CIGINModel -from train import train, get_metrics +from train import train, evaluate_model, get_metrics, initialize_logger from molecular_graph import get_graph_from_smile from utils import * @@ -120,7 +120,7 @@ def main(): # Final evaluation model.eval() - loss, mae_loss = get_metrics(model, test_loader) + loss, mae_loss = get_metrics(model, test_loader,devices) print(f"\nFinal test set performance:") print(f"MSE Loss: {loss:.4f}") print(f"MAE Loss: {mae_loss:.4f}") From 6f1a596b58c8acc8c46728f80b763e8e2bd7b6d1 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 24 Jul 2025 12:30:11 +0530 Subject: [PATCH 20/39] logger added --- CIGIN_V2/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CIGIN_V2/main.py b/CIGIN_V2/main.py index 18d49fb..1cda8f8 100644 --- a/CIGIN_V2/main.py +++ b/CIGIN_V2/main.py @@ -114,6 +114,8 @@ def main(): # Training setup optimizer = torch.optim.Adam(model.parameters(), lr=0.001) scheduler = ReduceLROnPlateau(optimizer, patience=5, mode='min', verbose=True) + initialize_logger() + # Training loop train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, project_name) From f9631ef3e6cd5e86c4153a79afd4b8222ab67851 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 31 Jul 2025 19:26:40 +0530 Subject: [PATCH 21/39] cahning it to tesing phase --- CIGIN_V2/model.py | 209 ++++++++++++++++++++-------------------------- 1 file changed, 92 insertions(+), 117 deletions(-) diff --git a/CIGIN_V2/model.py b/CIGIN_V2/model.py index 84c0110..72b0666 100644 --- a/CIGIN_V2/model.py +++ b/CIGIN_V2/model.py @@ -1,159 +1,137 @@ import numpy as np +import math + +from dgl import DGLGraph +from dgl.nn.pytorch import Set2Set, NNConv, GATConv + import torch import torch.nn as nn import torch.nn.functional as F -import dgl.function as fn -from dgl.nn.pytorch import Set2Set, NNConv - - -class EnhancedEdgeNetwork(nn.Module): - def __init__(self, edge_dim): - super().__init__() - self.edge_proj = nn.Linear(edge_dim, edge_dim) - self.gate = nn.Sequential( - nn.Linear(edge_dim, 1), - nn.Sigmoid() - ) - - def forward(self, e_feat): - projected = torch.relu(self.edge_proj(e_feat)) - gate = self.gate(e_feat) - return projected * gate - - -class ResidualSet2Set(Set2Set): - def forward(self, graph, feat): - pooled = super().forward(graph, feat) - mean_feat = feat.mean(dim=0, keepdim=True).repeat(pooled.shape[0], 1) - return torch.cat([pooled, mean_feat], dim=-1) - - -def add_rwse(g, k=16): - adj = g.adjacency_matrix().to_dense() - rwpe = torch.zeros(g.num_nodes(), k) - deg = adj.sum(1) - deg[deg == 0] = 1 # Prevent div by zero - for i in range(k): - if i == 0: - rwpe[:, i] = deg - else: - rwpe[:, i] = (adj @ rwpe[:, i-1]) / deg - return rwpe - class GatherModel(nn.Module): - def __init__(self, node_input_dim=42, edge_input_dim=10, - node_hidden_dim=42, edge_hidden_dim=42, - num_step_message_passing=6): - super().__init__() - - self.edge_network = EnhancedEdgeNetwork(edge_input_dim) - - self.conv = NNConv( - in_feats=node_hidden_dim, - out_feats=node_hidden_dim, - edge_func=self.edge_network, - aggregator_type='mean', - residual=True - ) - - self.lin0 = nn.Linear(node_input_dim + 16, node_hidden_dim) # +RWSE + """ + Original MPNN from CIGIN paper (unchanged) + """ + def __init__(self, + node_input_dim=42, + edge_input_dim=10, + node_hidden_dim=42, + edge_hidden_dim=42, + num_step_message_passing=6, + ): + super(GatherModel, self).__init__() + self.num_step_message_passing = num_step_message_passing + self.lin0 = nn.Linear(node_input_dim, node_hidden_dim) + self.set2set = Set2Set(node_hidden_dim, 2, 1) self.message_layer = nn.Linear(2 * node_hidden_dim, node_hidden_dim) - self.subgraph_proj = nn.Linear(2 * node_hidden_dim, node_hidden_dim) - self.num_steps = num_step_message_passing + edge_network = nn.Sequential( + nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(), + nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim)) + self.conv = NNConv(in_feats=node_hidden_dim, + out_feats=node_hidden_dim, + edge_func=edge_network, + aggregator_type='sum', + residual=True + ) def forward(self, g, n_feat, e_feat): - rwse = add_rwse(g).to(n_feat.device) - n_feat = torch.cat([n_feat, rwse], dim=-1) - init = n_feat.clone() out = F.relu(self.lin0(n_feat)) - - for _ in range(self.num_steps // 2): - m = torch.relu(self.conv(g, out, e_feat)) if e_feat is not None \ - else torch.relu(self.conv.bias + self.conv.res_fc(out)) + for i in range(self.num_step_message_passing): + if e_feat is not None: + m = torch.relu(self.conv(g, out, e_feat)) + else: + m = torch.relu(self.conv.bias + self.conv.res_fc(out)) out = self.message_layer(torch.cat([m, out], dim=1)) - - with g.local_scope(): - g.ndata['h'] = out - g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_group')) - group_feat = g.ndata['h_group'] - out = self.subgraph_proj(torch.cat([out, group_feat], dim=1)) - return out + init - class CIGINModel(nn.Module): - def __init__(self, node_input_dim=42, edge_input_dim=10, - node_hidden_dim=42, edge_hidden_dim=42, - num_step_message_passing=6, interaction='dot', - num_step_set2_set=2, num_layer_set2set=1): - super().__init__() - + """ + Original CIGIN model (unchanged) + """ + def __init__(self, + node_input_dim=42, + edge_input_dim=10, + node_hidden_dim=42, + edge_hidden_dim=42, + num_step_message_passing=6, + interaction='dot', + num_step_set2_set=2, + num_layer_set2set=1, + ): + super(CIGINModel, self).__init__() + + self.node_input_dim = node_input_dim self.node_hidden_dim = node_hidden_dim + self.edge_input_dim = edge_input_dim + self.edge_hidden_dim = edge_hidden_dim + self.num_step_message_passing = num_step_message_passing self.interaction = interaction - self.temperature = nn.Parameter(torch.tensor(1.0)) - - self.solute_gather = GatherModel( - node_input_dim, edge_input_dim, - node_hidden_dim, edge_hidden_dim, - num_step_message_passing - ) - self.solvent_gather = GatherModel( - node_input_dim, edge_input_dim, - node_hidden_dim, edge_hidden_dim, - num_step_message_passing - ) - - self.set2set_solute = ResidualSet2Set(2 * node_hidden_dim, num_step_set2_set, num_layer_set2set) - self.set2set_solvent = ResidualSet2Set(2 * node_hidden_dim, num_step_set2_set, num_layer_set2set) - - self.fc1 = nn.Linear(8 * node_hidden_dim, 256) + self.solute_gather = GatherModel(self.node_input_dim, self.edge_input_dim, + self.node_hidden_dim, self.edge_input_dim, + self.num_step_message_passing, + ) + self.solvent_gather = GatherModel(self.node_input_dim, self.edge_input_dim, + self.node_hidden_dim, self.edge_input_dim, + self.num_step_message_passing, + ) + # These three are the FFNN for prediction phase + self.fc1 = nn.Linear(8 * self.node_hidden_dim, 256) self.fc2 = nn.Linear(256, 128) self.fc3 = nn.Linear(128, 1) - self.imap = nn.Linear(80, 1) - self.aux_head = nn.Sequential( - nn.Linear(8 * node_hidden_dim, 64), - nn.ReLU(), - nn.Linear(64, 3) - ) + self.num_step_set2set = num_step_set2_set + self.num_layer_set2set = num_layer_set2set + self.set2set_solute = Set2Set(2 * node_hidden_dim, self.num_step_set2set, self.num_layer_set2set) + self.set2set_solvent = Set2Set(2 * node_hidden_dim, self.num_step_set2set, self.num_layer_set2set) def forward(self, data): - solute, solvent, solute_len, solvent_len = data - + solute = data[0] + solvent = data[1] + solute_len = data[2] + solvent_len = data[3] + # node embeddings after interaction phase solute_features = self.solute_gather(solute, solute.ndata['x'].float(), solute.edata['w'].float()) try: + # if edge exists in a molecule solvent_features = self.solvent_gather(solvent, solvent.ndata['x'].float(), solvent.edata['w'].float()) except: + # if edge doesn't exist in a molecule, for example in case of water solvent_features = self.solvent_gather(solvent, solvent.ndata['x'].float(), None) + # Interaction phase len_map = torch.mm(solute_len.t(), solvent_len) if 'dot' not in self.interaction: X1 = solute_features.unsqueeze(0) Y1 = solvent_features.unsqueeze(1) - Z = torch.cat([ - X1.repeat(solvent_features.shape[0], 1, 1), - Y1.repeat(1, solute_features.shape[0], 1) - ], -1) + X2 = X1.repeat(solvent_features.shape[0], 1, 1) + Y2 = Y1.repeat(1, solute_features.shape[0], 1) + Z = torch.cat([X2, Y2], -1) - interaction_map = self.imap(Z).squeeze(2) + if self.interaction == 'general': + interaction_map = self.imap(Z).squeeze(2) if self.interaction == 'tanh-general': - interaction_map = torch.tanh(interaction_map) + interaction_map = torch.tanh(self.imap(Z)).squeeze(2) + interaction_map = torch.mul(len_map.float(), interaction_map.t()) + ret_interaction_map = torch.clone(interaction_map) - else: + elif 'dot' in self.interaction: interaction_map = torch.mm(solute_features, solvent_features.t()) if 'scaled' in self.interaction: - interaction_map = interaction_map / (self.temperature.abs() + 1e-8) + interaction_map = interaction_map / (np.sqrt(self.node_hidden_dim)) + + ret_interaction_map = torch.clone(interaction_map) + ret_interaction_map = torch.mul(len_map.float(), ret_interaction_map) interaction_map = torch.tanh(interaction_map) interaction_map = torch.mul(len_map.float(), interaction_map) solvent_prime = torch.mm(interaction_map.t(), solute_features) solute_prime = torch.mm(interaction_map, solvent_features) + # Prediction phase solute_features = torch.cat((solute_features, solute_prime), dim=1) solvent_features = torch.cat((solvent_features, solvent_prime), dim=1) @@ -161,11 +139,8 @@ def forward(self, data): solvent_features = self.set2set_solvent(solvent, solvent_features) final_features = torch.cat((solute_features, solvent_features), 1) + predictions = torch.relu(self.fc1(final_features)) + predictions = torch.relu(self.fc2(predictions)) + predictions = self.fc3(predictions) - main_pred = F.relu(self.fc1(final_features)) - main_pred = F.relu(self.fc2(main_pred)) - main_pred = self.fc3(main_pred) - - aux_pred = self.aux_head(final_features.detach()) - - return main_pred, aux_pred, interaction_map + return predictions, ret_interaction_map From 8c2f4531aa49cfb00dcb612979a0b244b7958981 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 31 Jul 2025 19:27:33 +0530 Subject: [PATCH 22/39] testing phase --- CIGIN_V2/train.py | 187 ++++++++++++++++------------------------------ 1 file changed, 64 insertions(+), 123 deletions(-) diff --git a/CIGIN_V2/train.py b/CIGIN_V2/train.py index 6b53b29..448f549 100644 --- a/CIGIN_V2/train.py +++ b/CIGIN_V2/train.py @@ -1,132 +1,73 @@ -import torch -import torch.nn as nn from tqdm import tqdm -import wandb -import logging -import os - -# Loss functions -mae_criterion = nn.L1Loss() -mse_criterion = nn.MSELoss() - -# ---------------- Logger Setup ---------------- -def initialize_logger(log_file="training.log"): - logging.basicConfig( - filename=log_file, - level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(message)s", - filemode='w' - ) - logging.getLogger().addHandler(logging.StreamHandler()) - -def log_metrics_to_file(epoch, train_loss, train_mae, val_loss, val_mae): - logging.info(f"Epoch {epoch}:") - logging.info(f" Train Loss: {train_loss:.4f}, MAE: {train_mae:.4f}") - logging.info(f" Val Loss: {val_loss:.4f}, MAE: {val_mae:.4f}") - -def log_metrics_to_wandb(epoch, train_loss, train_mae, val_loss, val_mae): - wandb.log({ - "Epoch": epoch, - "Train Loss": train_loss, - "Train MAE": train_mae, - "Val Loss": val_loss, - "Val MAE": val_mae - }) - -# ---------------- Training Function ---------------- -def train(max_epochs, model, optimizer, scheduler, train_loader, val_loader, project_name): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = model.to(device) +import torch +import numpy as np - wandb.init(project=project_name) - wandb.watch(model) +loss_fn = torch.nn.MSELoss() +mae_loss_fn = torch.nn.L1Loss() - best_val_loss = float("inf") - best_epoch = -1 +use_cuda = torch.cuda.is_available() +device = torch.device("cuda" if use_cuda else "cpu") - for epoch in range(1, max_epochs + 1): +def evaluate_model(model, dataloader): + model.eval() + preds, targets = [], [] + with torch.no_grad(): + for samples in dataloader: + outputs, _ = model([samples[0].to(device), samples[1].to(device), + samples[2].to(device), samples[3].to(device)]) + preds.extend(outputs.cpu().numpy()) + targets.extend(samples[4].numpy()) + preds, targets = np.array(preds), np.array(targets) + rmse = np.sqrt(np.mean((preds - targets) ** 2)) + return rmse + +def get_metrics(model, data_loader): + valid_outputs = [] + valid_labels = [] + valid_loss = [] + valid_mae_loss = [] + for solute_graphs, solvent_graphs, solute_lens, solvent_lens, labels in data_loader: + outputs, i_map = model( + [solute_graphs.to(device), solvent_graphs.to(device), torch.tensor(solute_lens).to(device), + torch.tensor(solvent_lens).to(device)]) + loss = loss_fn(outputs, torch.tensor(labels).to(device).float()) + mae_loss = mae_loss_fn(outputs, torch.tensor(labels).to(device).float()) + valid_outputs += outputs.cpu().detach().numpy().tolist() + valid_loss.append(loss.cpu().detach().numpy()) + valid_mae_loss.append(mae_loss.cpu().detach().numpy()) + valid_labels += labels + + loss = np.mean(np.array(valid_loss).flatten()) + mae_loss = np.mean(np.array(valid_mae_loss).flatten()) + return loss, mae_loss + + +def train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, project_name): + best_val_loss = 100 + for epoch in range(max_epochs): model.train() - epoch_loss = 0.0 - epoch_mae = 0.0 - - progress = tqdm(train_loader, desc=f"Epoch {epoch}/{max_epochs}") - for batch in progress: - try: - batch = [item.to(device) if hasattr(item, 'to') else item for item in batch] - outputs, interaction_map = model(batch) - - labels = batch[-1].to(device) - - loss = mse_criterion(outputs, labels) - mae = mae_criterion(outputs, labels) - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - epoch_loss += loss.item() - epoch_mae += mae.item() - except Exception as e: - logging.error(f"Error in training batch: {e}") - continue - - avg_train_loss = epoch_loss / len(train_loader) - avg_train_mae = epoch_mae / len(train_loader) - - # Validation - val_loss, val_mae = evaluate_model(model, val_loader, device) - - # Scheduler step + running_loss = [] + tq_loader = tqdm(train_loader) + o = {} + for samples in tq_loader: + optimizer.zero_grad() + outputs, interaction_map = model( + [samples[0].to(device), samples[1].to(device), torch.tensor(samples[2]).to(device), + torch.tensor(samples[3]).to(device)]) + l1_norm = torch.norm(interaction_map, p=2) * 1e-4 + loss = loss_fn(outputs, torch.tensor(samples[4]).to(device).float()) + l1_norm + loss.backward() + optimizer.step() + loss = loss - l1_norm + running_loss.append(loss.cpu().detach()) + tq_loader.set_description( + "Epoch: " + str(epoch + 1) + " Training loss: " + str(np.mean(np.array(running_loss)))) + model.eval() + val_loss, mae_loss = get_metrics(model, valid_loader) scheduler.step(val_loss) - - # Logging - log_metrics_to_file(epoch, avg_train_loss, avg_train_mae, val_loss, val_mae) - log_metrics_to_wandb(epoch, avg_train_loss, avg_train_mae, val_loss, val_mae) - - # Save best model + print("Epoch: " + str(epoch + 1) + " train_loss " + str(np.mean(np.array(running_loss))) + " Val_loss " + str( + val_loss) + " MAE Val_loss " + str(mae_loss)) if val_loss < best_val_loss: best_val_loss = val_loss - best_epoch = epoch - torch.save(model.state_dict(), f"{project_name}_best_model.pth") - - logging.info(f"Finished Epoch {epoch}: Train Loss={avg_train_loss:.4f}, Val Loss={val_loss:.4f}") - - logging.info(f"Training complete. Best validation loss: {best_val_loss:.4f} at epoch {best_epoch}") - -# ---------------- Evaluation Functions ---------------- -def evaluate_model(model, data_loader, device, return_predictions=False): - model.eval() - total_loss = 0.0 - total_mae = 0.0 - predictions = [] - ground_truth = [] - - with torch.no_grad(): - for batch in data_loader: - try: - batch = [item.to(device) if hasattr(item, 'to') else item for item in batch] - outputs, _ = model(batch) - labels = batch[-1].to(device) - - loss = mse_criterion(outputs, labels) - mae = mae_criterion(outputs, labels) - - total_loss += loss.item() - total_mae += mae.item() - - if return_predictions: - predictions.append(outputs.cpu()) - ground_truth.append(labels.cpu()) - except Exception as e: - logging.error(f"Error in validation batch: {e}") - continue - - avg_loss = total_loss / len(data_loader) - avg_mae = total_mae / len(data_loader) - - if return_predictions: - return avg_loss, avg_mae, torch.cat(predictions), torch.cat(ground_truth) - return avg_loss, avg_mae + torch.save(model.state_dict(), "./runs/run-" + str(project_name) + "/models/best_model.tar") -def get_metrics(model, data_loader, device, return_predictions=False): - return evaluate_model(model, data_loader, device, return_predictions) From 8f03e382750657234e66a786b20ead03ddb53147 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 31 Jul 2025 19:28:38 +0530 Subject: [PATCH 23/39] testing phase From c108b508cb42517cbbbdb4fb0e1ce1f672af14b5 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 31 Jul 2025 23:38:37 +0530 Subject: [PATCH 24/39] Create run_model.py --- scripts/run_model.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 scripts/run_model.py diff --git a/scripts/run_model.py b/scripts/run_model.py new file mode 100644 index 0000000..3114d0a --- /dev/null +++ b/scripts/run_model.py @@ -0,0 +1,17 @@ +import torch +from models import Cigin +from molecular_graph import ConstructMolecularGraph + +# Sample SMILES strings (you can change these) +solute = "CCO" # Ethanol +solvent = "O=C=O" # Carbon dioxide + +# Load model +model = Cigin().to("cuda" if torch.cuda.is_available() else "cpu") + +# Run model forward pass +with torch.no_grad(): + prediction, interaction_map = model(solute, solvent) + +print("Prediction (Solubility):", prediction.item()) +print("Interaction Map Shape:", interaction_map.shape) From 346bc001bfced9fa72711d1130b5f7e048d33df1 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 31 Jul 2025 23:58:17 +0530 Subject: [PATCH 25/39] Create train.py --- scripts/train.py | 61 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 scripts/train.py diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 0000000..abf292e --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,61 @@ +from tqdm import tqdm +import torch +import numpy as np + +loss_fn = torch.nn.MSELoss() +mae_fn = torch.nn.L1Loss() + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def get_metrics(model, data_loader): + model.eval() + all_preds, all_labels = [], [] + losses, mae_losses = [], [] + + with torch.no_grad(): + for solute, solvent, labels in tqdm(data_loader): + outputs, _ = model(solute, solvent) + labels = torch.tensor(labels).to(device).float() + + loss = loss_fn(outputs, labels) + mae = mae_fn(outputs, labels) + + losses.append(loss.item()) + mae_losses.append(mae.item()) + + all_preds += outputs.cpu().numpy().tolist() + all_labels += labels.cpu().numpy().tolist() + + return np.mean(losses), np.mean(mae_losses) + + +def train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, save_path): + best_val_loss = float('inf') + + for epoch in range(max_epochs): + model.train() + running_losses = [] + + for solute, solvent, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"): + optimizer.zero_grad() + + outputs, interaction_map = model(solute, solvent) + labels = torch.tensor(labels).to(device).float() + + l2_penalty = torch.norm(interaction_map, p=2) * 1e-4 + loss = loss_fn(outputs, labels) + l2_penalty + + loss.backward() + optimizer.step() + + running_losses.append((loss - l2_penalty).item()) + + val_loss, val_mae = get_metrics(model, valid_loader) + scheduler.step(val_loss) + + print(f"Epoch {epoch+1} | Train Loss: {np.mean(running_losses):.4f} | Val Loss: {val_loss:.4f} | MAE: {val_mae:.4f}") + + if val_loss < best_val_loss: + best_val_loss = val_loss + torch.save(model.state_dict(), save_path) + print("✅ Model saved.") From f17b14770058fcdee1e182ee28c83c75f9e3cd8c Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Thu, 31 Jul 2025 23:59:24 +0530 Subject: [PATCH 26/39] Create main.py --- scripts/main.py | 59 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 scripts/main.py diff --git a/scripts/main.py b/scripts/main.py new file mode 100644 index 0000000..dadbaf0 --- /dev/null +++ b/scripts/main.py @@ -0,0 +1,59 @@ +import pandas as pd +import torch +from torch.utils.data import DataLoader, Dataset +from torch.optim.lr_scheduler import ReduceLROnPlateau +from models import Cigin +from molecular_graph import ConstructMolecularGraph +from train import train + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class MNSolVDataset(Dataset): + def __init__(self, df): + self.data = df + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + solute = self.data.iloc[idx]['SoluteSMILES'] + solvent = self.data.iloc[idx]['SolventSMILES'] + deltaG = self.data.iloc[idx]['DeltaGsolv'] + return solute, solvent, deltaG + +def collate_fn(batch): + solute_graphs = [] + solvent_graphs = [] + labels = [] + + for solute_smiles, solvent_smiles, label in batch: + sol_graph, sol_feat = ConstructMolecularGraph(solute_smiles) + solv_graph, solv_feat = ConstructMolecularGraph(solvent_smiles) + + solute_graphs.append((sol_graph, sol_feat)) + solvent_graphs.append((solv_graph, solv_feat)) + labels.append(label) + + return solute_graphs, solvent_graphs, labels + + +def main(): + train_df = pd.read_csv("data/train.csv", sep=";") + valid_df = pd.read_csv("data/valid.csv", sep=";") + + train_set = MNSolVDataset(train_df) + valid_set = MNSolVDataset(valid_df) + + train_loader = DataLoader(train_set, batch_size=1, shuffle=True, collate_fn=collate_fn) + valid_loader = DataLoader(valid_set, batch_size=1, shuffle=False, collate_fn=collate_fn) + + model = Cigin().to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + scheduler = ReduceLROnPlateau(optimizer, patience=5, mode='min') + + train(max_epochs=100, model=model, optimizer=optimizer, + scheduler=scheduler, train_loader=train_loader, + valid_loader=valid_loader, save_path="best_model.pt") + +if __name__ == "__main__": + main() From 14734a1710a8b78d6c401abb88dc79ada40750d8 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Fri, 1 Aug 2025 00:15:40 +0530 Subject: [PATCH 27/39] Update main.py --- scripts/main.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/scripts/main.py b/scripts/main.py index dadbaf0..f1c0928 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -2,6 +2,8 @@ import torch from torch.utils.data import DataLoader, Dataset from torch.optim.lr_scheduler import ReduceLROnPlateau +from sklearn.model_selection import train_test_split + from models import Cigin from molecular_graph import ConstructMolecularGraph from train import train @@ -18,7 +20,7 @@ def __len__(self): def __getitem__(self, idx): solute = self.data.iloc[idx]['SoluteSMILES'] solvent = self.data.iloc[idx]['SolventSMILES'] - deltaG = self.data.iloc[idx]['DeltaGsolv'] + deltaG = self.data.iloc[idx]['delGsolv'] return solute, solvent, deltaG def collate_fn(batch): @@ -36,10 +38,15 @@ def collate_fn(batch): return solute_graphs, solvent_graphs, labels - def main(): - train_df = pd.read_csv("data/train.csv", sep=";") - valid_df = pd.read_csv("data/valid.csv", sep=";") + # Load and preprocess full dataset + df = pd.read_csv('https://raw.githubusercontent.com/adithyamauryakr/CIGIN-DevaLab/refs/heads/master/CIGIN_V2/data/whole_data.csv') + df.columns = df.columns.str.strip() + print("Dataset columns:", df.columns) + + # Split: 10% test, then 10% of remaining for validation + train_df, test_df = train_test_split(df, test_size=0.1, random_state=42) + train_df, valid_df = train_test_split(train_df, test_size=0.111, random_state=42) train_set = MNSolVDataset(train_df) valid_set = MNSolVDataset(valid_df) From a6295a6e502487494b3847ae26c4745f4ef41122 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Fri, 1 Aug 2025 10:05:08 +0530 Subject: [PATCH 28/39] Update main.py --- scripts/main.py | 133 ++++++++++++++++++++++++++---------------------- 1 file changed, 73 insertions(+), 60 deletions(-) diff --git a/scripts/main.py b/scripts/main.py index f1c0928..9a55e7b 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -1,66 +1,79 @@ import pandas as pd -import torch -from torch.utils.data import DataLoader, Dataset -from torch.optim.lr_scheduler import ReduceLROnPlateau -from sklearn.model_selection import train_test_split - -from models import Cigin -from molecular_graph import ConstructMolecularGraph -from train import train - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -class MNSolVDataset(Dataset): - def __init__(self, df): - self.data = df - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - solute = self.data.iloc[idx]['SoluteSMILES'] - solvent = self.data.iloc[idx]['SolventSMILES'] - deltaG = self.data.iloc[idx]['delGsolv'] - return solute, solvent, deltaG - -def collate_fn(batch): - solute_graphs = [] - solvent_graphs = [] - labels = [] - - for solute_smiles, solvent_smiles, label in batch: - sol_graph, sol_feat = ConstructMolecularGraph(solute_smiles) - solv_graph, solv_feat = ConstructMolecularGraph(solvent_smiles) - - solute_graphs.append((sol_graph, sol_feat)) - solvent_graphs.append((solv_graph, solv_feat)) - labels.append(label) - - return solute_graphs, solvent_graphs, labels - -def main(): - # Load and preprocess full dataset - df = pd.read_csv('https://raw.githubusercontent.com/adithyamauryakr/CIGIN-DevaLab/refs/heads/master/CIGIN_V2/data/whole_data.csv') +import numpy as np +from train import run_kfold_cv +import warnings +warnings.filterwarnings("ignore") + +def load_and_preprocess_data(csv_path): + """Load and preprocess the dataset""" + # Load the CSV file + df = pd.read_csv(csv_path) + + # Strip whitespace from column names as requested df.columns = df.columns.str.strip() - print("Dataset columns:", df.columns) - - # Split: 10% test, then 10% of remaining for validation - train_df, test_df = train_test_split(df, test_size=0.1, random_state=42) - train_df, valid_df = train_test_split(train_df, test_size=0.111, random_state=42) + + # Strip whitespace from string columns + for col in df.columns: + if df[col].dtype == 'object': + df[col] = df[col].str.strip() + + # Remove any rows with missing values + df = df.dropna() + + # Filter out problematic SMILES if any + df = df[df['SoluteSMILES'].str.len() > 0] + df = df[df['SolventSMILES'].str.len() > 0] + + print(f"Dataset loaded: {len(df)} samples") + print(f"Unique solutes: {df['SoluteSMILES'].nunique()}") + print(f"Unique solvents: {df['SolventSMILES'].nunique()}") + print(f"Solvation free energy range: {df['delGsolv'].min():.2f} to {df['delGsolv'].max():.2f} kcal/mol") + + return df - train_set = MNSolVDataset(train_df) - valid_set = MNSolVDataset(valid_df) - - train_loader = DataLoader(train_set, batch_size=1, shuffle=True, collate_fn=collate_fn) - valid_loader = DataLoader(valid_set, batch_size=1, shuffle=False, collate_fn=collate_fn) - - model = Cigin().to(device) - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - scheduler = ReduceLROnPlateau(optimizer, patience=5, mode='min') - - train(max_epochs=100, model=model, optimizer=optimizer, - scheduler=scheduler, train_loader=train_loader, - valid_loader=valid_loader, save_path="best_model.pt") +def main(): + """Main training function following CIGIN paper methodology""" + print("CIGIN Model Training") + print("=" * 50) + + # Check GPU availability + import torch + if torch.cuda.is_available(): + print(f"GPU Available: {torch.cuda.get_device_name(0)}") + print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") + else: + print("GPU Not Available - Using CPU") + print("=" * 50) + + # Load dataset - replace with your actual CSV path + csv_path = "https://github.com/adithyamauryakr/CIGIN-DevaLab/raw/master/CIGIN_V2/data/whole_data.csv" + + try: + # Try to load from URL first + df = pd.read_csv(csv_path) + except: + # If URL fails, try local file + print("Loading from URL failed, trying local file...") + df = pd.read_csv("whole_data.csv") + + # Preprocess data + data = load_and_preprocess_data(csv_path if 'df' in locals() else "whole_data.csv") + df = data + + # Run k-fold cross validation as described in the paper + # Paper mentions: "10-fold cross validation scheme was used to assess the model" + # "We made 5 such 10 cross validation splits and trained our model independently" + print("\nStarting 10-fold cross validation (5 independent runs)...") + + mean_rmse, std_rmse = run_kfold_cv(df, k=10, n_runs=5) + + print("\n" + "=" * 50) + print("FINAL RESULTS") + print("=" * 50) + print(f"CIGIN Model Performance:") + print(f"RMSE: {mean_rmse:.2f} ± {std_rmse:.2f} kcal/mol") + print("\nPaper reported RMSE: 0.57 ± 0.10 kcal/mol") + print("=" * 50) if __name__ == "__main__": main() From 0d89ddaa09e31733c1b442d20e10007733506044 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Fri, 1 Aug 2025 10:05:33 +0530 Subject: [PATCH 29/39] Update train.py --- scripts/train.py | 233 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 177 insertions(+), 56 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index abf292e..308be35 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,61 +1,182 @@ -from tqdm import tqdm import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import Dataset, DataLoader +import pandas as pd import numpy as np - -loss_fn = torch.nn.MSELoss() -mae_fn = torch.nn.L1Loss() - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -def get_metrics(model, data_loader): - model.eval() - all_preds, all_labels = [], [] - losses, mae_losses = [], [] - - with torch.no_grad(): - for solute, solvent, labels in tqdm(data_loader): - outputs, _ = model(solute, solvent) - labels = torch.tensor(labels).to(device).float() - - loss = loss_fn(outputs, labels) - mae = mae_fn(outputs, labels) - - losses.append(loss.item()) - mae_losses.append(mae.item()) - - all_preds += outputs.cpu().numpy().tolist() - all_labels += labels.cpu().numpy().tolist() - - return np.mean(losses), np.mean(mae_losses) - - -def train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, save_path): +from sklearn.model_selection import KFold +from models import Cigin +import warnings +warnings.filterwarnings("ignore") + +device = "cuda" if torch.cuda.is_available() else "cpu" + +class SolvationDataset(Dataset): + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + row = self.data.iloc[idx] + return { + 'solute_smiles': row['SoluteSMILES'], + 'solvent_smiles': row['SolventSMILES'], + 'target': torch.FloatTensor([row['delGsolv']]) + } + +def train_model(model, train_loader, val_loader, num_epochs=100): + """Train CIGIN model following the paper's methodology""" + # ADAM optimizer with default parameters as mentioned in paper + optimizer = optim.Adam(model.parameters(), lr=0.01) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=10) + criterion = nn.MSELoss() + best_val_loss = float('inf') - - for epoch in range(max_epochs): + + for epoch in range(num_epochs): + # Training model.train() - running_losses = [] - - for solute, solvent, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"): + train_loss = 0.0 + train_count = 0 + + for batch in train_loader: optimizer.zero_grad() - - outputs, interaction_map = model(solute, solvent) - labels = torch.tensor(labels).to(device).float() - - l2_penalty = torch.norm(interaction_map, p=2) * 1e-4 - loss = loss_fn(outputs, labels) + l2_penalty - - loss.backward() - optimizer.step() - - running_losses.append((loss - l2_penalty).item()) - - val_loss, val_mae = get_metrics(model, valid_loader) - scheduler.step(val_loss) - - print(f"Epoch {epoch+1} | Train Loss: {np.mean(running_losses):.4f} | Val Loss: {val_loss:.4f} | MAE: {val_mae:.4f}") - - if val_loss < best_val_loss: - best_val_loss = val_loss - torch.save(model.state_dict(), save_path) - print("✅ Model saved.") + + try: + solute_smiles = batch['solute_smiles'][0] + solvent_smiles = batch['solvent_smiles'][0] + target = batch['target'].to(device) + + # Forward pass + prediction, interaction_map = model(solute_smiles, solvent_smiles) + loss = criterion(prediction, target) + + # Backward pass + loss.backward() + optimizer.step() + + train_loss += loss.item() + train_count += 1 + + except Exception as e: + # Skip problematic molecules as done in the paper + continue + + # Validation + model.eval() + val_loss = 0.0 + val_count = 0 + + with torch.no_grad(): + for batch in val_loader: + try: + solute_smiles = batch['solute_smiles'][0] + solvent_smiles = batch['solvent_smiles'][0] + target = batch['target'].to(device) + + prediction, _ = model(solute_smiles, solvent_smiles) + loss = criterion(prediction, target) + + val_loss += loss.item() + val_count += 1 + + except Exception as e: + continue + + if train_count > 0 and val_count > 0: + avg_train_loss = train_loss / train_count + avg_val_loss = val_loss / val_count + + print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}') + + scheduler.step(avg_val_loss) + + # Save best model + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + torch.save(model.state_dict(), 'best_model.pth') + + # Early stopping if learning rate becomes too small + if optimizer.param_groups[0]['lr'] < 1e-5: + print("Early stopping: Learning rate too small") + break + + return best_val_loss + +def evaluate_model(model, test_loader): + """Evaluate the model and return RMSE""" + model.eval() + criterion = nn.MSELoss() + total_loss = 0.0 + count = 0 + + with torch.no_grad(): + for batch in test_loader: + try: + solute_smiles = batch['solute_smiles'][0] + solvent_smiles = batch['solvent_smiles'][0] + target = batch['target'].to(device) + + prediction, _ = model(solute_smiles, solvent_smiles) + loss = criterion(prediction, target) + + total_loss += loss.item() + count += 1 + + except Exception as e: + continue + + if count > 0: + rmse = np.sqrt(total_loss / count) + return rmse + else: + return float('inf') + +def run_kfold_cv(data, k=10, n_runs=5): + """Run k-fold cross validation as described in the paper""" + all_rmses = [] + + for run in range(n_runs): + print(f"\nRun {run+1}/{n_runs}") + kf = KFold(n_splits=k, shuffle=True, random_state=run) + run_rmses = [] + + for fold, (train_idx, test_idx) in enumerate(kf.split(data)): + print(f"Fold {fold+1}/{k}") + + train_data = data.iloc[train_idx] + test_data = data.iloc[test_idx] + + # Create datasets and loaders + train_dataset = SolvationDataset(train_data) + test_dataset = SolvationDataset(test_data) + + train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) + test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) + + # Initialize model with paper's hyperparameters + model = Cigin(node_dim=40, edge_dim=10, T=3).to(device) + + # Train model + _ = train_model(model, train_loader, test_loader) + + # Load best model and evaluate + model.load_state_dict(torch.load('best_model.pth')) + rmse = evaluate_model(model, test_loader) + + print(f"Fold {fold+1} RMSE: {rmse:.4f} kcal/mol") + run_rmses.append(rmse) + + run_avg_rmse = np.mean(run_rmses) + print(f"Run {run+1} Average RMSE: {run_avg_rmse:.4f} kcal/mol") + all_rmses.extend(run_rmses) + + overall_mean = np.mean(all_rmses) + overall_std = np.std(all_rmses) + + print(f"\nOverall Results:") + print(f"Mean RMSE: {overall_mean:.4f} ± {overall_std:.4f} kcal/mol") + + return overall_mean, overall_std From 42853df9c9bd940c7af398d35fef81acc628050f Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Fri, 1 Aug 2025 10:23:31 +0530 Subject: [PATCH 30/39] Update model.py --- CIGIN_V2/model.py | 43 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/CIGIN_V2/model.py b/CIGIN_V2/model.py index 72b0666..eb93686 100644 --- a/CIGIN_V2/model.py +++ b/CIGIN_V2/model.py @@ -1,5 +1,4 @@ import numpy as np -import math from dgl import DGLGraph from dgl.nn.pytorch import Set2Set, NNConv, GATConv @@ -8,10 +7,26 @@ import torch.nn as nn import torch.nn.functional as F + + class GatherModel(nn.Module): """ - Original MPNN from CIGIN paper (unchanged) + MPNN from + `Neural Message Passing for Quantum Chemistry ` + Parameters + ---------- + node_input_dim : int + Dimension of input node feature, default to be 42. + edge_input_dim : int + Dimension of input edge feature, default to be 10. + node_hidden_dim : int + Dimension of node feature in hidden layers, default to be 42. + edge_hidden_dim : int + Dimension of edge feature in hidden layers, default to be 128. + num_step_message_passing : int + Number of message passing steps, default to be 6. """ + def __init__(self, node_input_dim=42, edge_input_dim=10, @@ -35,20 +50,38 @@ def __init__(self, ) def forward(self, g, n_feat, e_feat): + """Returns the node embeddings after message passing phase. + Parameters + ---------- + g : DGLGraph + Input DGLGraph for molecule(s) + n_feat : tensor of dtype float32 and shape (B1, D1) + Node features. B1 for number of nodes and D1 for + the node feature size. + e_feat : tensor of dtype float32 and shape (B2, D2) + Edge features. B2 for number of edges and D2 for + the edge feature size. + Returns + ------- + res : node features + """ + init = n_feat.clone() out = F.relu(self.lin0(n_feat)) for i in range(self.num_step_message_passing): if e_feat is not None: m = torch.relu(self.conv(g, out, e_feat)) else: - m = torch.relu(self.conv.bias + self.conv.res_fc(out)) + m = torch.relu(self.conv.bias + self.conv.res_fc(out)) out = self.message_layer(torch.cat([m, out], dim=1)) return out + init + class CIGINModel(nn.Module): """ - Original CIGIN model (unchanged) + This the main class for CIGIN model """ + def __init__(self, node_input_dim=42, edge_input_dim=10, @@ -75,7 +108,7 @@ def __init__(self, self.node_hidden_dim, self.edge_input_dim, self.num_step_message_passing, ) - # These three are the FFNN for prediction phase + self.fc1 = nn.Linear(8 * self.node_hidden_dim, 256) self.fc2 = nn.Linear(256, 128) self.fc3 = nn.Linear(128, 1) From 8bbe6f986e41d1987047f1fa29c1f23a54df1308 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Fri, 1 Aug 2025 10:25:47 +0530 Subject: [PATCH 31/39] Update molecular_graph.py From b66416adc76dc561b9d7924e212a2546e6d3cd6a Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Fri, 1 Aug 2025 14:16:51 +0530 Subject: [PATCH 32/39] dgl api changed --- CIGIN_V2/molecular_graph.py | 44 ++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/CIGIN_V2/molecular_graph.py b/CIGIN_V2/molecular_graph.py index 52e745a..da908db 100644 --- a/CIGIN_V2/molecular_graph.py +++ b/CIGIN_V2/molecular_graph.py @@ -1,5 +1,5 @@ import numpy as np -from dgl import DGLGraph +import dgl from rdkit import Chem from rdkit.Chem import rdMolDescriptors as rdDesc from utils import one_of_k_encoding_unk, one_of_k_encoding @@ -24,28 +24,21 @@ def get_atom_features(atom, stereo, features, explicit_H=False): Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D]) atom_features += [int(i) for i in list("{0:06b}".format(features))] - if not explicit_H: atom_features += one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) - try: atom_features += one_of_k_encoding_unk(stereo, ['R', 'S']) atom_features += [atom.HasProp('_ChiralityPossible')] except Exception as e: - - atom_features += [False, False - ] + [atom.HasProp('_ChiralityPossible')] - + atom_features += [False, False] + [atom.HasProp('_ChiralityPossible')] return np.array(atom_features) - def get_bond_features(bond): """ Method that computes bond level features from rdkit bond object :param bond: rdkit bond object :return: bond features, 1d numpy array """ - bond_type = bond.GetBondType() bond_feats = [ bond_type == Chem.rdchem.BondType.SINGLE, bond_type == Chem.rdchem.BondType.DOUBLE, @@ -54,10 +47,8 @@ def get_bond_features(bond): bond.IsInRing() ] bond_feats += one_of_k_encoding_unk(str(bond.GetStereo()), ["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"]) - return np.array(bond_feats) - def get_graph_from_smile(molecule_smile): """ Method that constructs a molecular graph with nodes being the atoms @@ -65,32 +56,41 @@ def get_graph_from_smile(molecule_smile): :param molecule_smile: SMILE sequence :return: DGL graph object, Node features and Edge features """ - - G = DGLGraph() molecule = Chem.MolFromSmiles(molecule_smile) features = rdDesc.GetFeatureInvariants(molecule) - stereo = Chem.FindMolChiralCenters(molecule) chiral_centers = [0] * molecule.GetNumAtoms() for i in stereo: chiral_centers[i[0]] = i[1] - - G.add_nodes(molecule.GetNumAtoms()) + + # Create graph with modern DGL API + G = dgl.graph([], num_nodes=molecule.GetNumAtoms()) node_features = [] edge_features = [] + edges_src = [] + edges_dst = [] + for i in range(molecule.GetNumAtoms()): - atom_i = molecule.GetAtomWithIdx(i) atom_i_features = get_atom_features(atom_i, chiral_centers[i], features[i]) node_features.append(atom_i_features) - for j in range(molecule.GetNumAtoms()): bond_ij = molecule.GetBondBetweenAtoms(i, j) if bond_ij is not None: - G.add_edges(i, j) + edges_src.append(i) + edges_dst.append(j) bond_features_ij = get_bond_features(bond_ij) edge_features.append(bond_features_ij) - - G.ndata['x'] = torch.tensor(node_features) - G.edata['w'] = torch.tensor(edge_features) + + # Add edges to graph + if edges_src: + G.add_edges(edges_src, edges_dst) + + # MINIMAL PERFORMANCE FIX: Convert to numpy array first, then to tensor + G.ndata['x'] = torch.tensor(np.array(node_features)) + if edge_features: # Only if edges exist + G.edata['w'] = torch.tensor(np.array(edge_features)) + else: + G.edata['w'] = torch.tensor([]) # Empty tensor for molecules with no bonds + return G From b5cd40712af7d6eddeef1869d008fd3b969fd20d Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Fri, 1 Aug 2025 14:17:50 +0530 Subject: [PATCH 33/39] Update train.py --- CIGIN_V2/train.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/CIGIN_V2/train.py b/CIGIN_V2/train.py index 448f549..4179560 100644 --- a/CIGIN_V2/train.py +++ b/CIGIN_V2/train.py @@ -4,7 +4,6 @@ loss_fn = torch.nn.MSELoss() mae_loss_fn = torch.nn.L1Loss() - use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") @@ -16,7 +15,7 @@ def evaluate_model(model, dataloader): outputs, _ = model([samples[0].to(device), samples[1].to(device), samples[2].to(device), samples[3].to(device)]) preds.extend(outputs.cpu().numpy()) - targets.extend(samples[4].numpy()) + targets.extend(samples[4]) preds, targets = np.array(preds), np.array(targets) rmse = np.sqrt(np.mean((preds - targets) ** 2)) return rmse @@ -30,18 +29,20 @@ def get_metrics(model, data_loader): outputs, i_map = model( [solute_graphs.to(device), solvent_graphs.to(device), torch.tensor(solute_lens).to(device), torch.tensor(solvent_lens).to(device)]) - loss = loss_fn(outputs, torch.tensor(labels).to(device).float()) - mae_loss = mae_loss_fn(outputs, torch.tensor(labels).to(device).float()) + + # MINIMAL FIX: Convert targets to proper tensor shape + targets_tensor = torch.tensor(labels).to(device).float().view(-1, 1) + loss = loss_fn(outputs, targets_tensor) + mae_loss = mae_loss_fn(outputs, targets_tensor) + valid_outputs += outputs.cpu().detach().numpy().tolist() valid_loss.append(loss.cpu().detach().numpy()) valid_mae_loss.append(mae_loss.cpu().detach().numpy()) valid_labels += labels - loss = np.mean(np.array(valid_loss).flatten()) mae_loss = np.mean(np.array(valid_mae_loss).flatten()) return loss, mae_loss - def train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, project_name): best_val_loss = 100 for epoch in range(max_epochs): @@ -55,7 +56,11 @@ def train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, p [samples[0].to(device), samples[1].to(device), torch.tensor(samples[2]).to(device), torch.tensor(samples[3]).to(device)]) l1_norm = torch.norm(interaction_map, p=2) * 1e-4 - loss = loss_fn(outputs, torch.tensor(samples[4]).to(device).float()) + l1_norm + + # MINIMAL FIX: Convert targets to proper tensor shape + targets_tensor = torch.tensor(samples[4]).to(device).float().view(-1, 1) + loss = loss_fn(outputs, targets_tensor) + l1_norm + loss.backward() optimizer.step() loss = loss - l1_norm @@ -69,5 +74,5 @@ def train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, p val_loss) + " MAE Val_loss " + str(mae_loss)) if val_loss < best_val_loss: best_val_loss = val_loss - torch.save(model.state_dict(), "./runs/run-" + str(project_name) + "/models/best_model.tar") - + # MINIMAL FIX: Fixed path format + torch.save(model.state_dict(), "./runs/run_" + str(project_name) + "/models/best_model.tar") From 703b4d23cc8ad866eeaccd82807c8bc4d60e8584 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Fri, 1 Aug 2025 14:18:22 +0530 Subject: [PATCH 34/39] Update main.py --- CIGIN_V2/main.py | 334 +++++++++++++++++++++++++++++++---------------- 1 file changed, 220 insertions(+), 114 deletions(-) diff --git a/CIGIN_V2/main.py b/CIGIN_V2/main.py index 1cda8f8..7569400 100644 --- a/CIGIN_V2/main.py +++ b/CIGIN_V2/main.py @@ -1,131 +1,237 @@ -# python imports import pandas as pd -import warnings -import os -import argparse -from sklearn.model_selection import train_test_split import numpy as np - -# rdkit imports -from rdkit import RDLogger -from rdkit import rdBase -from rdkit import Chem - -# torch imports -from torch.utils.data import DataLoader, Dataset -from torch.optim.lr_scheduler import ReduceLROnPlateau import torch - -# dgl imports +import torch.optim as optim +from torch.utils.data import DataLoader, Dataset +from sklearn.model_selection import KFold +import os import dgl +import warnings -# local imports +# Suppress RDKit warnings +from rdkit import RDLogger +RDLogger.DisableLog('rdApp.*') + +# Import your existing modules from model import CIGINModel -from train import train, evaluate_model, get_metrics, initialize_logger +from train import train, get_metrics from molecular_graph import get_graph_from_smile -from utils import * - -# Disable logs and warnings -lg = RDLogger.logger() -lg.setLevel(RDLogger.CRITICAL) -rdBase.DisableLog('rdApp.error') -warnings.filterwarnings("ignore") - -# Argument parsing -parser = argparse.ArgumentParser() -parser.add_argument('--name', default='cigin', help="The name of the current project: default: CIGIN") -parser.add_argument('--interaction', help="type of interaction function to use: dot | scaled-dot | general | tanh-general", - default='dot') -parser.add_argument('--max_epochs', type=int, default=100, help="The max number of epochs for training") -parser.add_argument('--batch_size', type=int, default=32, help="The batch size for training") - -args = parser.parse_args() -project_name = args.name -interaction = args.interaction -max_epochs = args.max_epochs -batch_size = args.batch_size - -# Device configuration -use_cuda = torch.cuda.is_available() -device = torch.device("cuda" if use_cuda else "cpu") - -# Create output directory -if not os.path.isdir("runs/run-" + str(project_name)): - os.makedirs("./runs/run-" + str(project_name)) - os.makedirs("./runs/run-" + str(project_name) + "/models") - -def collate(samples): - solute_graphs, solvent_graphs, labels = map(list, zip(*samples)) - solute_graphs = dgl.batch(solute_graphs) - solvent_graphs = dgl.batch(solvent_graphs) - solute_len_matrix = torch.FloatTensor(get_len_matrix(solute_graphs.batch_num_nodes().tolist())) - solvent_len_matrix = torch.FloatTensor(get_len_matrix(solvent_graphs.batch_num_nodes().tolist())) - labels = torch.FloatTensor(labels) - return solute_graphs, solvent_graphs, solute_len_matrix, solvent_len_matrix, labels +from utils import get_len_matrix class SolvationDataset(Dataset): - """Custom dataset class for solvation data""" - def __init__(self, dataset): - self.dataset = dataset - + def __init__(self, dataframe): + self.data = dataframe + def __len__(self): - return len(self.dataset) - + return len(self.data) + def __getitem__(self, idx): - # Process solute - solute = self.dataset.iloc[idx]['SoluteSMILES'] - mol = Chem.MolFromSmiles(solute) - mol = Chem.AddHs(mol) - solute = Chem.MolToSmiles(mol) - solute_graph = get_graph_from_smile(solute) - - # Process solvent - solvent = self.dataset.iloc[idx]['SolventSMILES'] - mol = Chem.MolFromSmiles(solvent) - mol = Chem.AddHs(mol) - solvent = Chem.MolToSmiles(mol) - solvent_graph = get_graph_from_smile(solvent) - - delta_g = self.dataset.iloc[idx]['delGsolv'] - return solute_graph, solvent_graph, delta_g + row = self.data.iloc[idx] + + # Get molecular graphs + solute_graph = get_graph_from_smile(row['SoluteSMILES']) + solvent_graph = get_graph_from_smile(row['SolventSMILES']) + + # Get molecule lengths (number of atoms) + solute_len = solute_graph.number_of_nodes() + solvent_len = solvent_graph.number_of_nodes() + + # Get target value + target = float(row['delGsolv']) + + return solute_graph, solvent_graph, solute_len, solvent_len, target + +def collate_fn(batch): + """Custom collate function for batching molecular graphs using DGL's batch functionality""" + solute_graphs, solvent_graphs, solute_lens, solvent_lens, targets = zip(*batch) + + # Batch the graphs using DGL's batch function + batched_solute = dgl.batch(solute_graphs) + batched_solvent = dgl.batch(solvent_graphs) + + # Create length matrices as described in CIGIN paper + # The length matrix is used to mask interactions between different molecules in the batch + solute_len_matrix = get_len_matrix(solute_lens) + solvent_len_matrix = get_len_matrix(solvent_lens) + + return batched_solute, batched_solvent, solute_len_matrix, solvent_len_matrix, list(targets) def main(): - # Load and prepare data - df = pd.read_csv('https://raw.githubusercontent.com/adithyamauryakr/CIGIN-DevaLab/refs/heads/master/CIGIN_V2/data/whole_data.csv') + # Load and preprocess data + print("Loading dataset...") + + # Use the specific dataset URL you provided + dataset_url = "https://raw.githubusercontent.com/adithyamauryakr/CIGIN-DevaLab/master/CIGIN_V2/data/whole_data.csv" + + try: + print(f"Loading dataset from: {dataset_url}") + df = pd.read_csv(dataset_url) + print(f"Successfully loaded dataset from: {dataset_url}") + except Exception as e: + print(f"ERROR: Could not load dataset from {dataset_url}: {str(e)}") + print("Please check the dataset URL or network connection.") + return + + # Strip whitespace from column names first df.columns = df.columns.str.strip() - # Split data into train/valid/test (80/10/10) - train_df, test_df = train_test_split(df, test_size=0.1, random_state=42) - train_df, valid_df = train_test_split(train_df, test_size=0.111, random_state=42) - - # Create datasets and data loaders - train_dataset = SolvationDataset(train_df) - valid_dataset = SolvationDataset(valid_df) - test_dataset = SolvationDataset(test_df) - - train_loader = DataLoader(train_dataset, collate_fn=collate, batch_size=batch_size, shuffle=True) - valid_loader = DataLoader(valid_dataset, collate_fn=collate, batch_size=128) - test_loader = DataLoader(test_dataset, collate_fn=collate, batch_size=128) - - # Initialize model - model = CIGINModel(interaction=interaction) - model.to(device) - - # Training setup - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - scheduler = ReduceLROnPlateau(optimizer, patience=5, mode='min', verbose=True) - initialize_logger() - - - # Training loop - train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, project_name) - - # Final evaluation - model.eval() - loss, mae_loss = get_metrics(model, test_loader,devices) - print(f"\nFinal test set performance:") - print(f"MSE Loss: {loss:.4f}") - print(f"MAE Loss: {mae_loss:.4f}") + # Strip whitespace from all string columns + for col in df.columns: + if df[col].dtype == 'object': + df[col] = df[col].str.strip() + + print(f"Dataset loaded with {len(df)} samples") + print(f"Columns after cleaning: {df.columns.tolist()}") + + # Verify required columns exist + required_columns = ['SoluteName', 'SoluteSMILES', 'SolventName', 'SolventSMILES', 'delGsolv'] + missing_columns = [col for col in required_columns if col not in df.columns] + if missing_columns: + print(f"ERROR: Missing required columns: {missing_columns}") + print(f"Available columns: {df.columns.tolist()}") + return + + # Remove any rows with missing SMILES or target values + initial_len = len(df) + df = df.dropna(subset=['SoluteSMILES', 'SolventSMILES', 'delGsolv']) + print(f"Removed {initial_len - len(df)} rows with missing data") + print(f"Final dataset size: {len(df)} samples") + + # Hyperparameters as mentioned in CIGIN paper + node_input_dim = 42 # Based on atom features in paper + edge_input_dim = 10 # Based on bond features in paper + node_hidden_dim = 42 + edge_hidden_dim = 42 + num_step_message_passing = 6 # T=6 as mentioned in paper + interaction = 'dot' # dot product interaction as in paper + batch_size = 32 + learning_rate = 0.001 # ADAM with default parameters + max_epochs = 100 + + # Device setup + use_cuda = torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + print(f"Using device: {device}") + + # 10-fold cross validation as described in CIGIN paper + # Paper mentions: "10-fold cross validation scheme was used to assess the model due to the small size of the dataset" + # "We made 5 such 10 cross validation splits and trained our model independently on each of them" + + all_fold_results = [] + + # Run 5 independent 10-fold cross validation splits as in the paper + for run in range(5): + print(f"\n=== Cross Validation Run {run + 1}/5 ===") + + # Random split with different seed for each run + kfold = KFold(n_splits=10, shuffle=True, random_state=42 + run) + fold_results = [] + + for fold, (train_idx, val_idx) in enumerate(kfold.split(df)): + print(f"\nRun {run + 1}, Fold {fold + 1}/10") + + # Split data - 9:1 ratio as mentioned in paper + train_df = df.iloc[train_idx].reset_index(drop=True) + val_df = df.iloc[val_idx].reset_index(drop=True) + + print(f"Train samples: {len(train_df)}, Val samples: {len(val_df)}") + + # Create datasets + train_dataset = SolvationDataset(train_df) + val_dataset = SolvationDataset(val_df) + + # Create data loaders + train_loader = DataLoader(train_dataset, batch_size=batch_size, + shuffle=True, collate_fn=collate_fn) + val_loader = DataLoader(val_dataset, batch_size=batch_size, + shuffle=False, collate_fn=collate_fn) + + # Initialize model - CIGIN with set2set as it performed best in paper + model = CIGINModel( + node_input_dim=node_input_dim, + edge_input_dim=edge_input_dim, + node_hidden_dim=node_hidden_dim, + edge_hidden_dim=edge_hidden_dim, + num_step_message_passing=num_step_message_passing, + interaction=interaction, + num_step_set2_set=2, # As mentioned in paper + num_layer_set2set=1 # As mentioned in paper + ) + + # Move model to device + model.to(device) + + # Initialize optimizer and scheduler as mentioned in paper + # "ADAM optimizer with its default parameters as suggested by Kingma and Ba was used" + # "The learning rate was decreased on plateau by a factor of 10^-1 from 10^-2 to 10^-5" + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', + factor=0.1, patience=10) + + # Create directory for this run and fold - MINIMAL FIX: Fixed directory naming + run_fold_dir = f"./runs/run_{run + 1}_fold_{fold + 1}" + os.makedirs(f"{run_fold_dir}/models", exist_ok=True) + + # Train model - MINIMAL FIX: Fixed project name format + train(max_epochs, model, optimizer, scheduler, train_loader, + val_loader, f"{run + 1}_fold_{fold + 1}") + + # Get final validation metrics + model.eval() + val_loss, val_mae = get_metrics(model, val_loader) + fold_results.append({ + 'run': run + 1, + 'fold': fold + 1, + 'val_rmse': np.sqrt(val_loss), # Convert MSE to RMSE + 'val_mae': val_mae + }) + + print(f"Run {run + 1}, Fold {fold + 1} - Val RMSE: {np.sqrt(val_loss):.4f}, Val MAE: {val_mae:.4f}") + + all_fold_results.extend(fold_results) + + # Calculate average for this run + run_rmse = np.mean([r['val_rmse'] for r in fold_results]) + run_mae = np.mean([r['val_mae'] for r in fold_results]) + print(f"Run {run + 1} Average - RMSE: {run_rmse:.4f}, MAE: {run_mae:.4f}") + + # Calculate final statistics across all runs and folds + all_rmse = [r['val_rmse'] for r in all_fold_results] + all_mae = [r['val_mae'] for r in all_fold_results] + + final_rmse_mean = np.mean(all_rmse) + final_rmse_std = np.std(all_rmse) + final_mae_mean = np.mean(all_mae) + final_mae_std = np.std(all_mae) + + print(f"\n=== Final Results (5 independent 10-fold CV runs) ===") + print(f"Average RMSE: {final_rmse_mean:.4f} ± {final_rmse_std:.4f} kcal/mol") + print(f"Average MAE: {final_mae_mean:.4f} ± {final_mae_std:.4f} kcal/mol") + + # Expected result from paper: RMSE of 0.57 ± 0.10 kcal/mol + print(f"\nPaper reported RMSE: 0.57 ± 0.10 kcal/mol") + print(f"Our result RMSE: {final_rmse_mean:.2f} ± {final_rmse_std:.2f} kcal/mol") + + # Save detailed results + results_df = pd.DataFrame(all_fold_results) + results_df.to_csv("./cigin_5x10fold_cv_results.csv", index=False) + + # Save summary statistics + summary = { + 'final_rmse_mean': final_rmse_mean, + 'final_rmse_std': final_rmse_std, + 'final_mae_mean': final_mae_mean, + 'final_mae_std': final_mae_std, + 'paper_rmse': 0.57, + 'paper_rmse_std': 0.10 + } + + summary_df = pd.DataFrame([summary]) + summary_df.to_csv("./cigin_summary_results.csv", index=False) + + print(f"\nResults saved to:") + print(f"- Detailed results: ./cigin_5x10fold_cv_results.csv") + print(f"- Summary results: ./cigin_summary_results.csv") -if __name__ == '__main__': +if __name__ == "__main__": main() From 8b0dcadabae35d463e45fb4c02e51026d6a45a26 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Fri, 1 Aug 2025 17:34:58 +0530 Subject: [PATCH 35/39] Update main.py --- CIGIN_V2/main.py | 319 +++++++++++++++-------------------------------- 1 file changed, 103 insertions(+), 216 deletions(-) diff --git a/CIGIN_V2/main.py b/CIGIN_V2/main.py index 7569400..6b03138 100644 --- a/CIGIN_V2/main.py +++ b/CIGIN_V2/main.py @@ -1,237 +1,124 @@ +# python imports import pandas as pd -import numpy as np -import torch -import torch.optim as optim -from torch.utils.data import DataLoader, Dataset -from sklearn.model_selection import KFold -import os -import dgl import warnings +import os +import argparse +from sklearn.model_selection import train_test_split -# Suppress RDKit warnings +# rdkit imports from rdkit import RDLogger -RDLogger.DisableLog('rdApp.*') +from rdkit import rdBase +from rdkit import Chem + +# torch imports +from torch.utils.data import DataLoader, Dataset +from torch.optim.lr_scheduler import ReduceLROnPlateau +import torch + +# dgl imports +import dgl -# Import your existing modules +# local imports from model import CIGINModel from train import train, get_metrics from molecular_graph import get_graph_from_smile -from utils import get_len_matrix +from utils import * + +lg = RDLogger.logger() +lg.setLevel(RDLogger.CRITICAL) +rdBase.DisableLog('rdApp.error') +warnings.filterwarnings("ignore") + +parser = argparse.ArgumentParser() +parser.add_argument('--name', default='cigin', help="The name of the current project: default: CIGIN") +parser.add_argument('--interaction', help="type of interaction function to use: dot | scaled-dot | general | " + "tanh-general", default='dot') +parser.add_argument('--max_epochs', required=False, default=100, help="The max number of epochs for training") +parser.add_argument('--batch_size', required=False, default=32, help="The batch size for training") + +args = parser.parse_args() +project_name = args.name +interaction = args.interaction +max_epochs = int(args.max_epochs) +batch_size = int(args.batch_size) + +use_cuda = torch.cuda.is_available() +device = torch.device("cuda" if use_cuda else "cpu") + +if not os.path.isdir("runs/run-" + str(project_name)): + os.makedirs("./runs/run-" + str(project_name)) + os.makedirs("./runs/run-" + str(project_name) + "/models") + + +def collate(samples): + solute_graphs, solvent_graphs, labels = map(list, zip(*samples)) + solute_graphs = dgl.batch(solute_graphs) + solvent_graphs = dgl.batch(solvent_graphs) + solute_len_matrix = get_len_matrix(solute_graphs.batch_num_nodes().tolist()) + solvent_len_matrix = get_len_matrix(solvent_graphs.batch_num_nodes().tolist()) + return solute_graphs, solvent_graphs, solute_len_matrix, solvent_len_matrix, labels + + +class Dataclass(Dataset): + def __init__(self, dataset): + self.dataset = dataset + # Normalize delGsolv values + self.mean = dataset['delGsolv'].mean() + self.std = dataset['delGsolv'].std() -class SolvationDataset(Dataset): - def __init__(self, dataframe): - self.data = dataframe - def __len__(self): - return len(self.data) - + return len(self.dataset) + def __getitem__(self, idx): - row = self.data.iloc[idx] - - # Get molecular graphs - solute_graph = get_graph_from_smile(row['SoluteSMILES']) - solvent_graph = get_graph_from_smile(row['SolventSMILES']) + solute = self.dataset.iloc[idx]['SoluteSMILES'] + mol = Chem.MolFromSmiles(solute) + mol = Chem.AddHs(mol) + solute = Chem.MolToSmiles(mol) + solute_graph = get_graph_from_smile(solute) - # Get molecule lengths (number of atoms) - solute_len = solute_graph.number_of_nodes() - solvent_len = solvent_graph.number_of_nodes() + solvent = self.dataset.iloc[idx]['SolventSMILES'] + mol = Chem.MolFromSmiles(solvent) + mol = Chem.AddHs(mol) + solvent = Chem.MolToSmiles(mol) + solvent_graph = get_graph_from_smile(solvent) - # Get target value - target = float(row['delGsolv']) - - return solute_graph, solvent_graph, solute_len, solvent_len, target + delta_g = self.dataset.iloc[idx]['delGsolv'] + # Normalize delta_g + delta_g = (delta_g - self.mean) / self.std + return [solute_graph, solvent_graph, [delta_g]] -def collate_fn(batch): - """Custom collate function for batching molecular graphs using DGL's batch functionality""" - solute_graphs, solvent_graphs, solute_lens, solvent_lens, targets = zip(*batch) - - # Batch the graphs using DGL's batch function - batched_solute = dgl.batch(solute_graphs) - batched_solvent = dgl.batch(solvent_graphs) - - # Create length matrices as described in CIGIN paper - # The length matrix is used to mask interactions between different molecules in the batch - solute_len_matrix = get_len_matrix(solute_lens) - solvent_len_matrix = get_len_matrix(solvent_lens) - - return batched_solute, batched_solvent, solute_len_matrix, solvent_len_matrix, list(targets) def main(): - # Load and preprocess data - print("Loading dataset...") - - # Use the specific dataset URL you provided - dataset_url = "https://raw.githubusercontent.com/adithyamauryakr/CIGIN-DevaLab/master/CIGIN_V2/data/whole_data.csv" - - try: - print(f"Loading dataset from: {dataset_url}") - df = pd.read_csv(dataset_url) - print(f"Successfully loaded dataset from: {dataset_url}") - except Exception as e: - print(f"ERROR: Could not load dataset from {dataset_url}: {str(e)}") - print("Please check the dataset URL or network connection.") - return - - # Strip whitespace from column names first + # Load and split data + df = pd.read_csv('https://raw.githubusercontent.com/adithyamauryakr/CIGIN-DevaLab/refs/heads/master/CIGIN_V2/data/whole_data.csv') df.columns = df.columns.str.strip() - # Strip whitespace from all string columns - for col in df.columns: - if df[col].dtype == 'object': - df[col] = df[col].str.strip() - - print(f"Dataset loaded with {len(df)} samples") - print(f"Columns after cleaning: {df.columns.tolist()}") - - # Verify required columns exist - required_columns = ['SoluteName', 'SoluteSMILES', 'SolventName', 'SolventSMILES', 'delGsolv'] - missing_columns = [col for col in required_columns if col not in df.columns] - if missing_columns: - print(f"ERROR: Missing required columns: {missing_columns}") - print(f"Available columns: {df.columns.tolist()}") - return - - # Remove any rows with missing SMILES or target values - initial_len = len(df) - df = df.dropna(subset=['SoluteSMILES', 'SolventSMILES', 'delGsolv']) - print(f"Removed {initial_len - len(df)} rows with missing data") - print(f"Final dataset size: {len(df)} samples") - - # Hyperparameters as mentioned in CIGIN paper - node_input_dim = 42 # Based on atom features in paper - edge_input_dim = 10 # Based on bond features in paper - node_hidden_dim = 42 - edge_hidden_dim = 42 - num_step_message_passing = 6 # T=6 as mentioned in paper - interaction = 'dot' # dot product interaction as in paper - batch_size = 32 - learning_rate = 0.001 # ADAM with default parameters - max_epochs = 100 - - # Device setup - use_cuda = torch.cuda.is_available() - device = torch.device("cuda" if use_cuda else "cpu") - print(f"Using device: {device}") - - # 10-fold cross validation as described in CIGIN paper - # Paper mentions: "10-fold cross validation scheme was used to assess the model due to the small size of the dataset" - # "We made 5 such 10 cross validation splits and trained our model independently on each of them" - - all_fold_results = [] - - # Run 5 independent 10-fold cross validation splits as in the paper - for run in range(5): - print(f"\n=== Cross Validation Run {run + 1}/5 ===") - - # Random split with different seed for each run - kfold = KFold(n_splits=10, shuffle=True, random_state=42 + run) - fold_results = [] - - for fold, (train_idx, val_idx) in enumerate(kfold.split(df)): - print(f"\nRun {run + 1}, Fold {fold + 1}/10") - - # Split data - 9:1 ratio as mentioned in paper - train_df = df.iloc[train_idx].reset_index(drop=True) - val_df = df.iloc[val_idx].reset_index(drop=True) - - print(f"Train samples: {len(train_df)}, Val samples: {len(val_df)}") - - # Create datasets - train_dataset = SolvationDataset(train_df) - val_dataset = SolvationDataset(val_df) - - # Create data loaders - train_loader = DataLoader(train_dataset, batch_size=batch_size, - shuffle=True, collate_fn=collate_fn) - val_loader = DataLoader(val_dataset, batch_size=batch_size, - shuffle=False, collate_fn=collate_fn) - - # Initialize model - CIGIN with set2set as it performed best in paper - model = CIGINModel( - node_input_dim=node_input_dim, - edge_input_dim=edge_input_dim, - node_hidden_dim=node_hidden_dim, - edge_hidden_dim=edge_hidden_dim, - num_step_message_passing=num_step_message_passing, - interaction=interaction, - num_step_set2_set=2, # As mentioned in paper - num_layer_set2set=1 # As mentioned in paper - ) - - # Move model to device - model.to(device) - - # Initialize optimizer and scheduler as mentioned in paper - # "ADAM optimizer with its default parameters as suggested by Kingma and Ba was used" - # "The learning rate was decreased on plateau by a factor of 10^-1 from 10^-2 to 10^-5" - optimizer = optim.Adam(model.parameters(), lr=learning_rate) - scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', - factor=0.1, patience=10) - - # Create directory for this run and fold - MINIMAL FIX: Fixed directory naming - run_fold_dir = f"./runs/run_{run + 1}_fold_{fold + 1}" - os.makedirs(f"{run_fold_dir}/models", exist_ok=True) - - # Train model - MINIMAL FIX: Fixed project name format - train(max_epochs, model, optimizer, scheduler, train_loader, - val_loader, f"{run + 1}_fold_{fold + 1}") - - # Get final validation metrics - model.eval() - val_loss, val_mae = get_metrics(model, val_loader) - fold_results.append({ - 'run': run + 1, - 'fold': fold + 1, - 'val_rmse': np.sqrt(val_loss), # Convert MSE to RMSE - 'val_mae': val_mae - }) - - print(f"Run {run + 1}, Fold {fold + 1} - Val RMSE: {np.sqrt(val_loss):.4f}, Val MAE: {val_mae:.4f}") - - all_fold_results.extend(fold_results) - - # Calculate average for this run - run_rmse = np.mean([r['val_rmse'] for r in fold_results]) - run_mae = np.mean([r['val_mae'] for r in fold_results]) - print(f"Run {run + 1} Average - RMSE: {run_rmse:.4f}, MAE: {run_mae:.4f}") - - # Calculate final statistics across all runs and folds - all_rmse = [r['val_rmse'] for r in all_fold_results] - all_mae = [r['val_mae'] for r in all_fold_results] - - final_rmse_mean = np.mean(all_rmse) - final_rmse_std = np.std(all_rmse) - final_mae_mean = np.mean(all_mae) - final_mae_std = np.std(all_mae) - - print(f"\n=== Final Results (5 independent 10-fold CV runs) ===") - print(f"Average RMSE: {final_rmse_mean:.4f} ± {final_rmse_std:.4f} kcal/mol") - print(f"Average MAE: {final_mae_mean:.4f} ± {final_mae_std:.4f} kcal/mol") - - # Expected result from paper: RMSE of 0.57 ± 0.10 kcal/mol - print(f"\nPaper reported RMSE: 0.57 ± 0.10 kcal/mol") - print(f"Our result RMSE: {final_rmse_mean:.2f} ± {final_rmse_std:.2f} kcal/mol") - - # Save detailed results - results_df = pd.DataFrame(all_fold_results) - results_df.to_csv("./cigin_5x10fold_cv_results.csv", index=False) - - # Save summary statistics - summary = { - 'final_rmse_mean': final_rmse_mean, - 'final_rmse_std': final_rmse_std, - 'final_mae_mean': final_mae_mean, - 'final_mae_std': final_mae_std, - 'paper_rmse': 0.57, - 'paper_rmse_std': 0.10 - } - - summary_df = pd.DataFrame([summary]) - summary_df.to_csv("./cigin_summary_results.csv", index=False) + train_df, test_df = train_test_split(df, test_size=0.1, random_state=42) + train_df, valid_df = train_test_split(train_df, test_size=0.111, random_state=42) + + train_dataset = Dataclass(train_df) + valid_dataset = Dataclass(valid_df) + test_dataset = Dataclass(test_df) + + train_loader = DataLoader(train_dataset, collate_fn=collate, batch_size=batch_size, shuffle=True) + valid_loader = DataLoader(valid_dataset, collate_fn=collate, batch_size=128) + test_loader = DataLoader(test_dataset, collate_fn=collate, batch_size=128) + + # Initialize model + model = CIGINModel(interaction=interaction) + model.to(device) - print(f"\nResults saved to:") - print(f"- Detailed results: ./cigin_5x10fold_cv_results.csv") - print(f"- Summary results: ./cigin_summary_results.csv") + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + scheduler = ReduceLROnPlateau(optimizer, patience=5, mode='min', verbose=True) + + # Train model + train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, project_name) + + # Evaluate on test data + model.eval() + loss, mae_loss = get_metrics(model, test_loader) + print(f"Model performance on the testing data: Loss: {loss}, MAE_Loss: {mae_loss}") + -if __name__ == "__main__": +if __name__ == '__main__': main() From 58bdd80a34c97905f72917580302c091d8d6a9bb Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Fri, 1 Aug 2025 17:43:47 +0530 Subject: [PATCH 36/39] Update main.py From d8b3190e8588cf6bb2b83753410d04a1eeb61fae Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Fri, 1 Aug 2025 17:53:34 +0530 Subject: [PATCH 37/39] Update molecular_graph.py From f1406af0e7e0bdce4f433b231a6230bb3d8da338 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Fri, 1 Aug 2025 17:54:38 +0530 Subject: [PATCH 38/39] Update molecular_graph.py From 27f97b50eb9851b1c6e853fb112cd1071f7709c2 Mon Sep 17 00:00:00 2001 From: R Nishanth <71981689+Nishanth-nishu@users.noreply.github.com> Date: Mon, 4 Aug 2025 10:43:57 +0530 Subject: [PATCH 39/39] removed normalization --- CIGIN_V2/main.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/CIGIN_V2/main.py b/CIGIN_V2/main.py index 6b03138..aa48b7e 100644 --- a/CIGIN_V2/main.py +++ b/CIGIN_V2/main.py @@ -62,9 +62,6 @@ def collate(samples): class Dataclass(Dataset): def __init__(self, dataset): self.dataset = dataset - # Normalize delGsolv values - self.mean = dataset['delGsolv'].mean() - self.std = dataset['delGsolv'].std() def __len__(self): return len(self.dataset) @@ -84,7 +81,6 @@ def __getitem__(self, idx): delta_g = self.dataset.iloc[idx]['delGsolv'] # Normalize delta_g - delta_g = (delta_g - self.mean) / self.std return [solute_graph, solvent_graph, [delta_g]]