Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ py_strict_binary(
srcs = ["mnist_e2e_sparsity2x4.py"],
deps = [
# absl:app dep1,
# absl/flags dep1,
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/keras:compat",
"//tensorflow_model_optimization/python/core/keras:test_utils",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
"""Train a simple convnet on the MNIST dataset."""
from __future__ import print_function

import datetime
import os
import tempfile

from absl import app as absl_app
from absl import flags
import tensorflow as tf
Expand All @@ -35,9 +39,8 @@
num_classes = 10
epochs = 12

flags.DEFINE_string('output_dir', '/tmp/mnist_train/',
'Output directory to hold tensorboard events')

flags.DEFINE_string('output_dir', None,
'Output directory to hold tensorboard events and models. If None, a temporary directory is used.')

def build_sequential_model(input_shape):
return keras.Sequential([
Expand Down Expand Up @@ -94,7 +97,7 @@ def build_layerwise_model(input_shape, **pruning_params):
])


def train_and_save(models, x_train, y_train, x_test, y_test):
def train_and_save(models, x_train, y_train, x_test, y_test, output_dir):
for model in models:
model.compile(
loss=keras.losses.categorical_crossentropy,
Expand All @@ -109,7 +112,7 @@ def train_and_save(models, x_train, y_train, x_test, y_test):
# step. Also add a callback to add pruning summaries to tensorboard
callbacks = [
pruning_callbacks.UpdatePruningStep(),
pruning_callbacks.PruningSummaries(log_dir=FLAGS.output_dir)
pruning_callbacks.PruningSummaries(log_dir=output_dir)
]

model.fit(
Expand All @@ -125,7 +128,7 @@ def train_and_save(models, x_train, y_train, x_test, y_test):
print('Test accuracy:', score[1])

# Export and import the model. Check that accuracy persists.
saved_model_dir = '/tmp/saved_model'
saved_model_dir = os.path.join(output_dir, 'saved_model')
print('Saving model to: ', saved_model_dir)
keras.models.save_model(model, saved_model_dir, save_format='tf')
print('Loading model from: ', saved_model_dir)
Expand Down Expand Up @@ -182,8 +185,13 @@ def main(unused_argv):
functional_model = prune.prune_low_magnitude(
functional_model, **pruning_params)

if FLAGS.output_dir and not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
output_dir = tempfile.mkdtemp(dir=FLAGS.output_dir, prefix=datetime.datetime.now().strftime("tmp_%Y%m%d%H%M_"))
print('All models and logs will be saved to: {}'.format(output_dir))

models = [layerwise_model, sequential_model, functional_model]
train_and_save(models, x_train, y_train, x_test, y_test)
train_and_save(models, x_train, y_train, x_test, y_test, output_dir=output_dir)


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
"""Train a simple convnet on the MNIST dataset."""
from __future__ import print_function

import datetime
import os
import tempfile

from absl import app as absl_app
from absl import flags
import tensorflow as tf
Expand All @@ -37,7 +41,7 @@
epochs = 1

flags.DEFINE_float('sparsity', '0.0', 'Target sparsity level.')

flags.DEFINE_string('output_dir', None, 'Output directory for models and logs. If not set, a temporary directory is used.')

def build_layerwise_model(input_shape, **pruning_params):
return keras.Sequential([
Expand All @@ -60,7 +64,7 @@ def build_layerwise_model(input_shape, **pruning_params):
])


def train(model, x_train, y_train, x_test, y_test):
def train(model, x_train, y_train, x_test, y_test, output_dir):
model.compile(
loss=keras.losses.categorical_crossentropy,
optimizer='adam',
Expand All @@ -74,7 +78,7 @@ def train(model, x_train, y_train, x_test, y_test):
# step. Also add a callback to add pruning summaries to tensorboard
callbacks = [
pruning_callbacks.UpdatePruningStep(),
pruning_callbacks.PruningSummaries(log_dir='/tmp/logs')
pruning_callbacks.PruningSummaries(log_dir=output_dir)
]

model.fit(
Expand All @@ -101,6 +105,11 @@ def main(unused_argv):
x_test,
y_test), input_shape = keras_test_utils.get_preprocessed_mnist_data()

if FLAGS.output_dir and not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
temp_dir = tempfile.mkdtemp(dir=FLAGS.output_dir, prefix=datetime.datetime.now().strftime("tmp_%Y%m%d%H%M_"))
print('All models and logs will be saved to: {}'.format(temp_dir))

##############################################################################
# Train and convert a model with 2x2 block config. There's no kernel in tflite
# supporting this block configuration, so the sparse tensor is densified and
Expand All @@ -113,13 +122,13 @@ def main(unused_argv):
}

model = build_layerwise_model(input_shape, **pruning_params)
model = train(model, x_train, y_train, x_test, y_test)
model = train(model, x_train, y_train, x_test, y_test, output_dir=temp_dir)

converter = tf.lite.TFLiteConverter.from_keras_model(model)

# Get a dense model as baseline
tflite_model_dense = converter.convert()
tflite_model_path = '/tmp/dense_mnist.tflite'
tflite_model_path = os.path.join(temp_dir, 'dense_mnist.tflite')
with open(tflite_model_path, 'wb') as f:
f.write(tflite_model_dense)

Expand All @@ -131,7 +140,7 @@ def main(unused_argv):
# Check the model is compressed
print('Compression ratio: ', len(tflite_model) / len(tflite_model_dense))

tflite_model_path = '/tmp/sparse_mnist_%s_2x2.tflite' % FLAGS.sparsity
tflite_model_path = os.path.join(temp_dir, 'sparse_mnist_%s_2x2.tflite' % FLAGS.sparsity)
with open(tflite_model_path, 'wb') as f:
f.write(tflite_model)

Expand All @@ -152,7 +161,7 @@ def main(unused_argv):
}

model = build_layerwise_model(input_shape, **pruning_params)
model = train(model, x_train, y_train, x_test, y_test)
model = train(model, x_train, y_train, x_test, y_test, output_dir=temp_dir)

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = {tf.lite.Optimize.EXPERIMENTAL_SPARSITY}
Expand All @@ -161,7 +170,7 @@ def main(unused_argv):
# Check the model is compressed
print('Compression ratio: ', len(tflite_model) / len(tflite_model_dense))

tflite_model_path = '/tmp/sparse_mnist_%s_1x4.tflite' % FLAGS.sparsity
tflite_model_path = os.path.join(temp_dir, 'sparse_mnist_%s_1x4.tflite' % FLAGS.sparsity)
with open(tflite_model_path, 'wb') as f:
f.write(tflite_model)

Expand All @@ -181,7 +190,7 @@ def main(unused_argv):
}

model = build_layerwise_model(input_shape, **pruning_params)
model = train(model, x_train, y_train, x_test, y_test)
model = train(model, x_train, y_train, x_test, y_test, output_dir=temp_dir)

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = {
Expand All @@ -192,7 +201,7 @@ def main(unused_argv):
# Check the model is compressed
print('Compression ratio: ', len(tflite_model) / len(tflite_model_dense))

tflite_model_path = '/tmp/sparse_mnist_%s_1x16.tflite' % FLAGS.sparsity
tflite_model_path = os.path.join(temp_dir, 'sparse_mnist_%s_1x16.tflite' % FLAGS.sparsity)
with open(tflite_model_path, 'wb') as f:
f.write(tflite_model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
"""
from __future__ import print_function

import datetime
import os
import tempfile

from absl import app as absl_app
from absl import flags
import tensorflow as tf

from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
Expand All @@ -40,6 +45,9 @@
num_classes = 10
epochs = 1

FLAGS = flags.FLAGS
flags.DEFINE_string('output_dir', None, 'Output directory for models and logs. If not set, a temporary directory is used.')

PRUNABLE_2x4_LAYERS = (keras.layers.Conv2D, keras.layers.Dense)


Expand Down Expand Up @@ -77,7 +85,7 @@ def build_layerwise_model(input_shape, **pruning_params):
])


def train(model, x_train, y_train, x_test, y_test):
def train(model, x_train, y_train, x_test, y_test, output_dir):
model.compile(
loss=keras.losses.categorical_crossentropy,
optimizer='adam',
Expand All @@ -92,7 +100,7 @@ def train(model, x_train, y_train, x_test, y_test):
# step. Also add a callback to add pruning summaries to tensorboard
callbacks = [
pruning_callbacks.UpdatePruningStep(),
pruning_callbacks.PruningSummaries(log_dir='/tmp/logs')
pruning_callbacks.PruningSummaries(log_dir=output_dir)
]

model.fit(
Expand Down Expand Up @@ -131,14 +139,20 @@ def main(unused_argv):
'sparsity_m_by_n': (2, 4),
}

if FLAGS.output_dir and not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
temp_dir = tempfile.mkdtemp(dir=FLAGS.output_dir, prefix=datetime.datetime.now().strftime("tmp_%Y%m%d%H%M_"))
print('All models and logs will be saved to: {}'.format(temp_dir))

model = build_layerwise_model(input_shape, **pruning_params)
pruned_model = train(model, x_train, y_train, x_test, y_test)
pruned_model = train(model, x_train, y_train, x_test, y_test, output_dir=temp_dir)

# Write a model that has been pruned with 2x4 sparsity.
converter = tf.lite.TFLiteConverter.from_keras_model(pruned_model)
tflite_model = converter.convert()

tflite_model_path = '/tmp/mnist_2x4.tflite'
tflite_model_dir = temp_dir
tflite_model_path = os.path.join(tflite_model_dir, 'mnist_2x4.tflite')
print('model is saved to {}'.format(tflite_model_path))
with open(tflite_model_path, 'wb') as f:
f.write(tflite_model)
Expand Down
Loading
Loading