Skip to content

Commit c56c803

Browse files
committed
fixed earlystopping callback
1 parent fa480be commit c56c803

File tree

10 files changed

+60
-11
lines changed

10 files changed

+60
-11
lines changed

week_0_project_setup/train.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import pytorch_lightning as pl
33
from pytorch_lightning.callbacks import ModelCheckpoint
4+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
45

56
from data import DataModule
67
from model import ColaModel
@@ -13,14 +14,17 @@ def main():
1314
checkpoint_callback = ModelCheckpoint(
1415
dirpath="./models", monitor="val_loss", mode="min"
1516
)
17+
early_stopping_callback = EarlyStopping(
18+
monitor="val_loss", patience=3, verbose=True, mode="min"
19+
)
1620

1721
trainer = pl.Trainer(
1822
default_root_dir="logs",
1923
gpus=(1 if torch.cuda.is_available() else 0),
20-
max_epochs=1,
24+
max_epochs=5,
2125
fast_dev_run=False,
2226
logger=pl.loggers.TensorBoardLogger("logs/", name="cola", version=1),
23-
callbacks=[checkpoint_callback],
27+
callbacks=[checkpoint_callback, early_stopping_callback],
2428
)
2529
trainer.fit(cola_model, cola_data)
2630

week_1_wandb_logging/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pandas as pd
44
import pytorch_lightning as pl
55
from pytorch_lightning.callbacks import ModelCheckpoint
6+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
67
from pytorch_lightning.loggers import WandbLogger
78

89
from data import DataModule
@@ -47,11 +48,15 @@ def main():
4748
mode="min",
4849
)
4950

51+
early_stopping_callback = EarlyStopping(
52+
monitor="val_loss", patience=3, verbose=True, mode="min"
53+
)
54+
5055
wandb_logger = WandbLogger(project="MLOps Basics", entity="raviraja")
5156
trainer = pl.Trainer(
5257
max_epochs=1,
5358
logger=wandb_logger,
54-
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data)],
59+
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],
5560
log_every_n_steps=10,
5661
deterministic=True,
5762
# limit_train_batches=0.25,

week_2_hydra_config/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytorch_lightning as pl
88
from omegaconf.omegaconf import OmegaConf
99
from pytorch_lightning.callbacks import ModelCheckpoint
10+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
1011
from pytorch_lightning.loggers import WandbLogger
1112

1213
from data import DataModule
@@ -59,11 +60,15 @@ def main(cfg):
5960
mode="min",
6061
)
6162

63+
early_stopping_callback = EarlyStopping(
64+
monitor="val_loss", patience=3, verbose=True, mode="min"
65+
)
66+
6267
wandb_logger = WandbLogger(project="MLOps Basics", entity="raviraja")
6368
trainer = pl.Trainer(
6469
max_epochs=cfg.training.max_epochs,
6570
logger=wandb_logger,
66-
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data)],
71+
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],
6772
log_every_n_steps=cfg.training.log_every_n_steps,
6873
deterministic=cfg.training.deterministic,
6974
limit_train_batches=cfg.training.limit_train_batches,

week_3_dvc/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytorch_lightning as pl
88
from omegaconf.omegaconf import OmegaConf
99
from pytorch_lightning.callbacks import ModelCheckpoint
10+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
1011
from pytorch_lightning.loggers import WandbLogger
1112

1213
from data import DataModule
@@ -60,11 +61,15 @@ def main(cfg):
6061
mode="min",
6162
)
6263

64+
early_stopping_callback = EarlyStopping(
65+
monitor="val_loss", patience=3, verbose=True, mode="min"
66+
)
67+
6368
wandb_logger = WandbLogger(project="MLOps Basics", entity="raviraja")
6469
trainer = pl.Trainer(
6570
max_epochs=cfg.training.max_epochs,
6671
logger=wandb_logger,
67-
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data)],
72+
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],
6873
log_every_n_steps=cfg.training.log_every_n_steps,
6974
deterministic=cfg.training.deterministic,
7075
# limit_train_batches=cfg.training.limit_train_batches,

week_4_onnx/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytorch_lightning as pl
88
from omegaconf.omegaconf import OmegaConf
99
from pytorch_lightning.callbacks import ModelCheckpoint
10+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
1011
from pytorch_lightning.loggers import WandbLogger
1112

1213
from data import DataModule
@@ -60,11 +61,15 @@ def main(cfg):
6061
mode="min",
6162
)
6263

64+
early_stopping_callback = EarlyStopping(
65+
monitor="val_loss", patience=3, verbose=True, mode="min"
66+
)
67+
6368
wandb_logger = WandbLogger(project="MLOps Basics", entity="raviraja")
6469
trainer = pl.Trainer(
6570
max_epochs=cfg.training.max_epochs,
6671
logger=wandb_logger,
67-
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data)],
72+
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],
6873
log_every_n_steps=cfg.training.log_every_n_steps,
6974
deterministic=cfg.training.deterministic,
7075
# limit_train_batches=cfg.training.limit_train_batches,

week_5_docker/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytorch_lightning as pl
88
from omegaconf.omegaconf import OmegaConf
99
from pytorch_lightning.callbacks import ModelCheckpoint
10+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
1011
from pytorch_lightning.loggers import WandbLogger
1112

1213
from data import DataModule
@@ -60,11 +61,15 @@ def main(cfg):
6061
mode="min",
6162
)
6263

64+
early_stopping_callback = EarlyStopping(
65+
monitor="val_loss", patience=3, verbose=True, mode="min"
66+
)
67+
6368
wandb_logger = WandbLogger(project="MLOps Basics", entity="raviraja")
6469
trainer = pl.Trainer(
6570
max_epochs=cfg.training.max_epochs,
6671
logger=wandb_logger,
67-
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data)],
72+
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],
6873
log_every_n_steps=cfg.training.log_every_n_steps,
6974
deterministic=cfg.training.deterministic,
7075
# limit_train_batches=cfg.training.limit_train_batches,

week_6_github_actions/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytorch_lightning as pl
88
from omegaconf.omegaconf import OmegaConf
99
from pytorch_lightning.callbacks import ModelCheckpoint
10+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
1011
from pytorch_lightning.loggers import WandbLogger
1112

1213
from data import DataModule
@@ -60,11 +61,15 @@ def main(cfg):
6061
mode="min",
6162
)
6263

64+
early_stopping_callback = EarlyStopping(
65+
monitor="val_loss", patience=3, verbose=True, mode="min"
66+
)
67+
6368
wandb_logger = WandbLogger(project="MLOps Basics", entity="raviraja")
6469
trainer = pl.Trainer(
6570
max_epochs=cfg.training.max_epochs,
6671
logger=wandb_logger,
67-
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data)],
72+
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],
6873
log_every_n_steps=cfg.training.log_every_n_steps,
6974
deterministic=cfg.training.deterministic,
7075
# limit_train_batches=cfg.training.limit_train_batches,

week_7_ecr/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytorch_lightning as pl
88
from omegaconf.omegaconf import OmegaConf
99
from pytorch_lightning.callbacks import ModelCheckpoint
10+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
1011
from pytorch_lightning.loggers import WandbLogger
1112

1213
from data import DataModule
@@ -60,11 +61,15 @@ def main(cfg):
6061
mode="min",
6162
)
6263

64+
early_stopping_callback = EarlyStopping(
65+
monitor="val_loss", patience=3, verbose=True, mode="min"
66+
)
67+
6368
wandb_logger = WandbLogger(project="MLOps Basics", entity="raviraja")
6469
trainer = pl.Trainer(
6570
max_epochs=cfg.training.max_epochs,
6671
logger=wandb_logger,
67-
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data)],
72+
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],
6873
log_every_n_steps=cfg.training.log_every_n_steps,
6974
deterministic=cfg.training.deterministic,
7075
# limit_train_batches=cfg.training.limit_train_batches,

week_8_serverless/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytorch_lightning as pl
88
from omegaconf.omegaconf import OmegaConf
99
from pytorch_lightning.callbacks import ModelCheckpoint
10+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
1011
from pytorch_lightning.loggers import WandbLogger
1112

1213
from data import DataModule
@@ -60,11 +61,15 @@ def main(cfg):
6061
mode="min",
6162
)
6263

64+
early_stopping_callback = EarlyStopping(
65+
monitor="val_loss", patience=3, verbose=True, mode="min"
66+
)
67+
6368
wandb_logger = WandbLogger(project="MLOps Basics", entity="raviraja")
6469
trainer = pl.Trainer(
6570
max_epochs=cfg.training.max_epochs,
6671
logger=wandb_logger,
67-
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data)],
72+
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],
6873
log_every_n_steps=cfg.training.log_every_n_steps,
6974
deterministic=cfg.training.deterministic,
7075
# limit_train_batches=cfg.training.limit_train_batches,

week_9_monitoring/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytorch_lightning as pl
88
from omegaconf.omegaconf import OmegaConf
99
from pytorch_lightning.callbacks import ModelCheckpoint
10+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
1011
from pytorch_lightning.loggers import WandbLogger
1112

1213
from data import DataModule
@@ -60,11 +61,15 @@ def main(cfg):
6061
mode="min",
6162
)
6263

64+
early_stopping_callback = EarlyStopping(
65+
monitor="val_loss", patience=3, verbose=True, mode="min"
66+
)
67+
6368
wandb_logger = WandbLogger(project="MLOps Basics", entity="raviraja")
6469
trainer = pl.Trainer(
6570
max_epochs=cfg.training.max_epochs,
6671
logger=wandb_logger,
67-
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data)],
72+
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],
6873
log_every_n_steps=cfg.training.log_every_n_steps,
6974
deterministic=cfg.training.deterministic,
7075
# limit_train_batches=cfg.training.limit_train_batches,

0 commit comments

Comments
 (0)