- Notifications
You must be signed in to change notification settings - Fork 28
Safetensors #27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Safetensors #27
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR integrates the safetensors
library as the default serialization mechanism for the CkptMixin
class, replacing torch.save
and torch.load
. The change maintains backward compatibility while enabling more secure model serialization.
- Changes default behavior of the save method to use
safetensors
format - Adds comprehensive support for both single-file and directory-based safetensors formats
- Implements fallback mechanisms for loading legacy PyTorch checkpoint formats
Reviewed Changes
Copilot reviewed 31 out of 32 changed files in this pull request and generated 7 comments.
Show a summary per file
File | Description |
---|---|
torch_ecg/utils/utils_nn.py | Major update to CkptMixin with safetensors integration and type annotations |
torch_ecg/utils/misc.py | Minor type annotation fix for make_serializable function |
torch_ecg/utils/download.py | Bytes handling fix for path processing |
torch_ecg/models/grad_cam.py | Tensor operation ordering fix (detach before cpu) |
torch_ecg/models/ecg_crnn.py | Type annotation additions and import statement |
torch_ecg/models/_nets.py | Type annotation improvements across various classes |
torch_ecg/components/trainer.py | Updated save method calls and type annotations |
pyproject.toml | Addition of safetensors dependency |
Comments suppressed due to low confidence (1)
torch_ecg/utils/utils_nn.py:1
- This else clause sets output_shape to [None] but this appears to be a fallback case that may not be reachable given the preceding conditions. Consider adding a comment explaining when this case would occur or remove if it's truly unreachable.
"""
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
"transposeconvolution", | ||
]: | ||
out_channels = num_filters | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The else clause raises an exception after all layer type checks. However, this code appears to be unreachable since all valid layer types should be handled by the preceding conditions. Consider removing this else clause or adding a comment explaining when this condition could be reached.
else: | |
else: | |
# This branch should be unreachable if all valid layer types are handled above. | |
# Retained as a safeguard in case an unknown or misspelled layer type is passed. |
Copilot uses AI. Check for mistakes.
save_filename = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}.pth.tar" | ||
save_path = self.train_config.checkpoints / save_filename | ||
if self.train_config.keep_checkpoint_max != 0: | ||
# save_filename = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}.pth.tar" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The commented out line should be removed since the code has been updated to use folder-based saving instead of file-based saving. Keeping commented code can cause confusion about the intended behavior.
# save_filename = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}.pth.tar" |
Copilot uses AI. Check for mistakes.
# save_filename = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}.pth.tar" | ||
save_folder = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}" | ||
save_path = self.train_config.model_dir / save_folder # type: ignore | ||
# self.save_checkpoint(path=str(save_path)) | ||
self._model.save(path=str(save_path), train_config=self.train_config) | ||
self.log_manager.log_message(f"best model is saved at {save_path}") | ||
elif self.train_config.monitor is None: | ||
self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model") | ||
self.log_manager.log_message(f"best model is saved at {save_path}") # type: ignore | ||
elif self.train_config.monitor is None: # type: ignore | ||
self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model") # type: ignore | ||
self.best_state_dict = self._model.state_dict() | ||
save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar" | ||
save_path = self.train_config.model_dir / save_filename | ||
# save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar" | ||
save_folder = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}" | ||
save_path = self.train_config.model_dir / save_folder # type: ignore | ||
# self.save_checkpoint(path=str(save_path)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the previous comment, this commented out line should be removed to avoid confusion about the current implementation approach.
Copilot uses AI. Check for mistakes.
# save_filename = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}.pth.tar" | ||
save_folder = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}" | ||
save_path = self.train_config.model_dir / save_folder # type: ignore | ||
# self.save_checkpoint(path=str(save_path)) | ||
self._model.save(path=str(save_path), train_config=self.train_config) | ||
self.log_manager.log_message(f"best model is saved at {save_path}") | ||
elif self.train_config.monitor is None: | ||
self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model") | ||
self.log_manager.log_message(f"best model is saved at {save_path}") # type: ignore | ||
elif self.train_config.monitor is None: # type: ignore | ||
self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model") # type: ignore | ||
self.best_state_dict = self._model.state_dict() | ||
save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar" | ||
save_path = self.train_config.model_dir / save_filename | ||
# save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar" | ||
save_folder = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}" | ||
save_path = self.train_config.model_dir / save_folder # type: ignore | ||
# self.save_checkpoint(path=str(save_path)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another commented out line that should be removed to maintain clean code without obsolete references to the old file-based approach.
Copilot uses AI. Check for mistakes.
if not str(path).endswith(".pth.tar"): | ||
path = Path(path).with_suffix(".pth.tar") # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type ignore comment suggests potential type mismatch issues. Consider using proper type handling instead of ignoring type checking, especially since Path.with_suffix() should work correctly with string inputs.
Copilot uses AI. Check for mistakes.
Add a check to ensure the model has at least one module before the loop Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Update the error message for failing to load a safetensors file Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@ ## master #27 +/- ## ========================================== + Coverage 92.86% 93.21% +0.34% ========================================== Files 138 137 -1 Lines 19020 18325 -695 ========================================== - Hits 17663 17081 -582 + Misses 1357 1244 -113 ☔ View full report in Codecov by Sentry. |
This PR typically changes the default behavior of the save method of the
CkptMixin
class. Now it uses thesave_file
method fromsafetensors
instead oftorch.save
by default. See the comparison of the model saving mechanisms. Thesave
method now has the following signatureThis change is backward compatible. One is also able to save the models in
pth
/pt
format like previously, by explicitly settinguse_safetensors=False
. Theload
method is able to loadpth
/pt
format models correctly.