File tree Expand file tree Collapse file tree 10 files changed +60
-11
lines changed
Expand file tree Collapse file tree 10 files changed +60
-11
lines changed Original file line number Diff line number Diff line change 11import torch
22import pytorch_lightning as pl
33from pytorch_lightning .callbacks import ModelCheckpoint
4+ from pytorch_lightning .callbacks .early_stopping import EarlyStopping
45
56from data import DataModule
67from model import ColaModel
@@ -13,14 +14,17 @@ def main():
1314 checkpoint_callback = ModelCheckpoint (
1415 dirpath = "./models" , monitor = "val_loss" , mode = "min"
1516 )
17+ early_stopping_callback = EarlyStopping (
18+ monitor = "val_loss" , patience = 3 , verbose = True , mode = "min"
19+ )
1620
1721 trainer = pl .Trainer (
1822 default_root_dir = "logs" ,
1923 gpus = (1 if torch .cuda .is_available () else 0 ),
20- max_epochs = 1 ,
24+ max_epochs = 5 ,
2125 fast_dev_run = False ,
2226 logger = pl .loggers .TensorBoardLogger ("logs/" , name = "cola" , version = 1 ),
23- callbacks = [checkpoint_callback ],
27+ callbacks = [checkpoint_callback , early_stopping_callback ],
2428 )
2529 trainer .fit (cola_model , cola_data )
2630
Original file line number Diff line number Diff line change 33import pandas as pd
44import pytorch_lightning as pl
55from pytorch_lightning .callbacks import ModelCheckpoint
6+ from pytorch_lightning .callbacks .early_stopping import EarlyStopping
67from pytorch_lightning .loggers import WandbLogger
78
89from data import DataModule
@@ -47,11 +48,15 @@ def main():
4748 mode = "min" ,
4849 )
4950
51+ early_stopping_callback = EarlyStopping (
52+ monitor = "val_loss" , patience = 3 , verbose = True , mode = "min"
53+ )
54+
5055 wandb_logger = WandbLogger (project = "MLOps Basics" , entity = "raviraja" )
5156 trainer = pl .Trainer (
5257 max_epochs = 1 ,
5358 logger = wandb_logger ,
54- callbacks = [checkpoint_callback , SamplesVisualisationLogger (cola_data )],
59+ callbacks = [checkpoint_callback , SamplesVisualisationLogger (cola_data ), early_stopping_callback ],
5560 log_every_n_steps = 10 ,
5661 deterministic = True ,
5762 # limit_train_batches=0.25,
Original file line number Diff line number Diff line change 77import pytorch_lightning as pl
88from omegaconf .omegaconf import OmegaConf
99from pytorch_lightning .callbacks import ModelCheckpoint
10+ from pytorch_lightning .callbacks .early_stopping import EarlyStopping
1011from pytorch_lightning .loggers import WandbLogger
1112
1213from data import DataModule
@@ -59,11 +60,15 @@ def main(cfg):
5960 mode = "min" ,
6061 )
6162
63+ early_stopping_callback = EarlyStopping (
64+ monitor = "val_loss" , patience = 3 , verbose = True , mode = "min"
65+ )
66+
6267 wandb_logger = WandbLogger (project = "MLOps Basics" , entity = "raviraja" )
6368 trainer = pl .Trainer (
6469 max_epochs = cfg .training .max_epochs ,
6570 logger = wandb_logger ,
66- callbacks = [checkpoint_callback , SamplesVisualisationLogger (cola_data )],
71+ callbacks = [checkpoint_callback , SamplesVisualisationLogger (cola_data ), early_stopping_callback ],
6772 log_every_n_steps = cfg .training .log_every_n_steps ,
6873 deterministic = cfg .training .deterministic ,
6974 limit_train_batches = cfg .training .limit_train_batches ,
Original file line number Diff line number Diff line change 77import pytorch_lightning as pl
88from omegaconf .omegaconf import OmegaConf
99from pytorch_lightning .callbacks import ModelCheckpoint
10+ from pytorch_lightning .callbacks .early_stopping import EarlyStopping
1011from pytorch_lightning .loggers import WandbLogger
1112
1213from data import DataModule
@@ -60,11 +61,15 @@ def main(cfg):
6061 mode = "min" ,
6162 )
6263
64+ early_stopping_callback = EarlyStopping (
65+ monitor = "val_loss" , patience = 3 , verbose = True , mode = "min"
66+ )
67+
6368 wandb_logger = WandbLogger (project = "MLOps Basics" , entity = "raviraja" )
6469 trainer = pl .Trainer (
6570 max_epochs = cfg .training .max_epochs ,
6671 logger = wandb_logger ,
67- callbacks = [checkpoint_callback , SamplesVisualisationLogger (cola_data )],
72+ callbacks = [checkpoint_callback , SamplesVisualisationLogger (cola_data ), early_stopping_callback ],
6873 log_every_n_steps = cfg .training .log_every_n_steps ,
6974 deterministic = cfg .training .deterministic ,
7075 # limit_train_batches=cfg.training.limit_train_batches,
Original file line number Diff line number Diff line change 77import pytorch_lightning as pl
88from omegaconf .omegaconf import OmegaConf
99from pytorch_lightning .callbacks import ModelCheckpoint
10+ from pytorch_lightning .callbacks .early_stopping import EarlyStopping
1011from pytorch_lightning .loggers import WandbLogger
1112
1213from data import DataModule
@@ -60,11 +61,15 @@ def main(cfg):
6061 mode = "min" ,
6162 )
6263
64+ early_stopping_callback = EarlyStopping (
65+ monitor = "val_loss" , patience = 3 , verbose = True , mode = "min"
66+ )
67+
6368 wandb_logger = WandbLogger (project = "MLOps Basics" , entity = "raviraja" )
6469 trainer = pl .Trainer (
6570 max_epochs = cfg .training .max_epochs ,
6671 logger = wandb_logger ,
67- callbacks = [checkpoint_callback , SamplesVisualisationLogger (cola_data )],
72+ callbacks = [checkpoint_callback , SamplesVisualisationLogger (cola_data ), early_stopping_callback ],
6873 log_every_n_steps = cfg .training .log_every_n_steps ,
6974 deterministic = cfg .training .deterministic ,
7075 # limit_train_batches=cfg.training.limit_train_batches,
Original file line number Diff line number Diff line change 77import pytorch_lightning as pl
88from omegaconf .omegaconf import OmegaConf
99from pytorch_lightning .callbacks import ModelCheckpoint
10+ from pytorch_lightning .callbacks .early_stopping import EarlyStopping
1011from pytorch_lightning .loggers import WandbLogger
1112
1213from data import DataModule
@@ -60,11 +61,15 @@ def main(cfg):
6061 mode = "min" ,
6162 )
6263
64+ early_stopping_callback = EarlyStopping (
65+ monitor = "val_loss" , patience = 3 , verbose = True , mode = "min"
66+ )
67+
6368 wandb_logger = WandbLogger (project = "MLOps Basics" , entity = "raviraja" )
6469 trainer = pl .Trainer (
6570 max_epochs = cfg .training .max_epochs ,
6671 logger = wandb_logger ,
67- callbacks = [checkpoint_callback , SamplesVisualisationLogger (cola_data )],
72+ callbacks = [checkpoint_callback , SamplesVisualisationLogger (cola_data ), early_stopping_callback ],
6873 log_every_n_steps = cfg .training .log_every_n_steps ,
6974 deterministic = cfg .training .deterministic ,
7075 # limit_train_batches=cfg.training.limit_train_batches,
Original file line number Diff line number Diff line change 77import pytorch_lightning as pl
88from omegaconf .omegaconf import OmegaConf
99from pytorch_lightning .callbacks import ModelCheckpoint
10+ from pytorch_lightning .callbacks .early_stopping import EarlyStopping
1011from pytorch_lightning .loggers import WandbLogger
1112
1213from data import DataModule
@@ -60,11 +61,15 @@ def main(cfg):
6061 mode = "min" ,
6162 )
6263
64+ early_stopping_callback = EarlyStopping (
65+ monitor = "val_loss" , patience = 3 , verbose = True , mode = "min"
66+ )
67+
6368 wandb_logger = WandbLogger (project = "MLOps Basics" , entity = "raviraja" )
6469 trainer = pl .Trainer (
6570 max_epochs = cfg .training .max_epochs ,
6671 logger = wandb_logger ,
67- callbacks = [checkpoint_callback , SamplesVisualisationLogger (cola_data )],
72+ callbacks = [checkpoint_callback , SamplesVisualisationLogger (cola_data ), early_stopping_callback ],
6873 log_every_n_steps = cfg .training .log_every_n_steps ,
6974 deterministic = cfg .training .deterministic ,
7075 # limit_train_batches=cfg.training.limit_train_batches,
Original file line number Diff line number Diff line change 77import pytorch_lightning as pl
88from omegaconf .omegaconf import OmegaConf
99from pytorch_lightning .callbacks import ModelCheckpoint
10+ from pytorch_lightning .callbacks .early_stopping import EarlyStopping
1011from pytorch_lightning .loggers import WandbLogger
1112
1213from data import DataModule
@@ -60,11 +61,15 @@ def main(cfg):
6061 mode = "min" ,
6162 )
6263
64+ early_stopping_callback = EarlyStopping (
65+ monitor = "val_loss" , patience = 3 , verbose = True , mode = "min"
66+ )
67+
6368 wandb_logger = WandbLogger (project = "MLOps Basics" , entity = "raviraja" )
6469 trainer = pl .Trainer (
6570 max_epochs = cfg .training .max_epochs ,
6671 logger = wandb_logger ,
67- callbacks = [checkpoint_callback , SamplesVisualisationLogger (cola_data )],
72+ callbacks = [checkpoint_callback , SamplesVisualisationLogger (cola_data ), early_stopping_callback ],
6873 log_every_n_steps = cfg .training .log_every_n_steps ,
6974 deterministic = cfg .training .deterministic ,
7075 # limit_train_batches=cfg.training.limit_train_batches,
Original file line number Diff line number Diff line change 77import pytorch_lightning as pl
88from omegaconf .omegaconf import OmegaConf
99from pytorch_lightning .callbacks import ModelCheckpoint
10+ from pytorch_lightning .callbacks .early_stopping import EarlyStopping
1011from pytorch_lightning .loggers import WandbLogger
1112
1213from data import DataModule
@@ -60,11 +61,15 @@ def main(cfg):
6061 mode = "min" ,
6162 )
6263
64+ early_stopping_callback = EarlyStopping (
65+ monitor = "val_loss" , patience = 3 , verbose = True , mode = "min"
66+ )
67+
6368 wandb_logger = WandbLogger (project = "MLOps Basics" , entity = "raviraja" )
6469 trainer = pl .Trainer (
6570 max_epochs = cfg .training .max_epochs ,
6671 logger = wandb_logger ,
67- callbacks = [checkpoint_callback , SamplesVisualisationLogger (cola_data )],
72+ callbacks = [checkpoint_callback , SamplesVisualisationLogger (cola_data ), early_stopping_callback ],
6873 log_every_n_steps = cfg .training .log_every_n_steps ,
6974 deterministic = cfg .training .deterministic ,
7075 # limit_train_batches=cfg.training.limit_train_batches,
Original file line number Diff line number Diff line change 77import pytorch_lightning as pl
88from omegaconf .omegaconf import OmegaConf
99from pytorch_lightning .callbacks import ModelCheckpoint
10+ from pytorch_lightning .callbacks .early_stopping import EarlyStopping
1011from pytorch_lightning .loggers import WandbLogger
1112
1213from data import DataModule
@@ -60,11 +61,15 @@ def main(cfg):
6061 mode = "min" ,
6162 )
6263
64+ early_stopping_callback = EarlyStopping (
65+ monitor = "val_loss" , patience = 3 , verbose = True , mode = "min"
66+ )
67+
6368 wandb_logger = WandbLogger (project = "MLOps Basics" , entity = "raviraja" )
6469 trainer = pl .Trainer (
6570 max_epochs = cfg .training .max_epochs ,
6671 logger = wandb_logger ,
67- callbacks = [checkpoint_callback , SamplesVisualisationLogger (cola_data )],
72+ callbacks = [checkpoint_callback , SamplesVisualisationLogger (cola_data ), early_stopping_callback ],
6873 log_every_n_steps = cfg .training .log_every_n_steps ,
6974 deterministic = cfg .training .deterministic ,
7075 # limit_train_batches=cfg.training.limit_train_batches,
You can’t perform that action at this time.
0 commit comments