@@ -45,25 +45,53 @@ def __init__(self,
4545 An object for tracking when to stop the network training.
4646 It handles epoch based criteria as well as training based criteria.
4747
48- It also allows to define a 'epoch_or_time' budget type, which means,
49- the first of them both which is exhausted, is honored
48+ It also allows to define a 'epoch_or_time' budget type, which means, the first of them both which is
49+ exhausted, is honored
50+
51+ Args:
52+ budget_type (str):
53+ Type of budget to be used when fitting the pipeline.
54+ Possible values are 'epochs', 'runtime', or 'epoch_or_time'
55+ max_epochs (Optional[int], default=None):
56+ Maximum number of epochs to train the pipeline for
57+ max_runtime (Optional[int], default=None):
58+ Maximum number of seconds to train the pipeline for
5059 """
5160 self .start_time = time .time ()
5261 self .budget_type = budget_type
5362 self .max_epochs = max_epochs
5463 self .max_runtime = max_runtime
5564
5665 def is_max_epoch_reached (self , epoch : int ) -> bool :
66+ """
67+ For budget type 'epoch' or 'epoch_or_time' return True if the maximum number of epochs is reached.
68+
69+ Args:
70+ epoch (int):
71+ the current epoch
5772
58- # Make None a method to run without this constrain
73+ Returns:
74+ bool:
75+ True if the current epoch is larger than the maximum epochs, False otherwise.
76+ Additionally, returns False if the run is without this constraint.
77+ """
78+ # Make None a method to run without this constraint
5979 if self .max_epochs is None :
6080 return False
6181 if self .budget_type in ['epochs' , 'epoch_or_time' ] and epoch > self .max_epochs :
6282 return True
6383 return False
6484
6585 def is_max_time_reached (self ) -> bool :
66- # Make None a method to run without this constrain
86+ """
87+ For budget type 'runtime' or 'epoch_or_time' return True if the maximum runtime is reached.
88+
89+ Returns:
90+ bool:
91+ True if the maximum runtime is reached, False otherwise.
92+ Additionally, returns False if the run is without this constraint.
93+ """
94+ # Make None a method to run without this constraint
6795 if self .max_runtime is None :
6896 return False
6997 elapsed_time = time .time () - self .start_time
@@ -78,14 +106,22 @@ def __init__(
78106 total_parameter_count : float ,
79107 trainable_parameter_count : float ,
80108 optimize_metric : Optional [str ] = None ,
81- ):
109+ ) -> None :
82110 """
83111 A useful object to track performance per epoch.
84112
85- It allows to track train, validation and test information not only for
86- debug, but for research purposes (Like understanding overfit).
113+ It allows to track train, validation and test information not only for debug, but for research purposes
114+ (Like understanding overfit).
87115
88116 It does so by tracking a metric/loss at the end of each epoch.
117+
118+ Args:
119+ total_parameter_count (float):
120+ the total number of parameters of the model
121+ trainable_parameter_count (float):
122+ only the parameters being optimized
123+ optimize_metric (Optional[str], default=None):
124+ name of the metric that is used to evaluate a pipeline.
89125 """
90126 self .performance_tracker : Dict [str , Dict ] = {
91127 'start_time' : {},
@@ -121,8 +157,30 @@ def add_performance(self,
121157 test_loss : Optional [float ] = None ,
122158 ) -> None :
123159 """
124- Tracks performance information about the run, useful for
125- plotting individual runs
160+ Tracks performance information about the run, useful for plotting individual runs.
161+
162+ Args:
163+ epoch (int):
164+ the current epoch
165+ start_time (float):
166+ timestamp at the beginning of current epoch
167+ end_time (float):
168+ timestamp when gathering the information after the current epoch
169+ train_loss (float):
170+ the training loss
171+ train_metrics (Dict[str, float]):
172+ training scores for each desired metric
173+ val_metrics (Dict[str, float]):
174+ validation scores for each desired metric
175+ test_metrics (Dict[str, float]):
176+ test scores for each desired metric
177+ val_loss (Optional[float], default=None):
178+ the validation loss
179+ test_loss (Optional[float], default=None):
180+ the test loss
181+
182+ Returns:
183+ None
126184 """
127185 self .performance_tracker ['train_loss' ][epoch ] = train_loss
128186 self .performance_tracker ['val_loss' ][epoch ] = val_loss
@@ -134,6 +192,18 @@ def add_performance(self,
134192 self .performance_tracker ['test_metrics' ][epoch ] = test_metrics
135193
136194 def get_best_epoch (self , split_type : str = 'val' ) -> int :
195+ """
196+ Get the epoch with the best metric.
197+
198+ Args:
199+ split_type (str, default=val):
200+ Which split's metric to consider.
201+ Possible values are 'train' or 'val
202+
203+ Returns:
204+ int:
205+ the epoch with the best metric
206+ """
137207 # If we compute for optimization, prefer the performance
138208 # metric to the loss
139209 if self .optimize_metric is not None :
@@ -159,6 +229,13 @@ def get_best_epoch(self, split_type: str = 'val') -> int:
159229 )) + 1 # Epochs start at 1
160230
161231 def get_last_epoch (self ) -> int :
232+ """
233+ Get the last epoch.
234+
235+ Returns:
236+ int:
237+ the last epoch
238+ """
162239 if 'train_loss' not in self .performance_tracker :
163240 return 0
164241 else :
@@ -170,7 +247,8 @@ def repr_last_epoch(self) -> str:
170247 performance
171248
172249 Returns:
173- str: A nice representation of the last epoch
250+ str:
251+ A nice representation of the last epoch
174252 """
175253 last_epoch = len (self .performance_tracker ['train_loss' ])
176254 string = "\n "
@@ -202,30 +280,43 @@ def is_empty(self) -> bool:
202280 Checks if the object is empty or not
203281
204282 Returns:
205- bool
283+ bool:
284+ True if the object is empty, False otherwise
206285 """
207286 # if train_loss is empty, we can be sure that RunSummary is empty.
208287 return not bool (self .performance_tracker ['train_loss' ])
209288
210289
211290class BaseTrainerComponent (autoPyTorchTrainingComponent ):
212291 """
213- Base class for training
292+ Base class for training.
293+
214294 Args:
215- weighted_loss (int, default=0): In case for classification, whether to weight
216- the loss function according to the distribution of classes in the target
217- use_stochastic_weight_averaging (bool, default=True): whether to use stochastic
218- weight averaging. Stochastic weight averaging is a simple average of
219- multiple points(model parameters) along the trajectory of SGD. SWA
220- has been proposed in
295+ weighted_loss (int, default=0):
296+ In case for classification, whether to weight the loss function according to the distribution of classes
297+ in the target
298+ use_stochastic_weight_averaging (bool, default=True):
299+ whether to use stochastic weight averaging. Stochastic weight averaging is a simple average of
300+ multiple points(model parameters) along the trajectory of SGD. SWA has been proposed in
221301 [Averaging Weights Leads to Wider Optima and Better Generalization](https://arxiv.org/abs/1803.05407)
222- use_snapshot_ensemble (bool, default=True): whether to use snapshot
223- ensemble
224- se_lastk (int, default=3): Number of snapshots of the network to maintain
225- use_lookahead_optimizer (bool, default=True): whether to use lookahead
226- optimizer
227- random_state:
228- **lookahead_config:
302+ use_snapshot_ensemble (bool, default=True):
303+ whether to use snapshot ensemble
304+ se_lastk (int, default=3):
305+ Number of snapshots of the network to maintain
306+ use_lookahead_optimizer (bool, default=True):
307+ whether to use lookahead optimizer
308+ random_state (Optional[np.random.RandomState]):
309+ Object that contains a seed and allows for reproducible results
310+ swa_model (Optional[torch.nn.Module], default=None):
311+ Averaged model used for Stochastic Weight Averaging
312+ model_snapshots (Optional[List[torch.nn.Module]], default=None):
313+ List of model snapshots in case snapshot ensemble is used
314+ **lookahead_config (Any):
315+ keyword arguments for the lookahead optimizer including:
316+ la_steps (int):
317+ number of lookahead steps
318+ la_alpha (float):
319+ linear interpolation factor. 1.0 recovers the inner optimizer.
229320 """
230321 def __init__ (self , weighted_loss : int = 0 ,
231322 use_stochastic_weight_averaging : bool = True ,
@@ -336,15 +427,21 @@ def prepare(
336427
337428 def on_epoch_start (self , X : Dict [str , Any ], epoch : int ) -> None :
338429 """
339- Optional place holder for AutoPytorch Extensions.
430+ Optional placeholder for AutoPytorch Extensions.
431+ A user can define what happens on every epoch start or every epoch end.
340432
341- An user can define what happens on every epoch start or every epoch end.
433+ Args:
434+ X (Dict[str, Any]):
435+ Dictionary with fitted parameters. It is a message passing mechanism, in which during a transform,
436+ a components adds relevant information so that further stages can be properly fitted
437+ epoch (int):
438+ the current epoch
342439 """
343440 pass
344441
345442 def _swa_update (self ) -> None :
346443 """
347- perform swa model update
444+ Perform Stochastic Weight Averaging model update
348445 """
349446 if self .swa_model is None :
350447 raise ValueError ("SWA model cannot be none when stochastic weight averaging is enabled" )
@@ -354,6 +451,7 @@ def _swa_update(self) -> None:
354451 def _se_update (self , epoch : int ) -> None :
355452 """
356453 Add latest model or swa_model to model snapshot ensemble
454+
357455 Args:
358456 epoch (int):
359457 current epoch
@@ -373,9 +471,16 @@ def _se_update(self, epoch: int) -> None:
373471
374472 def on_epoch_end (self , X : Dict [str , Any ], epoch : int ) -> bool :
375473 """
376- Optional place holder for AutoPytorch Extensions.
377- An user can define what happens on every epoch start or every epoch end.
378- If returns True, the training is stopped
474+ Optional placeholder for AutoPytorch Extensions.
475+ A user can define what happens on every epoch start or every epoch end.
476+ If returns True, the training is stopped.
477+
478+ Args:
479+ X (Dict[str, Any]):
480+ Dictionary with fitted parameters. It is a message passing mechanism, in which during a transform,
481+ a components adds relevant information so that further stages can be properly fitted
482+ epoch (int):
483+ the current epoch
379484
380485 """
381486 if X ['is_cyclic_scheduler' ]:
@@ -421,12 +526,18 @@ def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int,
421526 Train the model for a single epoch.
422527
423528 Args:
424- train_loader (torch.utils.data.DataLoader): generator of features/label
425- epoch (int): The current epoch used solely for tracking purposes
529+ train_loader (torch.utils.data.DataLoader):
530+ generator of features/label
531+ epoch (int):
532+ The current epoch used solely for tracking purposes
533+ writer (Optional[SummaryWriter]):
534+ Object to keep track of the training loss in an event file
426535
427536 Returns:
428- float: training loss
429- Dict[str, float]: scores for each desired metric
537+ float:
538+ training loss
539+ Dict[str, float]:
540+ scores for each desired metric
430541 """
431542
432543 loss_sum = 0.0
@@ -482,12 +593,16 @@ def train_step(self, data: torch.Tensor, targets: torch.Tensor) -> Tuple[float,
482593 Allows to train 1 step of gradient descent, given a batch of train/labels
483594
484595 Args:
485- data (torch.Tensor): input features to the network
486- targets (torch.Tensor): ground truth to calculate loss
596+ data (torch.Tensor):
597+ input features to the network
598+ targets (torch.Tensor):
599+ ground truth to calculate loss
487600
488601 Returns:
489- torch.Tensor: The predictions of the network
490- float: the loss incurred in the prediction
602+ torch.Tensor:
603+ The predictions of the network
604+ float:
605+ the loss incurred in the prediction
491606 """
492607 # prepare
493608 data = data .float ().to (self .device )
@@ -513,12 +628,18 @@ def evaluate(self, test_loader: torch.utils.data.DataLoader, epoch: int,
513628 Evaluate the model in both metrics and criterion
514629
515630 Args:
516- test_loader (torch.utils.data.DataLoader): generator of features/label
517- epoch (int): the current epoch for tracking purposes
631+ test_loader (torch.utils.data.DataLoader):
632+ generator of features/label
633+ epoch (int):
634+ the current epoch for tracking purposes
635+ writer (Optional[SummaryWriter]):
636+ Object to keep track of the test loss in an event file
518637
519638 Returns:
520- float: test loss
521- Dict[str, float]: scores for each desired metric
639+ float:
640+ test loss
641+ Dict[str, float]:
642+ scores for each desired metric
522643 """
523644 self .model .eval ()
524645
@@ -576,14 +697,15 @@ def get_class_weights(self, criterion: Type[torch.nn.Module], labels: Union[np.n
576697 def data_preparation (self , X : torch .Tensor , y : torch .Tensor ,
577698 ) -> Tuple [torch .Tensor , Dict [str , np .ndarray ]]:
578699 """
579- Depending on the trainer choice, data fed to the network might be pre-processed
580- on a different way. That is, in standard training we provide the data to the
581- network as we receive it to the loader. Some regularization techniques, like mixup
582- alter the data.
700+ Depending on the trainer choice, data fed to the network might be pre-processed on a different way. That is,
701+ in standard training we provide the data to the network as we receive it to the loader. Some regularization
702+ techniques, like mixup alter the data.
583703
584704 Args:
585- X (torch.Tensor): The batch training features
586- y (torch.Tensor): The batch training labels
705+ X (torch.Tensor):
706+ The batch training features
707+ y (torch.Tensor):
708+ The batch training labels
587709
588710 Returns:
589711 torch.Tensor: that processes data
@@ -595,16 +717,21 @@ def data_preparation(self, X: torch.Tensor, y: torch.Tensor,
595717 def criterion_preparation (self , y_a : torch .Tensor , y_b : torch .Tensor = None , lam : float = 1.0
596718 ) -> Callable : # type: ignore
597719 """
598- Depending on the trainer choice, the criterion is not directly applied to the
599- traditional y_pred/y_ground_truth pairs, but rather it might have a slight transformation.
720+ Depending on the trainer choice, the criterion is not directly applied to the traditional
721+ y_pred/y_ground_truth pairs, but rather it might have a slight transformation.
600722 For example, in the case of mixup training, we need to account for the lambda mixup
601723
602724 Args:
603- kwargs (Dict): an expanded dictionary with modifiers to the
604- criterion calculation
725+ y_a (torch.Tensor):
726+ the batch label of the first training example used in trainer
727+ y_b (torch.Tensor, default=None):
728+ if applicable, the batch label of the second training example used in trainer
729+ lam (float):
730+ trainer coefficient
605731
606732 Returns:
607- Callable: a lambda function that contains the new criterion calculation recipe
733+ Callable:
734+ a lambda function that contains the new criterion calculation recipe
608735 """
609736 raise NotImplementedError ()
610737
0 commit comments