Skip to content

Commit e493d99

Browse files
authored
[ADD] Missing Batchnorm (#317)
* Add batch norm to mlp * Increase max runtime in test * Try drop last for running batch norm * Change order to BN-ReLU
1 parent 3659407 commit e493d99

File tree

4 files changed

+6
-4
lines changed

4 files changed

+6
-4
lines changed

autoPyTorch/pipeline/components/setup/network_backbone/MLPBackbone.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def _add_layer(self, layers: List[nn.Module], in_features: int, out_features: in
5555
5656
"""
5757
layers.append(nn.Linear(in_features, out_features))
58+
layers.append(nn.BatchNorm1d(out_features))
5859
layers.append(_activations[self.config["activation"]]())
5960
if self.config['use_dropout']:
6061
layers.append(nn.Dropout(self.config["dropout_%d" % layer_id]))

autoPyTorch/pipeline/components/setup/network_backbone/ShapedMLPBackbone.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def _add_layer(self, layers: List[nn.Module],
6262
in_features: int, out_features: int, dropout: float
6363
) -> None:
6464
layers.append(nn.Linear(in_features, out_features))
65+
layers.append(nn.BatchNorm1d(out_features))
6566
layers.append(_activations[self.config["activation"]]())
6667
if self.config["use_dropout"] and self.config["max_dropout"] > 0.05:
6768
layers.append(nn.Dropout(dropout))

autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> torch.utils.data.DataLoader:
114114
shuffle=True,
115115
num_workers=X.get('num_workers', 0),
116116
pin_memory=X.get('pin_memory', True),
117-
drop_last=X.get('drop_last', False),
117+
drop_last=X.get('drop_last', True),
118118
collate_fn=custom_collate_fn,
119119
)
120120

test/test_pipeline/test_tabular_classification.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def test_train_pipeline_with_runtime(fit_dictionary_tabular_dummy):
460460
# Convert the training to runtime
461461
fit_dictionary_tabular_dummy.pop('epochs', None)
462462
fit_dictionary_tabular_dummy['budget_type'] = 'runtime'
463-
fit_dictionary_tabular_dummy['runtime'] = 3
463+
fit_dictionary_tabular_dummy['runtime'] = 5
464464
fit_dictionary_tabular_dummy['early_stopping'] = -1
465465

466466
pipeline = TabularClassificationPipeline(
@@ -474,11 +474,11 @@ def test_train_pipeline_with_runtime(fit_dictionary_tabular_dummy):
474474
run_summary = pipeline.named_steps['trainer'].run_summary
475475
budget_tracker = pipeline.named_steps['trainer'].budget_tracker
476476
assert budget_tracker.budget_type == 'runtime'
477-
assert budget_tracker.max_runtime == 3
477+
assert budget_tracker.max_runtime == 5
478478
assert budget_tracker.is_max_time_reached()
479479

480480
# There is no epoch limitation
481481
assert not budget_tracker.is_max_epoch_reached(epoch=np.inf)
482482

483-
# More than 200 epochs would have pass in 3 seconds for this dataset
483+
# More than 200 epochs would have pass in 5 seconds for this dataset
484484
assert len(run_summary.performance_tracker['start_time']) > 100

0 commit comments

Comments
 (0)