![]() |
A single-replica view of training procedure.
Inherits From: Task
tfm.vision.MaskRCNNTask( params, logging_dir: Optional[str] = None, name: Optional[str] = None )
Mask R-CNN task provides artifacts for training/evalution procedures, including loading/iterating over Datasets, initializing the model, calculating the loss, post-processing, and customized metrics with reduction.
Attributes | |
---|---|
logging_dir | |
task_config |
Methods
aggregate_logs
aggregate_logs( state: Optional[Any] = None, step_outputs: Optional[Dict[str, Any]] = None ) -> Optional[Any]
Optional aggregation over logs returned from a validation step.
build_inputs
build_inputs( params: tfm.vision.configs.maskrcnn.DataConfig
, input_context: Optional[tf.distribute.InputContext] = None, dataset_fn: Optional[dataset_fn_lib.PossibleDatasetType] = None ) -> tf.data.Dataset
Builds input dataset.
build_losses
build_losses( outputs: Mapping[str, Any], labels: Mapping[str, Any], aux_losses: Optional[Any] = None ) -> Dict[str, tf.Tensor]
Builds Mask R-CNN losses.
build_metrics
build_metrics( training: bool = True )
Builds detection metrics.
build_model
build_model()
Builds Mask R-CNN model.
create_optimizer
@classmethod
create_optimizer( optimizer_config:
tfm.optimization.OptimizationConfig
, runtime_config: Optional[tfm.core.base_task.RuntimeConfig
] = None, dp_config: Optional[tfm.core.base_task.DifferentialPrivacyConfig
] = None )
Creates an TF optimizer from configurations.
Args | |
---|---|
optimizer_config | the parameters of the Optimization settings. |
runtime_config | the parameters of the runtime. |
dp_config | the parameter of differential privacy. |
Returns | |
---|---|
A tf.optimizers.Optimizer object. |
inference_step
inference_step( inputs, model: tf.keras.Model )
Performs the forward step.
With distribution strategies, this method runs on devices.
Args | |
---|---|
inputs | a dictionary of input tensors. |
model | the keras.Model. |
Returns | |
---|---|
Model outputs. |
initialize
initialize( model: tf.keras.Model )
Loads pretrained checkpoint.
process_compiled_metrics
process_compiled_metrics( compiled_metrics, labels, model_outputs )
Process and update compiled_metrics.
call when using compile/fit API.
Args | |
---|---|
compiled_metrics | the compiled metrics (model.compiled_metrics). |
labels | a tensor or a nested structure of tensors. |
model_outputs | a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model. |
process_metrics
process_metrics( metrics, labels, model_outputs, **kwargs )
Process and update metrics.
Called when using custom training loop API.
Args | |
---|---|
metrics | a nested structure of metrics objects. The return of function self.build_metrics. |
labels | a tensor or a nested structure of tensors. |
model_outputs | a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model. |
**kwargs | other args. |
reduce_aggregated_logs
reduce_aggregated_logs( aggregated_logs: Dict[str, Any], global_step: Optional[tf.Tensor] = None ) -> Dict[str, tf.Tensor]
Optional reduce of aggregated logs over validation steps.
train_step
train_step( inputs: Tuple[Any, Any], model: tf.keras.Model, optimizer: tf.keras.optimizers.Optimizer, metrics: Optional[List[Any]] = None )
Does forward and backward.
Args | |
---|---|
inputs | a dictionary of input tensors. |
model | the model, forward pass definition. |
optimizer | the optimizer for this training step. |
metrics | a nested structure of metrics objects. |
Returns | |
---|---|
A dictionary of logs. |
validation_step
validation_step( inputs: Tuple[Any, Any], model: tf.keras.Model, metrics: Optional[List[Any]] = None )
Validatation step.
Args | |
---|---|
inputs | a dictionary of input tensors. |
model | the keras.Model. |
metrics | a nested structure of metrics objects. |
Returns | |
---|---|
A dictionary of logs. |
Class Variables | |
---|---|
loss | 'loss' |