diff --git a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/BUILD b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/BUILD index b6d0aa96..af2de1f9 100644 --- a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/BUILD +++ b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/BUILD @@ -52,6 +52,7 @@ py_strict_binary( srcs = ["mnist_e2e_sparsity2x4.py"], deps = [ # absl:app dep1, + # absl/flags dep1, # tensorflow dep1, "//tensorflow_model_optimization/python/core/keras:compat", "//tensorflow_model_optimization/python/core/keras:test_utils", diff --git a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_cnn.py b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_cnn.py index eb476863..a5a69c48 100644 --- a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_cnn.py +++ b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_cnn.py @@ -16,6 +16,10 @@ """Train a simple convnet on the MNIST dataset.""" from __future__ import print_function +import datetime +import os +import tempfile + from absl import app as absl_app from absl import flags import tensorflow as tf @@ -35,8 +39,12 @@ num_classes = 10 epochs = 12 -flags.DEFINE_string('output_dir', '/tmp/mnist_train/', - 'Output directory to hold tensorboard events') +flags.DEFINE_string( + 'output_dir', + None, + 'Output directory to hold tensorboard events and models. If None, a' + ' temporary directory is used.', +) def build_sequential_model(input_shape): @@ -94,7 +102,7 @@ def build_layerwise_model(input_shape, **pruning_params): ]) -def train_and_save(models, x_train, y_train, x_test, y_test): +def train_and_save(models, x_train, y_train, x_test, y_test, output_dir): for model in models: model.compile( loss=keras.losses.categorical_crossentropy, @@ -109,7 +117,7 @@ def train_and_save(models, x_train, y_train, x_test, y_test): # step. Also add a callback to add pruning summaries to tensorboard callbacks = [ pruning_callbacks.UpdatePruningStep(), - pruning_callbacks.PruningSummaries(log_dir=FLAGS.output_dir) + pruning_callbacks.PruningSummaries(log_dir=output_dir) ] model.fit( @@ -125,7 +133,7 @@ def train_and_save(models, x_train, y_train, x_test, y_test): print('Test accuracy:', score[1]) # Export and import the model. Check that accuracy persists. - saved_model_dir = '/tmp/saved_model' + saved_model_dir = os.path.join(output_dir, 'saved_model') print('Saving model to: ', saved_model_dir) keras.models.save_model(model, saved_model_dir, save_format='tf') print('Loading model from: ', saved_model_dir) @@ -182,8 +190,18 @@ def main(unused_argv): functional_model = prune.prune_low_magnitude( functional_model, **pruning_params) + if FLAGS.output_dir and not os.path.exists(FLAGS.output_dir): + os.makedirs(FLAGS.output_dir) + output_dir = tempfile.mkdtemp( + dir=FLAGS.output_dir, + prefix=datetime.datetime.now().strftime('tmp_%Y%m%d%H%M_'), + ) + print('All models and logs will be saved to: {}'.format(output_dir)) + models = [layerwise_model, sequential_model, functional_model] - train_and_save(models, x_train, y_train, x_test, y_test) + train_and_save( + models, x_train, y_train, x_test, y_test, output_dir=output_dir + ) if __name__ == '__main__': diff --git a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_e2e.py b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_e2e.py index adfb275d..7baea71b 100644 --- a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_e2e.py +++ b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_e2e.py @@ -16,6 +16,10 @@ """Train a simple convnet on the MNIST dataset.""" from __future__ import print_function +import datetime +import os +import tempfile + from absl import app as absl_app from absl import flags import tensorflow as tf @@ -37,6 +41,12 @@ epochs = 1 flags.DEFINE_float('sparsity', '0.0', 'Target sparsity level.') +flags.DEFINE_string( + 'output_dir', + None, + 'Output directory for models and logs. If not set, a temporary directory' + ' is used.', +) def build_layerwise_model(input_shape, **pruning_params): @@ -60,7 +70,7 @@ def build_layerwise_model(input_shape, **pruning_params): ]) -def train(model, x_train, y_train, x_test, y_test): +def train(model, x_train, y_train, x_test, y_test, output_dir): model.compile( loss=keras.losses.categorical_crossentropy, optimizer='adam', @@ -74,7 +84,7 @@ def train(model, x_train, y_train, x_test, y_test): # step. Also add a callback to add pruning summaries to tensorboard callbacks = [ pruning_callbacks.UpdatePruningStep(), - pruning_callbacks.PruningSummaries(log_dir='/tmp/logs') + pruning_callbacks.PruningSummaries(log_dir=output_dir) ] model.fit( @@ -101,6 +111,14 @@ def main(unused_argv): x_test, y_test), input_shape = keras_test_utils.get_preprocessed_mnist_data() + if FLAGS.output_dir and not os.path.exists(FLAGS.output_dir): + os.makedirs(FLAGS.output_dir) + temp_dir = tempfile.mkdtemp( + dir=FLAGS.output_dir, + prefix=datetime.datetime.now().strftime('tmp_%Y%m%d%H%M_'), + ) + print('All models and logs will be saved to: {}'.format(temp_dir)) + ############################################################################## # Train and convert a model with 2x2 block config. There's no kernel in tflite # supporting this block configuration, so the sparse tensor is densified and @@ -113,13 +131,13 @@ def main(unused_argv): } model = build_layerwise_model(input_shape, **pruning_params) - model = train(model, x_train, y_train, x_test, y_test) + model = train(model, x_train, y_train, x_test, y_test, output_dir=temp_dir) converter = tf.lite.TFLiteConverter.from_keras_model(model) # Get a dense model as baseline tflite_model_dense = converter.convert() - tflite_model_path = '/tmp/dense_mnist.tflite' + tflite_model_path = os.path.join(temp_dir, 'dense_mnist.tflite') with open(tflite_model_path, 'wb') as f: f.write(tflite_model_dense) @@ -131,7 +149,9 @@ def main(unused_argv): # Check the model is compressed print('Compression ratio: ', len(tflite_model) / len(tflite_model_dense)) - tflite_model_path = '/tmp/sparse_mnist_%s_2x2.tflite' % FLAGS.sparsity + tflite_model_path = os.path.join( + temp_dir, 'sparse_mnist_%s_2x2.tflite' % FLAGS.sparsity + ) with open(tflite_model_path, 'wb') as f: f.write(tflite_model) @@ -152,7 +172,7 @@ def main(unused_argv): } model = build_layerwise_model(input_shape, **pruning_params) - model = train(model, x_train, y_train, x_test, y_test) + model = train(model, x_train, y_train, x_test, y_test, output_dir=temp_dir) converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = {tf.lite.Optimize.EXPERIMENTAL_SPARSITY} @@ -161,7 +181,9 @@ def main(unused_argv): # Check the model is compressed print('Compression ratio: ', len(tflite_model) / len(tflite_model_dense)) - tflite_model_path = '/tmp/sparse_mnist_%s_1x4.tflite' % FLAGS.sparsity + tflite_model_path = os.path.join( + temp_dir, 'sparse_mnist_%s_1x4.tflite' % FLAGS.sparsity + ) with open(tflite_model_path, 'wb') as f: f.write(tflite_model) @@ -181,7 +203,7 @@ def main(unused_argv): } model = build_layerwise_model(input_shape, **pruning_params) - model = train(model, x_train, y_train, x_test, y_test) + model = train(model, x_train, y_train, x_test, y_test, output_dir=temp_dir) converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = { @@ -192,7 +214,9 @@ def main(unused_argv): # Check the model is compressed print('Compression ratio: ', len(tflite_model) / len(tflite_model_dense)) - tflite_model_path = '/tmp/sparse_mnist_%s_1x16.tflite' % FLAGS.sparsity + tflite_model_path = os.path.join( + temp_dir, 'sparse_mnist_%s_1x16.tflite' % FLAGS.sparsity + ) with open(tflite_model_path, 'wb') as f: f.write(tflite_model) diff --git a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_e2e_sparsity2x4.py b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_e2e_sparsity2x4.py index 0520978f..5b232a32 100644 --- a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_e2e_sparsity2x4.py +++ b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_e2e_sparsity2x4.py @@ -19,7 +19,12 @@ """ from __future__ import print_function +import datetime +import os +import tempfile + from absl import app as absl_app +from absl import flags import tensorflow as tf from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils @@ -40,6 +45,14 @@ num_classes = 10 epochs = 1 +FLAGS = flags.FLAGS +flags.DEFINE_string( + 'output_dir', + None, + 'Output directory for models and logs. If not set, a temporary directory' + ' is used.', +) + PRUNABLE_2x4_LAYERS = (keras.layers.Conv2D, keras.layers.Dense) @@ -77,7 +90,7 @@ def build_layerwise_model(input_shape, **pruning_params): ]) -def train(model, x_train, y_train, x_test, y_test): +def train(model, x_train, y_train, x_test, y_test, output_dir): model.compile( loss=keras.losses.categorical_crossentropy, optimizer='adam', @@ -92,7 +105,7 @@ def train(model, x_train, y_train, x_test, y_test): # step. Also add a callback to add pruning summaries to tensorboard callbacks = [ pruning_callbacks.UpdatePruningStep(), - pruning_callbacks.PruningSummaries(log_dir='/tmp/logs') + pruning_callbacks.PruningSummaries(log_dir=output_dir) ] model.fit( @@ -131,14 +144,25 @@ def main(unused_argv): 'sparsity_m_by_n': (2, 4), } + if FLAGS.output_dir and not os.path.exists(FLAGS.output_dir): + os.makedirs(FLAGS.output_dir) + temp_dir = tempfile.mkdtemp( + dir=FLAGS.output_dir, + prefix=datetime.datetime.now().strftime('tmp_%Y%m%d%H%M_'), + ) + print('All models and logs will be saved to: {}'.format(temp_dir)) + model = build_layerwise_model(input_shape, **pruning_params) - pruned_model = train(model, x_train, y_train, x_test, y_test) + pruned_model = train( + model, x_train, y_train, x_test, y_test, output_dir=temp_dir + ) # Write a model that has been pruned with 2x4 sparsity. converter = tf.lite.TFLiteConverter.from_keras_model(pruned_model) tflite_model = converter.convert() - tflite_model_path = '/tmp/mnist_2x4.tflite' + tflite_model_dir = temp_dir + tflite_model_path = os.path.join(tflite_model_dir, 'mnist_2x4.tflite') print('model is saved to {}'.format(tflite_model_path)) with open(tflite_model_path, 'wb') as f: f.write(tflite_model) diff --git a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_mha.py b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_mha.py index fefd4a8e..c8174875 100644 --- a/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_mha.py +++ b/tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_mha.py @@ -16,6 +16,12 @@ """Train a simple model with MultiHeadAttention layer on MNIST dataset and prune it. """ +import datetime +import os +import tempfile + +from absl import app as absl_app +from absl import flags import tensorflow as tf from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils @@ -26,76 +32,97 @@ from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper - -tf.random.set_seed(42) - -ConstantSparsity = pruning_schedule.ConstantSparsity - -# Load MNIST dataset -mnist = keras.datasets.mnist -(train_images, train_labels), (test_images, test_labels) = mnist.load_data() - -# Normalize the input image so that each pixel value is between 0 to 1. -train_images = train_images / 255.0 -test_images = test_images / 255.0 - -# define model -input = keras.layers.Input(shape=(28, 28)) -x = keras.layers.MultiHeadAttention(num_heads=2, key_dim=16, name='mha')( - query=input, value=input -) -x = keras.layers.Flatten()(x) -out = keras.layers.Dense(10)(x) -model = keras.Model(inputs=input, outputs=out) - -# Train the digit classification model -model.compile( - optimizer='adam', - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - metrics=['accuracy'], -) - -model.fit( - train_images, train_labels, epochs=10, validation_split=0.1, +FLAGS = flags.FLAGS +flags.DEFINE_string( + 'output_dir', + None, + 'Output directory for models and logs. If not set, a temporary directory' + ' is used.', ) -score = model.evaluate(test_images, test_labels, verbose=0) -print('Model test loss:', score[0]) -print('Model test accuracy:', score[1]) - -# Define parameters for pruning - -batch_size = 128 -epochs = 3 -validation_split = 0.1 # 10% of training set will be used for validation set. - -callbacks = [ - pruning_callbacks.UpdatePruningStep(), - pruning_callbacks.PruningSummaries(log_dir='/tmp/logs') -] - -pruning_params = { - 'pruning_schedule': ConstantSparsity(0.75, begin_step=2000, frequency=100) -} - -model_for_pruning = prune.prune_low_magnitude(model, **pruning_params) - -# `prune_low_magnitude` requires a recompile. -model_for_pruning.compile( - optimizer='adam', - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - metrics=['accuracy'], -) - -model_for_pruning.fit( - train_images, - train_labels, - batch_size=batch_size, - epochs=epochs, - callbacks=callbacks, - validation_split=validation_split, -) -score = model_for_pruning.evaluate(test_images, test_labels, verbose=0) -print('Pruned model test loss:', score[0]) -print('Pruned model test accuracy:', score[1]) +def main(unused_argv): + tf.random.set_seed(42) + + constant_sparsity = pruning_schedule.ConstantSparsity + + # Load MNIST dataset + mnist = keras.datasets.mnist + (train_images, train_labels), (test_images, test_labels) = mnist.load_data() + + # Normalize the input image so that each pixel value is between 0 to 1. + train_images = train_images / 255.0 + test_images = test_images / 255.0 + + # define model + input = keras.layers.Input(shape=(28, 28)) + x = keras.layers.MultiHeadAttention(num_heads=2, key_dim=16, name='mha')( + query=input, value=input + ) + x = keras.layers.Flatten()(x) + out = keras.layers.Dense(10)(x) + model = keras.Model(inputs=input, outputs=out) + + # Train the digit classification model + model.compile( + optimizer='adam', + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=['accuracy'], + ) + + model.fit( + train_images, train_labels, epochs=10, validation_split=0.1, + ) + + score = model.evaluate(test_images, test_labels, verbose=0) + print('Model test loss:', score[0]) + print('Model test accuracy:', score[1]) + + # Define parameters for pruning + + batch_size = 128 + epochs = 3 + validation_split = 0.1 # 10% of training set will be used for validation set. + + if FLAGS.output_dir and not os.path.exists(FLAGS.output_dir): + os.makedirs(FLAGS.output_dir) + output_dir = tempfile.mkdtemp( + dir=FLAGS.output_dir, + prefix=datetime.datetime.now().strftime('tmp_%Y%m%d%H%M_'), + ) + print('All models and logs will be saved to: {}'.format(output_dir)) + callbacks = [ + pruning_callbacks.UpdatePruningStep(), + pruning_callbacks.PruningSummaries(log_dir=output_dir) + ] + + pruning_params = { + 'pruning_schedule': constant_sparsity( + 0.75, begin_step=2000, frequency=100 + ) + } + + model_for_pruning = prune.prune_low_magnitude(model, **pruning_params) + + # `prune_low_magnitude` requires a recompile. + model_for_pruning.compile( + optimizer='adam', + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=['accuracy'], + ) + + model_for_pruning.fit( + train_images, + train_labels, + batch_size=batch_size, + epochs=epochs, + callbacks=callbacks, + validation_split=validation_split, + ) + + score = model_for_pruning.evaluate(test_images, test_labels, verbose=0) + print('Pruned model test loss:', score[0]) + print('Pruned model test accuracy:', score[1]) + +if __name__ == '__main__': + absl_app.run(main)