Skip to content

Conversation

ha405
Copy link
Contributor

@ha405 ha405 commented Aug 22, 2025

This PR extends the validation metrics functionality (precision, recall, F1-score) to the train.py script.

Changes:

  • The validate function within train.py now supports the --metrics-avg flag.
  • Implemented torch.distributed.all_gather to correctly collect predictions and targets from all GPUs before calculating metrics on the primary process.
  • The feature remains a soft dependency on scikit-learn and is disabled by default.

This ensures that users can get these more detailed metrics during training, even in a multi-GPU environment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
1 participant