Skip to content

Pruning: improve custom training loop API #271

@alanchiao

Description

@alanchiao

The recommended path for pruning with a custom training loop is not as simple as it could be.

pruned_model = setup_pruned_model() loss = tf.keras.losses.categorical_crossentropy optimizer = keras.optimizers.Adam() log_dir = tempfile.mkdtemp() # This is all not boilerplate. pruned_model.optimizer = optimizer step_callback = tfmot.sparsity.keras.UpdatePruningStep() step_callback.set_model(pruned_model) log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir) # optional Tensorboard logging. log_callback.set_model(pruned_model) step_callback.on_train_begin() for _ in range(3): # only one batch given batch_size = 20 and input shape. step_callback.on_train_batch_begin(batch=unused_arg) inp = np.reshape(x_train, [self._BATCH_SIZE, 10]) # original shape: from [10]. with tf.GradientTape() as tape: logits = pruned_model(inp, training=True) loss_value = loss(y_train, logits) grads = tape.gradient(loss_value, pruned_model.trainable_variables) optimizer.apply_gradients(zip(grads, pruned_model.trainable_variables)) step_callback.on_epoch_end(batch=unused_arg) log_callback.on_epoch_end(batch=unused_arg) ... 

The set_model and pruned_model.optimizer setting is unusual and could be missed.

Metadata

Metadata

Assignees

Labels

feature requestfeature requestpriority:lowLow priority when applied. Intentionally open with no assignee or contributors welcome label.technique:pruningRegarding tfmot.sparsity.keras APIs and docs

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions