From 564df0caff30657b23d7c35f1a6d120ca1ab41a9 Mon Sep 17 00:00:00 2001 From: Alexandra Lauren Day Date: Fri, 1 Sep 2023 15:35:01 -0700 Subject: [PATCH] Files from summer project. --- .../data/spectra_probies_test_data_info.py | 23 ++ .../data/spectra_probies_train_data_info.py | 23 ++ .../data/spectra_probies_val_data_info.py | 23 ++ .../HRRL/images_to_scalers_grid_search.py | 176 +++++++++++++ .../HRRL/images_to_spectra_grid_search.py | 247 ++++++++++++++++++ .../HRRL/make_image_spectra_numpy_file.py | 99 +++++++ .../models/probiesNetLBANN_grid_search.py | 51 ++++ .../HRRL/models/probiesNet_HRRL_arch.py | 33 +++ .../HRRL/read_in_pytorch_sample_model.py | 147 +++++++++++ python/lbann/contrib/hyperparameter.py | 9 + 10 files changed, 831 insertions(+) create mode 100644 applications/physics/HRRL/data/spectra_probies_test_data_info.py create mode 100644 applications/physics/HRRL/data/spectra_probies_train_data_info.py create mode 100644 applications/physics/HRRL/data/spectra_probies_val_data_info.py create mode 100644 applications/physics/HRRL/images_to_scalers_grid_search.py create mode 100644 applications/physics/HRRL/images_to_spectra_grid_search.py create mode 100644 applications/physics/HRRL/make_image_spectra_numpy_file.py create mode 100644 applications/physics/HRRL/models/probiesNetLBANN_grid_search.py create mode 100644 applications/physics/HRRL/models/probiesNet_HRRL_arch.py create mode 100644 applications/physics/HRRL/read_in_pytorch_sample_model.py diff --git a/applications/physics/HRRL/data/spectra_probies_test_data_info.py b/applications/physics/HRRL/data/spectra_probies_test_data_info.py new file mode 100644 index 00000000000..356ca91e9ec --- /dev/null +++ b/applications/physics/HRRL/data/spectra_probies_test_data_info.py @@ -0,0 +1,23 @@ +import numpy as np +import os + +data_file = '/p/vast1/lbann/datasets/HRRL/images_spectra_test_set.npy' + +sample_size = 90201 #size of one sample, flattened (For PROBIES, 300x300 image + 201 spectra array) +nsamples = 701 + +samples = None + +# Sample access functions +def get_sample(index): + global samples + if samples is None: + samples_raw = np.load(data_file, mmap_mode='r', allow_pickle=True) + samples = np.reshape(samples_raw,(nsamples,sample_size)) + return samples[index] + +def num_samples(): + return nsamples + +def sample_dims(): + return [sample_size] \ No newline at end of file diff --git a/applications/physics/HRRL/data/spectra_probies_train_data_info.py b/applications/physics/HRRL/data/spectra_probies_train_data_info.py new file mode 100644 index 00000000000..cf283d52d41 --- /dev/null +++ b/applications/physics/HRRL/data/spectra_probies_train_data_info.py @@ -0,0 +1,23 @@ +import numpy as np +import os + +data_file = '/p/vast1/lbann/datasets/HRRL/images_spectra_train_set.npy' + +sample_size = 90201 #size of one sample, flattened (For PROBIES, 300x300 image + 201 spectra array) +nsamples = 5670 + +samples = None + +# Sample access functions +def get_sample(index): + global samples + if samples is None: + samples_raw = np.load(data_file, mmap_mode='r', allow_pickle=True) + samples = np.reshape(samples_raw,(nsamples,sample_size)) + return samples[index] + +def num_samples(): + return nsamples + +def sample_dims(): + return [sample_size] \ No newline at end of file diff --git a/applications/physics/HRRL/data/spectra_probies_val_data_info.py b/applications/physics/HRRL/data/spectra_probies_val_data_info.py new file mode 100644 index 00000000000..1f027b3b7c5 --- /dev/null +++ b/applications/physics/HRRL/data/spectra_probies_val_data_info.py @@ -0,0 +1,23 @@ +import numpy as np +import os + +data_file = '/p/vast1/lbann/datasets/HRRL/images_spectra_val_set.npy' + +sample_size = 90201 #size of one sample, flattened (For PROBIES, 300x300 image + 201 spectra array) +nsamples = 631 + +samples = None + +# Sample access functions +def get_sample(index): + global samples + if samples is None: + samples_raw = np.load(data_file, mmap_mode='r', allow_pickle=True) + samples = np.reshape(samples_raw,(nsamples,sample_size)) + return samples[index] + +def num_samples(): + return nsamples + +def sample_dims(): + return [sample_size] \ No newline at end of file diff --git a/applications/physics/HRRL/images_to_scalers_grid_search.py b/applications/physics/HRRL/images_to_scalers_grid_search.py new file mode 100644 index 00000000000..6424406a01d --- /dev/null +++ b/applications/physics/HRRL/images_to_scalers_grid_search.py @@ -0,0 +1,176 @@ +import lbann +import torch +import lbann.torch +import lbann.contrib.launcher +import lbann.contrib.args +import google.protobuf.text_format as txtf +import lbann.contrib.hyperparameter as hyper +import argparse +import sys +import os +from os.path import abspath, dirname, join + +# ============================================== +# Setup +# ============================================== + +# Debugging +torch._dynamo.config.verbose=True + +# Command-line arguments +desc = ('Construct and a grid search on HRRL PROBIES data. ') +parser = argparse.ArgumentParser(description=desc) +lbann.contrib.args.add_scheduler_arguments(parser) +parser.add_argument( + '--job-name', action='store', default='probiesNet', type=str, + help='scheduler job name (default: probiesNet)') +parser.add_argument( + '--mini-batch-size', action='store', default=32, type=int, + help='mini-batch size (default: 32)', metavar='NUM') +parser.add_argument( + '--reader-prototext', action='store', default='probies_v2.prototext', type=str, + help='data to use (default: probies_v2.prototext, 20K data)') +parser.add_argument( + '--num-epochs', action='store', default=100, type=int, + help='number of epochs (default: 100)', metavar='NUM') + +# Add reader prototext +lbann.contrib.args.add_optimizer_arguments(parser) +args = parser.parse_args() + +# Default data reader +cur_dir = dirname(abspath(__file__)) +data_reader_prototext = join(cur_dir, + 'data', + args.reader_prototext) + +print("DATA READER ", data_reader_prototext) + +# Make script with 2 nodes +script = lbann.launcher.make_batch_script(nodes=2, procs_per_node=4) + +# Run experiment +kwargs = lbann.contrib.args.get_scheduler_kwargs(args) + +# Define the model for the search +def make_search_model( + learning_rate, + beta1, + beta2, + eps, + intermed_fc_layers, + activation, + dropout_percent, + num_labels=5): + + import models.probiesNetLBANN_grid_search as model + images = lbann.Input(data_field='samples') + responses = lbann.Input(data_field='responses') + num_labels = 5 + images = lbann.Reshape(images, dims=[1, 300, 300]) + pred = model.PROBIESNetLBANN(num_labels, intermed_fc_layers, activation, dropout_percent)(images) + + # lbann slice layer + pred_slice = lbann.Slice(pred, axis=0, slice_points=[0,1,2,3,4,5]) + response_slice = lbann.Slice(responses, axis=0, slice_points=[0,1,2,3,4,5]) + + # ============================================== + # Metrics + # ============================================== + + # MSE loss between responses and preds + mse = lbann.MeanSquaredError([responses, pred]) + + # Responses + epmax_response = lbann.Identity(response_slice) + etot_response = lbann.Identity(response_slice) + n_response = lbann.Identity(response_slice) + t_response = lbann.Identity(response_slice) + alpha_response = lbann.Identity(response_slice) + + # Preds + epmax_pred = lbann.Identity(pred_slice) + etot_pred = lbann.Identity(pred_slice) + n_pred = lbann.Identity(pred_slice) + t_pred = lbann.Identity(pred_slice) + alpha_pred = lbann.Identity(pred_slice) + + # MSEs + mse_epmax = lbann.MeanSquaredError([epmax_response, epmax_pred]) + mse_etot = lbann.MeanSquaredError([etot_response, etot_pred]) + mse_n = lbann.MeanSquaredError([n_response, n_pred]) + mse_t = lbann.MeanSquaredError([t_response, t_pred]) + mse_alpha = lbann.MeanSquaredError([alpha_response, alpha_pred]) + + layers = list(lbann.traverse_layer_graph([images, responses])) + + # Append metrics + metrics = [lbann.Metric(mse, name='mse')] + metrics.append(lbann.Metric(mse_epmax, name='mse_epmax')) + metrics.append(lbann.Metric(mse_etot, name='mse_etot')) + metrics.append(lbann.Metric(mse_n, name='mse_n')) + metrics.append(lbann.Metric(mse_t, name='mse_t')) + metrics.append(lbann.Metric(mse_alpha, name='mse_alpha')) + + callbacks = [lbann.CallbackPrint(), + lbann.CallbackTimer()] + + layers = list(lbann.traverse_layer_graph([images, responses])) + + model = lbann.Model(args.num_epochs, + layers=layers, + objective_function=mse, + metrics=metrics, + callbacks=callbacks) + + # Setup optimizer + opt = lbann.Adam(learn_rate=learning_rate,beta1=beta1,beta2=beta2,eps=eps) + + # Load data reader from prototext + data_reader_proto = lbann.lbann_pb2.LbannPB() + with open(data_reader_prototext, 'r') as f: + txtf.Merge(f.read(), data_reader_proto) + data_reader_proto = data_reader_proto.data_reader + + # Aliases for simplicity + SGD = lbann.BatchedIterativeOptimizer + RPE = lbann.RandomPairwiseExchange + + # Construct the local training algorithm + local_sgd = SGD("local sgd", num_iterations=10) + + # Construct the metalearning strategy. + meta_learning = RPE( + metric_strategies={'mse': RPE.MetricStrategy.LOWER_IS_BETTER}) + + # Setup vanilla trainer + trainer = lbann.Trainer(mini_batch_size=args.mini_batch_size) + return model, opt, data_reader_proto, trainer + +# run the grid search using make_search_model. Options below. +hyper.grid_search( + script, + make_search_model, + use_data_store=True, # must be True for images to scalers training + procs_per_trainer=1, + learning_rate=[0.00001], + beta1=[0.9], + beta2=[0.9], + eps=[1e-8], + intermed_fc_layers = [[960,240]], + activation = [lbann.Relu], + dropout_percent = [0.7]) + +# Sample syntax for a larger search: +# hyper.grid_search( +# script, +# make_search_model, +# use_data_store=False, +# procs_per_trainer=1, +# learning_rate=[0.00001], +# beta1=[0.9,0.99], +# beta2=[0.9,0.99], +# eps=[1e-8], +# intermed_fc_layers = [[960,240],[1920,960,480,240],[480,240]], +# activation = [lbann.Relu,lbann.Softmax,lbann.LeakyRelu], +# dropout_percent = [0.3, 0.5, 0.7]) diff --git a/applications/physics/HRRL/images_to_spectra_grid_search.py b/applications/physics/HRRL/images_to_spectra_grid_search.py new file mode 100644 index 00000000000..a2dc42f77f0 --- /dev/null +++ b/applications/physics/HRRL/images_to_spectra_grid_search.py @@ -0,0 +1,247 @@ +import lbann +import torch +import lbann.torch +import lbann.contrib.launcher +import lbann.contrib.args +import google.protobuf.text_format as txtf +import lbann.contrib.hyperparameter as hyper +import argparse +import sys +import os +from os.path import abspath, dirname, join + +print() + +# ============================================== +# Setup +# ============================================== + +# Debugging +torch._dynamo.config.verbose=True + +# Command-line arguments +desc = ('Construct and run a grid search on HRRL PROBIES image and spectra data. ') +parser = argparse.ArgumentParser(description=desc) +lbann.contrib.args.add_scheduler_arguments(parser) +parser.add_argument( + '--job-name', action='store', default='probiesNet', type=str, + help='scheduler job name (default: probiesNet)') +parser.add_argument( + '--mini-batch-size', action='store', default=32, type=int, + help='mini-batch size (default: 32)', metavar='NUM') +parser.add_argument( + '--num-epochs', action='store', default=100, type=int, + help='number of epochs (default: 100)', metavar='NUM') + +lbann.contrib.args.add_optimizer_arguments(parser) +args = parser.parse_args() + +# Make script with 2 nodes +script = lbann.launcher.make_batch_script(nodes=2, procs_per_node=4) + +# Run experiment +kwargs = lbann.contrib.args.get_scheduler_kwargs(args) + +def make_search_model( + learning_rate, + beta1, + beta2, + eps, + intermed_fc_layers, + activation, + dropout_percent, + num_labels=201): + + # ============================================== + # Make TRAIN data reader + # ============================================== + + cur_dir = dirname(abspath(__file__)) + + def construct_train_data_reader(): + import os.path + module_file = os.path.abspath(__file__) + module_name = os.path.splitext(os.path.basename(module_file))[0] + module_dir = os.path.dirname(module_file) + + # Base data reader message + message = lbann.reader_pb2.Reader() + + # Training set data reader + data_reader = message + data_reader.name = 'python' + data_reader.role = 'train' + data_reader.shuffle = True + data_reader.fraction_of_data_to_use = 1.0 + data_reader.validation_fraction = 0.0 + data_reader.python.module = 'data.spectra_probies_train_data_info' + data_reader.python.module_dir = module_dir + data_reader.python.sample_function = 'get_sample' + data_reader.python.num_samples_function = 'num_samples' + data_reader.python.sample_dims_function = 'sample_dims' + + return message + + # ============================================== + # Make VAL data reader + # ============================================== + + + def construct_val_data_reader(): + """Construct Protobuf message for Python data reader. + + The Python data reader will import this Python file to access the + sample access functions. + + """ + import os.path + module_file = os.path.abspath(__file__) + module_name = os.path.splitext(os.path.basename(module_file))[0] + module_dir = os.path.dirname(module_file) + + # Base data reader message + message = lbann.reader_pb2.Reader() + + # Training set data reader + data_reader = message + data_reader.name = 'python' + data_reader.role = 'validation' + data_reader.shuffle = True + data_reader.fraction_of_data_to_use = 1.0 + data_reader.validation_fraction = 0 + data_reader.python.module = 'data.spectra_probies_val_data_info' + data_reader.python.module_dir = module_dir + data_reader.python.sample_function = 'get_sample' + data_reader.python.num_samples_function = 'num_samples' + data_reader.python.sample_dims_function = 'sample_dims' + + return message + + # ============================================== + # Make TESTING data reader + # ============================================== + + + def construct_test_data_reader(): + """Construct Protobuf message for Python data reader. + + The Python data reader will import this Python file to access the + sample access functions. + + """ + import os.path + module_file = os.path.abspath(__file__) + module_name = os.path.splitext(os.path.basename(module_file))[0] + module_dir = os.path.dirname(module_file) + + # Base data reader message + message = lbann.reader_pb2.Reader() + + # Training set data reader + data_reader = message + data_reader.name = 'python' + data_reader.role = 'test' + data_reader.shuffle = True + data_reader.fraction_of_data_to_use = 1.0 + data_reader.validation_fraction = 0 + data_reader.python.module = 'data.spectra_probies_test_data_info' + data_reader.python.module_dir = module_dir + data_reader.python.sample_function = 'get_sample' + data_reader.python.num_samples_function = 'num_samples' + data_reader.python.sample_dims_function = 'sample_dims' + + return message + + import models.probiesNetLBANN_grid_search as model + + images_and_spectra = lbann.Input(data_field='samples') + split_results = lbann.Slice(images_and_spectra, axis=0, slice_points=[0,90000,90201]) #should be between the images and spectra + + images = lbann.Identity(split_results) + responses = lbann.Identity(split_results) + + num_labels = 201 + + images = lbann.Reshape(images, dims=[1, 300, 300]) + + pred = model.PROBIESNetLBANN(num_labels, intermed_fc_layers, activation, dropout_percent)(images) + + + # ============================================== + # Metrics + # ============================================== + + # MSE loss between responses and preds + mse = lbann.MeanSquaredError([responses, pred]) + + layers = list(lbann.traverse_layer_graph([images, responses])) + + # Append metrics + metrics = [lbann.Metric(mse, name='mse')] + + callbacks = [lbann.CallbackPrint(), + lbann.CallbackTimer()] + # for printing the results. Sample syntax below for dumping + # multiple layer outputs. Layer names must be checked for the + # particular model in question. + # lbann.CallbackDumpOutputs(layers='layer4 layer25 layer46 layer67 pred_out_instance1',execution_modes='test')] + + layers = list(lbann.traverse_layer_graph([images, responses])) + + model = lbann.Model(args.num_epochs, + layers=layers, + objective_function=mse, + metrics=metrics, + callbacks=callbacks) + + # Setup optimizer + opt = lbann.Adam(learn_rate=learning_rate,beta1=beta1,beta2=beta2,eps=eps) + + train_reader = construct_train_data_reader() + val_reader = construct_val_data_reader() + test_reader = construct_test_data_reader() + python_reader = lbann.reader_pb2.DataReader(reader=[train_reader, val_reader, test_reader]) + + # Aliases for simplicity + SGD = lbann.BatchedIterativeOptimizer + RPE = lbann.RandomPairwiseExchange + + # Construct the local training algorithm + local_sgd = SGD("local sgd", num_iterations=10) + + # Construct the metalearning strategy. + meta_learning = RPE( + metric_strategies={'mse': RPE.MetricStrategy.LOWER_IS_BETTER}) + + # Setup vanilla trainer + trainer = lbann.Trainer(mini_batch_size=args.mini_batch_size) + return model, opt, python_reader, trainer + +# Run the grid search using make_search_model. Options below. +hyper.grid_search( + script, + make_search_model, + use_data_store=False, # must be False for images to spectra training + procs_per_trainer=1, + learning_rate=[0.00001], + beta1=[0.9], + beta2=[0.9], + eps=[1e-8], + intermed_fc_layers = [[960,240]], + activation = [lbann.Relu], + dropout_percent = [0.7]) + +# Sample syntax for a larger search: +# hyper.grid_search( +# script, +# make_search_model, +# use_data_store=False, +# procs_per_trainer=1, +# learning_rate=[0.00001], +# beta1=[0.9,0.99], +# beta2=[0.9,0.99], +# eps=[1e-8], +# intermed_fc_layers = [[960,240],[1920,960,480,240],[480,240]], +# activation = [lbann.Relu,lbann.Softmax,lbann.LeakyRelu], +# dropout_percent = [0.3, 0.5, 0.7]) + diff --git a/applications/physics/HRRL/make_image_spectra_numpy_file.py b/applications/physics/HRRL/make_image_spectra_numpy_file.py new file mode 100644 index 00000000000..e924b91ed97 --- /dev/null +++ b/applications/physics/HRRL/make_image_spectra_numpy_file.py @@ -0,0 +1,99 @@ +#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ IMPORT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + +import h5py +import numpy as np +import os +from datetime import datetime +from sklearn import preprocessing as pp +from sklearn.preprocessing import FunctionTransformer +from sklearn import model_selection + +#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ NOTES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + +# Creates a .npy file with the images and spectra from the HRRL data set stored +# here: /usr/WS2/hrrl/Data/PROBIES/SimBased/300x300/RawFiles. Splits the data into +# testing, training, and validation sets, and saves each of these sets into a .npy file +# in the current directory. +# +# Each sample in the .npy file is a 90201x1 vector, where there are 90000 image data +# points and 201 spectra data points, in that order. +# + +#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SETUP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + +folder = '/usr/WS2/hrrl/Data/PROBIES/SimBased/300x300/RawFiles' +files = os.listdir(folder) + +#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LOAD DATA ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + +# loop over files; load; append +out_array = [] + +for i in range(len(files)): + if "h5base" in files[i]: # check that the name contains h5base + f = h5py.File(folder + '/' + files[i], 'r') + rids = f['RUN_ID'] + for tid in rids: + # check that we have the spectra + if (rids[tid]['Original_Spectrum'].shape[0] == 201): + out_array.append(np.append( + np.array(rids[tid]['Image']).flatten(), + rids[tid]['Original_Spectrum'] + )) + else: + warning('no spectrum info for' + str(tid)) + +out_array = np.array(out_array, dtype=object) + +values_raw = out_array[:,:-201] +labels_raw = out_array[:,-201:] + +print("raw value array shape", values_raw.shape) +print("raw labels array shape", labels_raw.shape) + +#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SCALE DATA ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + +# log1p scale on spectra only (labels) +transformer = FunctionTransformer(np.log1p, validate=True) +transformer.fit_transform(labels_raw) + +# minmax scale on both +scaler = pp.MinMaxScaler() +values = scaler.fit_transform(np.array(values_raw)).astype(np.float32) +labels = scaler.fit_transform(np.array(labels_raw)).astype(np.float32) + +#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SPLIT DATA ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + +X_train_val, X_test, y_train_val, y_test = model_selection.train_test_split(values, labels, test_size = 0.1) +X_train, X_val, y_train, y_val = model_selection.train_test_split(X_train_val, y_train_val, test_size = 0.1) + +#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RESHAPE AND SAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + +# create the final format +train_append = np.append(X_train, y_train, axis=1) +val_append = np.append(X_val, y_val, axis=1) +test_append = np.append(X_test, y_test, axis=1) + +train_data = train_append.flatten() +val_data = val_append.flatten() +test_data = test_append.flatten() + +# save +now = str(datetime.now()) +np.save('train_data' + str(now), train_data) +np.save('val_data' + str(now), val_data) +np.save('test_data' + str(now), test_data) + +#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ OPTIONAL CHECKS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + +print("X_train only:", X_train.shape) +print("X_val only: ", X_val.shape) +print("X_test only: ", X_test.shape) +print("y_train only: ", y_train.shape) +print("y_val only: ", y_val.shape) +print("y_test only: ", y_test.shape) +print("train append shape", train_append.shape) +print("val append shape", val_append.shape) +print("test append shape", test_append.shape) + +print("Numpy files written") \ No newline at end of file diff --git a/applications/physics/HRRL/models/probiesNetLBANN_grid_search.py b/applications/physics/HRRL/models/probiesNetLBANN_grid_search.py new file mode 100644 index 00000000000..c656e4679e7 --- /dev/null +++ b/applications/physics/HRRL/models/probiesNetLBANN_grid_search.py @@ -0,0 +1,51 @@ +import lbann +import lbann.modules + +class PROBIESNetLBANN(lbann.modules.Module): + + global_count = 0 # Static counter, used for default names + + def __init__(self, output_size, intermed_fc_layers, activation, dropout_percent, name=None): + """Initialize PROBIESNet. + + Args: + output_size (int): Size of output tensor. + name (str, optional): Module name + (default: 'probiesnet_module'). + + """ + PROBIESNetLBANN.global_count += 1 + self.instance = 0 + self.intermed_fc_layers = intermed_fc_layers + self.dropout_percent = dropout_percent + self.name = (name if name + else 'probiesNet_module{0}'.format(PROBIESNetLBANN.global_count)) + conv = lbann.modules.Convolution2dModule + fc = lbann.modules.FullyConnectedModule + self.conv1 = conv(36, 11, stride=4, activation=activation, + name=self.name+'_conv1') + self.conv2 = conv(64, 5, padding=2, activation=activation, + name=self.name+'_conv2') + for idx,layer in enumerate(intermed_fc_layers): + setattr(self, 'fc' + str(idx), fc(layer, activation=activation, name=self.name+'_fc'+str(idx))) + setattr(self, 'fc' + str((len(self.intermed_fc_layers))+1), fc(output_size, name='pred_out')) + + def forward(self, x): + self.instance += 1 + + x = self.conv1(x) + x = lbann.Pooling(x, num_dims=2, has_vectors=False, + pool_dims_i=2, pool_pads_i=0, pool_strides_i=2, + pool_mode='max', + name='{0}_pool1_instance{1}'.format(self.name,self.instance)) + x = self.conv2(x) + x = lbann.Pooling(x, num_dims=2, has_vectors=False, + pool_dims_i=2, pool_pads_i=0, pool_strides_i=2, + pool_mode='max', + name='{0}_pool2_instance{1}'.format(self.name,self.instance)) + for idx, layer in enumerate(self.intermed_fc_layers): + x = getattr(self, 'fc' + str(idx))(x) + x = lbann.Dropout(x, keep_prob=self.dropout_percent, + name='{0}_drop_search_' + str(idx) + '_instance{1}'.format(self.name,self.instance)) + x = lbann.Relu(x) + return getattr(self, 'fc' + str(len(self.intermed_fc_layers)+1))(x) diff --git a/applications/physics/HRRL/models/probiesNet_HRRL_arch.py b/applications/physics/HRRL/models/probiesNet_HRRL_arch.py new file mode 100644 index 00000000000..af8f69bb56f --- /dev/null +++ b/applications/physics/HRRL/models/probiesNet_HRRL_arch.py @@ -0,0 +1,33 @@ +import torch.nn as nn + +class PROBIESNet(nn.Module): + + global_count = 0 # Static counter, used for default names + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, stride=1, kernel_size=3, padding=0, dilation=1) + self.maxPool1 = nn.MaxPool2d(kernel_size=2,padding=0,dilation=1) + self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, stride=1, kernel_size=3, padding=0, dilation=1) + self.avgPool1 = nn.AvgPool2d(kernel_size=2,stride=2,padding=1) + self.flatten1 = nn.Flatten(start_dim=1, end_dim=-1) + self.leakyReLu1 = nn.LeakyReLU() + self.dense1 = nn.Linear(in_features=350464,out_features=128) + self.leakyReLu2 = nn.LeakyReLU() + self.dense2 = nn.Linear(in_features=128,out_features=2048) + # double-check the out_features when making changes. + self.dense3 = nn.Linear(in_features=2048, out_features=5) + + + def forward(self, x): + x = self.conv1(x) + x = self.maxPool1(x) + x = self.conv2(x) + x = self.avgPool1(x) + x = self.flatten1(x) + x = self.leakyReLu1(x) + x = self.dense1(x) + x = self.leakyReLu2(x) + x = self.dense2(x) + x = self.dense3(x) + return x diff --git a/applications/physics/HRRL/read_in_pytorch_sample_model.py b/applications/physics/HRRL/read_in_pytorch_sample_model.py new file mode 100644 index 00000000000..24ed236b6ba --- /dev/null +++ b/applications/physics/HRRL/read_in_pytorch_sample_model.py @@ -0,0 +1,147 @@ +import lbann +import torch +import lbann.torch +import lbann.contrib.launcher +import lbann.contrib.args +import google.protobuf.text_format as txtf +import lbann.contrib.hyperparameter as hyper +import argparse +import sys +import os +from os.path import abspath, dirname, join + +# ============================================== +# Setup +# ============================================== + +# Debugging +torch._dynamo.config.verbose=True + +import models.probiesNet_HRRL_arch as model + +# Command-line arguments +desc = ('Reads in the HRRL model and runs it on HRRL PROBIES data. ') +parser = argparse.ArgumentParser(description=desc) +lbann.contrib.args.add_scheduler_arguments(parser) +parser.add_argument( + '--job-name', action='store', default='probiesNet', type=str, + help='scheduler job name (default: probiesNet)') +parser.add_argument( + '--mini-batch-size', action='store', default=32, type=int, + help='mini-batch size (default: 32)', metavar='NUM') +parser.add_argument( + '--reader-prototext', action='store', default='probies_v2.prototext', type=str, + help='data to use (default: probies_v2.prototext, 20K data)') +parser.add_argument( + '--num-epochs', action='store', default=100, type=int, + help='number of epochs (default: 100)', metavar='NUM') + +# Add reader prototext +lbann.contrib.args.add_optimizer_arguments(parser) +args = parser.parse_args() + +# Default data reader +cur_dir = dirname(abspath(__file__)) +data_reader_prototext = join(cur_dir, + 'data', + args.reader_prototext) + +print("DATA READER ", data_reader_prototext) + +script = lbann.launcher.make_batch_script(nodes=2, procs_per_node=4) + +# Run experiment +kwargs = lbann.contrib.args.get_scheduler_kwargs(args) + +images = lbann.Input(data_field='samples') +num_labels = 5 +images = lbann.Reshape(images, dims=[1, 300, 300]) +responses = lbann.Input(data_field='responses') #labels + +# Initialize +mod = model.PROBIESNet() + +graph = lbann.torch.compile(mod, x=torch.rand(64,1,300,300)) # batch size 64 + +images = graph[0] +pred = graph[-1] + +# Lbann slice layer +pred_slice = lbann.Slice(pred, axis=0, slice_points=[0,1,2,3,4,5]) +response_slice = lbann.Slice(responses, axis=0, slice_points=[0,1,2,3,4,5]) + +# ============================================== +# Metrics +# ============================================== + +# MSE loss between responses and preds +mse = lbann.MeanSquaredError([responses, pred]) + +# Responses +epmax_response = lbann.Identity(response_slice) +etot_response = lbann.Identity(response_slice) +n_response = lbann.Identity(response_slice) +t_response = lbann.Identity(response_slice) +alpha_response = lbann.Identity(response_slice) + +# Preds +epmax_pred = lbann.Identity(pred_slice) +etot_pred = lbann.Identity(pred_slice) +n_pred = lbann.Identity(pred_slice) +t_pred = lbann.Identity(pred_slice) +alpha_pred = lbann.Identity(pred_slice) + +# MSEs +mse_epmax = lbann.MeanSquaredError([epmax_response, epmax_pred]) +mse_etot = lbann.MeanSquaredError([etot_response, etot_pred]) +mse_n = lbann.MeanSquaredError([n_response, n_pred]) +mse_t = lbann.MeanSquaredError([t_response, t_pred]) +mse_alpha = lbann.MeanSquaredError([alpha_response, alpha_pred]) + +layers = list(lbann.traverse_layer_graph([images, responses])) + +# Append Metrics +metrics = [lbann.Metric(mse, name='mse')] +metrics.append(lbann.Metric(mse_epmax, name='mse_epmax')) +metrics.append(lbann.Metric(mse_etot, name='mse_etot')) +metrics.append(lbann.Metric(mse_n, name='mse_n')) +metrics.append(lbann.Metric(mse_t, name='mse_t')) +metrics.append(lbann.Metric(mse_alpha, name='mse_alpha')) + +callbacks = [lbann.CallbackPrint(), + lbann.CallbackTimer()] + +layers = list(lbann.traverse_layer_graph([images, responses])) + +model = lbann.Model(args.num_epochs, + layers=layers, + objective_function=mse, + metrics=metrics, + callbacks=callbacks) + +# Setup optimizer +opt = lbann.Adam(learn_rate=0.0002,beta1=0.9,beta2=0.99,eps=1e-8) + +# Load data reader from prototext +data_reader_proto = lbann.lbann_pb2.LbannPB() +with open(data_reader_prototext, 'r') as f: + txtf.Merge(f.read(), data_reader_proto) +data_reader_proto = data_reader_proto.data_reader + +# Aliases for simplicity +SGD = lbann.BatchedIterativeOptimizer +RPE = lbann.RandomPairwiseExchange + +# Construct the local training algorithm +local_sgd = SGD("local sgd", num_iterations=10) + +# Construct the metalearning strategy +meta_learning = RPE( + metric_strategies={'mse': RPE.MetricStrategy.LOWER_IS_BETTER}) + +# Setup vanilla trainer +trainer = lbann.Trainer(mini_batch_size=args.mini_batch_size) + +lbann.contrib.launcher.run(trainer, model, data_reader_proto, opt, procs_per_trainer=1, + lbann_args=" --use_data_store --preload_data_store", job_name=args.job_name, + binary_protobuf=True, **kwargs) diff --git a/python/lbann/contrib/hyperparameter.py b/python/lbann/contrib/hyperparameter.py index 394d172b5aa..270cc4bf0d8 100644 --- a/python/lbann/contrib/hyperparameter.py +++ b/python/lbann/contrib/hyperparameter.py @@ -11,6 +11,7 @@ def grid_search( make_experiment, procs_per_trainer=1, hyperparameters_file=None, + use_data_store=False, **kwargs, ): """Run LBANN with exhaustive grid search over hyperparameter values @@ -85,6 +86,10 @@ def grid_search( f'--procs_per_trainer={procs_per_trainer}', '--generate_multi_proto', f'--prototext={work_dir}/run{run_id}.trainer0'] + if use_data_store: + command = command + [ + f'--use_data_store', + f'--preload_data_store'] script.add_parallel_command( command, nodes=num_nodes, procs_per_node=procs_per_node) @@ -106,6 +111,10 @@ def grid_search( f'--procs_per_trainer={procs_per_trainer}', '--generate_multi_proto', f'--prototext={work_dir}/run{run_id}.trainer0'] + if use_data_store: + command = command + [ + f'--use_data_store', + f'--preload_data_store'] script.add_parallel_command( command, nodes=((trainer_id+1)*procs_per_trainer) // procs_per_node,