Skip to content
Merged
Prev Previous commit
Next Next commit
feat: add resume_from command line argument to resume from checkpoint,
…fix #31
  • Loading branch information
ydcjeff committed Mar 31, 2021
commit 2b82c65be7eaca5286da652cde7ebec044b10c3d
5 changes: 5 additions & 0 deletions templates/_base/_argparse.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ DEFAULTS = {
"action": "store_true",
"help": "use torch.cuda.amp for automatic mixed precision"
},
"resume_from": {
"default": None,
"type": str,
"help": "path to the checkpoint file to resume, can also url starting with https (None)"
},
"seed": {
"default": 666,
"type": int,
Expand Down
15 changes: 11 additions & 4 deletions templates/single/single_cg/main.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ from ignite.utils import manual_seed
from single_cg.engines import create_engines
from single_cg.events import TrainEvents
from single_cg.handlers import get_handlers, get_logger
from single_cg.utils import get_default_parser, setup_logging, log_metrics, log_basic_info, initialize
from single_cg.utils import get_default_parser, setup_logging, log_metrics, log_basic_info, initialize, resume_from


def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
Expand Down Expand Up @@ -72,7 +72,7 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
# ignite handlers and ignite loggers
# -------------------------------------

to_save = {"model": model, "optimizer": optimizer, "train_engine": train_engine}
to_save = {"model": model, "optimizer": optimizer, "train_engine": train_engine, "lr_scheduler": lr_scheduler}
best_model_handler, es_handler, timer_handler = get_handlers(
config=config,
model=model,
Expand All @@ -85,11 +85,18 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
# TODO : replace with the name you have input in the Code-Generator
# if you check `Early stop the training by evaluation score` otherwise leave it None
to_save=to_save,
lr_scheduler=None,
lr_scheduler=lr_scheduler,
output_names=None,
)
logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer)

# -----------------------------------
# resume from the saved checkpoints
# -----------------------------------

if config.resume_from:
resume_from(to_load=to_save, checkpoint_fp=config.resume_from)

# --------------------------------------------
# let's trigger custom events we registered
# we will use a `event_filter` to trigger that
Expand Down Expand Up @@ -157,7 +164,7 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
# where is my best and last checkpoint ?
# -----------------------------------------

logger.info(best_model_handler.last_checkpoint)
logger.info("Last and best checkpoint: %s", best_model_handler.last_checkpoint)


def main():
Expand Down
77 changes: 59 additions & 18 deletions templates/single/single_cg/utils.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,21 @@ utility functions which can be used in training
"""
import hashlib
import logging
import os
import shutil
from datetime import datetime
from logging import Logger
from pathlib import Path
from pprint import pformat
from typing import Any, Optional, Tuple, Union
from typing import Any, Mapping, Optional, Tuple, Union

import ignite.distributed as idist
import torch
from ignite.engine import Engine
from ignite.handlers.checkpoint import Checkpoint
from ignite.utils import setup_logger
from torch.nn import Module
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from torch.optim.optimizer import Optimizer

{% include "_argparse.pyi" %}

Expand Down Expand Up @@ -121,40 +120,82 @@ def setup_logging(config: Any) -> Logger:


def hash_checkpoint(
checkpoint: str,
checkpoint_fp: Union[str, Path],
jitted: bool,
output_path: Union[str, Path],
) -> Tuple[str, str]:
) -> Tuple[Path, str]:
"""Hash the checkpoint file to be used with `check_hash` of
`torch.hub.load_state_dict_from_url`.

Parameters
----------
checkpoint
checkpoint file.
checkpoint_fp
path to the checkpoint file.
jitted
indicate the checkpoint is already applied torch.jit or not.
output_path
path to store the hashed checkpoint file.

Returns
-------
filename and sha_hash
the hashed filename and SHA hash
hashed_fp and sha_hash
path to the hashed file and SHA hash
"""
with open(checkpoint, "rb") as file:
sha_hash = hashlib.sha256(file.read()).hexdigest()
if isinstance(checkpoint_fp, str):
checkpoint_fp = Path(checkpoint_fp)

sha_hash = hashlib.sha256(checkpoint_fp.read_bytes()).hexdigest()
ckpt_file_name = checkpoint_fp.stem

ckpt_file_name = os.path.splitext(checkpoint.split(os.sep)[-1])[0]
if jitted:
filename = "-".join((ckpt_file_name, sha_hash[:8])) + ".ptc"
hashed_fp = "-".join((ckpt_file_name, sha_hash[:8])) + ".ptc"
else:
filename = "-".join((ckpt_file_name, sha_hash[:8])) + ".pt"
hashed_fp = "-".join((ckpt_file_name, sha_hash[:8])) + ".pt"

if isinstance(output_path, str):
output_path = Path(output_path)

shutil.move(checkpoint, output_path / filename)
print("Saved state dict into %s | SHA256: %s", filename, sha_hash)
hashed_fp = output_path / hashed_fp
shutil.move(checkpoint_fp, hashed_fp)
print(f"Saved state dict into {hashed_fp} | SHA256: {sha_hash}")

return hashed_fp, sha_hash


def resume_from(
to_load: Mapping,
checkpoint_fp: Union[str, Path],
logger: Logger,
strict: bool = True,
model_dir: Optional[str] = None,
) -> None:
"""Loads state dict from a checkpoint file to resume the training.

Parameters
----------
to_load
a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, …}
checkpoint_fp
path to the checkpoint file
logger
to log info about resuming from a checkpoint
strict
whether to strictly enforce that the keys in `state_dict` match the keys
returned by this module’s `state_dict()` function. Default: True
model_dir
directory in which to save the object
"""
if isinstance(checkpoint_fp, str) and checkpoint_fp.startswith("https"):
checkpoint = torch.hub.load_state_dict_from_url(
checkpoint_fp, model_dir=model_dir, map_location="cpu", check_hash=True
)
else:
if isinstance(checkpoint_fp, str):
checkpoint_fp = Path(checkpoint_fp)

if not checkpoint_fp.exists():
raise FileNotFoundError(f"Given {str(checkpoint_fp)} does not exist.")
checkpoint = torch.load(checkpoint_fp, map_location="cpu")

return filename, sha_hash
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint, strict=strict)
logger.info("Successfully resumed from a checkpoint: %s", checkpoint_fp)
71 changes: 59 additions & 12 deletions templates/single/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import logging
from pathlib import Path
from tempfile import TemporaryDirectory
import unittest
from argparse import ArgumentParser, Namespace
from pathlib import Path
from tempfile import TemporaryDirectory

import torch
from ignite.engine import Engine
from ignite.utils import setup_logger

from single_cg.utils import get_default_parser, log_metrics, setup_logging, hash_checkpoint
from single_cg.utils import (
get_default_parser,
hash_checkpoint,
log_metrics,
resume_from,
setup_logging,
)


class TestUtils(unittest.TestCase):
Expand Down Expand Up @@ -46,20 +51,62 @@ def test_hash_checkpoint(self):
torch.jit.save(scripted_model, f"{tmp}/squeezenet1_0.ckptc")
# download un-jitted model
torch.hub.download_url_to_file(
"https://download.pytorch.org/models/squeezenet1_0-a815701f.pth",
"https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
f"{tmp}/squeezenet1_0.ckpt",
)

checkpoint = f"{tmp}/squeezenet1_0.ckpt"
filename, sha_hash = hash_checkpoint(checkpoint, False, tmp)
model.load_state_dict(torch.load(f"{tmp}/{filename}"), True)
self.assertEqual(sha_hash[:8], "a815701f")
self.assertEqual(filename, f"squeezenet1_0-{sha_hash[:8]}.pt")
hashed_fp, sha_hash = hash_checkpoint(checkpoint, False, tmp)
model.load_state_dict(torch.load(hashed_fp), True)
self.assertEqual(sha_hash[:8], "b66bff10")
self.assertEqual(hashed_fp.name, f"squeezenet1_0-{sha_hash[:8]}.pt")

checkpoint = f"{tmp}/squeezenet1_0.ckptc"
filename, sha_hash = hash_checkpoint(checkpoint, True, tmp)
scripted_model = torch.jit.load(f"{tmp}/{filename}")
self.assertEqual(filename, f"squeezenet1_0-{sha_hash[:8]}.ptc")
hashed_fp, sha_hash = hash_checkpoint(checkpoint, True, tmp)
scripted_model = torch.jit.load(hashed_fp)
self.assertEqual(hashed_fp.name, f"squeezenet1_0-{sha_hash[:8]}.ptc")

def test_resume_from_url(self):
logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)
with TemporaryDirectory() as tmp:
checkpoint_fp = "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth"
model = torch.hub.load("pytorch/vision", "squeezenet1_0")
to_load = {"model": model}
with self.assertLogs() as log:
resume_from(to_load, checkpoint_fp, logger, model_dir=tmp)
self.assertRegex(log.output[0], r"Successfully resumed from a checkpoint", "checkpoint fail to load")

def test_resume_from_fp(self):
logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)
with TemporaryDirectory() as tmp:
torch.hub.download_url_to_file(
"https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
f"{tmp}/squeezenet1_0.pt",
)
checkpoint_fp = f"{tmp}/squeezenet1_0.pt"
model = torch.hub.load("pytorch/vision", "squeezenet1_0")
to_load = {"model": model}
with self.assertLogs() as log:
resume_from(to_load, checkpoint_fp, logger)
self.assertRegex(log.output[0], r"Successfully resumed from a checkpoint", "checkpoint fail to load")

with TemporaryDirectory() as tmp:
torch.hub.download_url_to_file(
"https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
f"{tmp}/squeezenet1_0.pt",
)
checkpoint_fp = Path(f"{tmp}/squeezenet1_0.pt")
model = torch.hub.load("pytorch/vision", "squeezenet1_0")
to_load = {"model": model}
with self.assertLogs() as log:
resume_from(to_load, checkpoint_fp, logger)
self.assertRegex(log.output[0], r"Successfully resumed from a checkpoint", "checkpoint fail to load")

def test_resume_from_error(self):
with self.assertRaisesRegex(FileNotFoundError, r"Given \w+ does not exist"):
resume_from({}, "abcdef/", None)


if __name__ == "__main__":
Expand Down