- Notifications
You must be signed in to change notification settings - Fork 339
Open
Labels
feature requestfeature requestfeature requestpriority:lowLow priority when applied. Intentionally open with no assignee or contributors welcome label.Low priority when applied. Intentionally open with no assignee or contributors welcome label.technique:pruningRegarding tfmot.sparsity.keras APIs and docsRegarding tfmot.sparsity.keras APIs and docs
Description
The recommended path for pruning with a custom training loop is not as simple as it could be.
pruned_model = setup_pruned_model() loss = tf.keras.losses.categorical_crossentropy optimizer = keras.optimizers.Adam() log_dir = tempfile.mkdtemp() # This is all not boilerplate. pruned_model.optimizer = optimizer step_callback = tfmot.sparsity.keras.UpdatePruningStep() step_callback.set_model(pruned_model) log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir) # optional Tensorboard logging. log_callback.set_model(pruned_model) step_callback.on_train_begin() for _ in range(3): # only one batch given batch_size = 20 and input shape. step_callback.on_train_batch_begin(batch=unused_arg) inp = np.reshape(x_train, [self._BATCH_SIZE, 10]) # original shape: from [10]. with tf.GradientTape() as tape: logits = pruned_model(inp, training=True) loss_value = loss(y_train, logits) grads = tape.gradient(loss_value, pruned_model.trainable_variables) optimizer.apply_gradients(zip(grads, pruned_model.trainable_variables)) step_callback.on_epoch_end(batch=unused_arg) log_callback.on_epoch_end(batch=unused_arg) ... The set_model and pruned_model.optimizer setting is unusual and could be missed.
BogdanDidenko
Metadata
Metadata
Assignees
Labels
feature requestfeature requestfeature requestpriority:lowLow priority when applied. Intentionally open with no assignee or contributors welcome label.Low priority when applied. Intentionally open with no assignee or contributors welcome label.technique:pruningRegarding tfmot.sparsity.keras APIs and docsRegarding tfmot.sparsity.keras APIs and docs