Skip to content

Conversation

wenh06
Copy link
Collaborator

@wenh06 wenh06 commented Sep 25, 2025

This PR typically changes the default behavior of the save method of the CkptMixin class. Now it uses the save_file method from safetensors instead of torch.save by default. See the comparison of the model saving mechanisms. The save method now has the following signature

 def save( self, path: Union[str, bytes, os.PathLike], train_config: CFG, extra_items: Optional[dict] = None, use_safetensors: bool = True, safetensors_single_file: bool = True, ) -> None: """Save the model to disk.   .. note::   `safetensors` is used by default to save the model.  If one wants to save the models in `.pth` or `.pt` format,  he/she must explicitly set ``use_safetensors=False``.   Parameters  ----------  path : `path-like`  Path to save the model.  train_config : CFG  Config for training the model,  used when one restores the model.  extra_items : dict, optional  Extra items to save along with the model.  The values should be serializable: can be saved as a json file,  or is a dict of torch tensors.   .. versionadded:: 0.0.32  use_safetensors : bool, default True  Whether to use `safetensors` to save the model.  This will be overridden by the suffix of `path`:  if it is `.safetensors`, then `use_safetensors` is set to True;  if it is `.pth` or `.pt`, then if `use_safetensors` is True,  the suffix is changed to `.safetensors`, otherwise it is unchanged.   .. versionadded:: 0.0.32  safetensors_single_file : bool, default True  Whether to save the metadata along with the state dict into one file.   .. versionadded:: 0.0.32   Returns  -------  None   """ ...

This change is backward compatible. One is also able to save the models in pth/pt format like previously, by explicitly setting use_safetensors=False. The load method is able to load pth/pt format models correctly.

@Copilot Copilot AI review requested due to automatic review settings September 25, 2025 14:09
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR integrates the safetensors library as the default serialization mechanism for the CkptMixin class, replacing torch.save and torch.load. The change maintains backward compatibility while enabling more secure model serialization.

  • Changes default behavior of the save method to use safetensors format
  • Adds comprehensive support for both single-file and directory-based safetensors formats
  • Implements fallback mechanisms for loading legacy PyTorch checkpoint formats

Reviewed Changes

Copilot reviewed 31 out of 32 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
torch_ecg/utils/utils_nn.py Major update to CkptMixin with safetensors integration and type annotations
torch_ecg/utils/misc.py Minor type annotation fix for make_serializable function
torch_ecg/utils/download.py Bytes handling fix for path processing
torch_ecg/models/grad_cam.py Tensor operation ordering fix (detach before cpu)
torch_ecg/models/ecg_crnn.py Type annotation additions and import statement
torch_ecg/models/_nets.py Type annotation improvements across various classes
torch_ecg/components/trainer.py Updated save method calls and type annotations
pyproject.toml Addition of safetensors dependency
Comments suppressed due to low confidence (1)

torch_ecg/utils/utils_nn.py:1

  • This else clause sets output_shape to [None] but this appears to be a fallback case that may not be reachable given the preceding conditions. Consider adding a comment explaining when this case would occur or remove if it's truly unreachable.
""" 

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

"transposeconvolution",
]:
out_channels = num_filters
else:
Copy link
Preview

Copilot AI Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The else clause raises an exception after all layer type checks. However, this code appears to be unreachable since all valid layer types should be handled by the preceding conditions. Consider removing this else clause or adding a comment explaining when this condition could be reached.

Suggested change
else:
else:
# This branch should be unreachable if all valid layer types are handled above.
# Retained as a safeguard in case an unknown or misspelled layer type is passed.

Copilot uses AI. Check for mistakes.

save_filename = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}.pth.tar"
save_path = self.train_config.checkpoints / save_filename
if self.train_config.keep_checkpoint_max != 0:
# save_filename = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}.pth.tar"
Copy link
Preview

Copilot AI Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The commented out line should be removed since the code has been updated to use folder-based saving instead of file-based saving. Keeping commented code can cause confusion about the intended behavior.

Suggested change
# save_filename = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}.pth.tar"

Copilot uses AI. Check for mistakes.

Comment on lines +282 to 294
# save_filename = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}.pth.tar"
save_folder = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}"
save_path = self.train_config.model_dir / save_folder # type: ignore
# self.save_checkpoint(path=str(save_path))
self._model.save(path=str(save_path), train_config=self.train_config)
self.log_manager.log_message(f"best model is saved at {save_path}")
elif self.train_config.monitor is None:
self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model")
self.log_manager.log_message(f"best model is saved at {save_path}") # type: ignore
elif self.train_config.monitor is None: # type: ignore
self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model") # type: ignore
self.best_state_dict = self._model.state_dict()
save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar"
save_path = self.train_config.model_dir / save_filename
# save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar"
save_folder = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}"
save_path = self.train_config.model_dir / save_folder # type: ignore
# self.save_checkpoint(path=str(save_path))
Copy link
Preview

Copilot AI Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the previous comment, this commented out line should be removed to avoid confusion about the current implementation approach.

Copilot uses AI. Check for mistakes.

Comment on lines +282 to 294
# save_filename = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}.pth.tar"
save_folder = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}"
save_path = self.train_config.model_dir / save_folder # type: ignore
# self.save_checkpoint(path=str(save_path))
self._model.save(path=str(save_path), train_config=self.train_config)
self.log_manager.log_message(f"best model is saved at {save_path}")
elif self.train_config.monitor is None:
self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model")
self.log_manager.log_message(f"best model is saved at {save_path}") # type: ignore
elif self.train_config.monitor is None: # type: ignore
self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model") # type: ignore
self.best_state_dict = self._model.state_dict()
save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar"
save_path = self.train_config.model_dir / save_filename
# save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar"
save_folder = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}"
save_path = self.train_config.model_dir / save_folder # type: ignore
# self.save_checkpoint(path=str(save_path))
Copy link
Preview

Copilot AI Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another commented out line that should be removed to maintain clean code without obsolete references to the old file-based approach.

Copilot uses AI. Check for mistakes.

Comment on lines +781 to +782
if not str(path).endswith(".pth.tar"):
path = Path(path).with_suffix(".pth.tar") # type: ignore
Copy link
Preview

Copilot AI Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type ignore comment suggests potential type mismatch issues. Consider using proper type handling instead of ignoring type checking, especially since Path.with_suffix() should work correctly with string inputs.

Copilot uses AI. Check for mistakes.

wenh06 and others added 2 commits September 25, 2025 22:15
 Add a check to ensure the model has at least one module before the loop Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
 Update the error message for failing to load a safetensors file Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@wenh06 wenh06 added the enhancement New feature or request label Sep 25, 2025
@wenh06 wenh06 self-assigned this Sep 25, 2025
Copy link

codecov bot commented Sep 27, 2025

Codecov Report

❌ Patch coverage is 91.24088% with 72 lines in your changes missing coverage. Please review.
✅ Project coverage is 93.21%. Comparing base (3d8008e) to head (6575e98).
⚠️ Report is 9 commits behind head on master.
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
torch_ecg/components/trainer.py 79.74% 16 Missing ⚠️
torch_ecg/utils/misc.py 88.49% 13 Missing ⚠️
torch_ecg/utils/download.py 87.95% 10 Missing ⚠️
torch_ecg/utils/utils_nn.py 92.06% 10 Missing ⚠️
torch_ecg/models/ecg_crnn.py 87.87% 8 Missing ⚠️
torch_ecg/databases/base.py 84.78% 7 Missing ⚠️
torch_ecg/models/grad_cam.py 0.00% 3 Missing ⚠️
torch_ecg/databases/physionet_databases/ludb.py 93.93% 2 Missing ⚠️
torch_ecg/models/_nets.py 96.07% 2 Missing ⚠️
torch_ecg/utils/utils_data.py 93.75% 1 Missing ⚠️
Additional details and impacted files
@@ Coverage Diff @@ ## master #27 +/- ## ========================================== + Coverage 92.86% 93.21% +0.34%  ========================================== Files 138 137 -1 Lines 19020 18325 -695 ========================================== - Hits 17663 17081 -582  + Misses 1357 1244 -113 

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@wenh06 wenh06 enabled auto-merge October 7, 2025 08:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
1 participant