- Notifications
You must be signed in to change notification settings - Fork 74
Continuing cleanup. #76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
znado commented May 8, 2022
- refactoring dataset iterators to return Dict[str, Any] instead of a tuple/triplet of inputs/labels/masks
- cleaning up some of the WMT code
- standardizing param shape/type utilities
- small cleanups/fixes to get all reference submissions to work with random inputs for 1 train and 2 eval steps (takes ~272s total to test)
MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really awesome!
I added a few comments and nits. I think, after merging #66 and addressing these comments, this is ready to merge.
# flax.linen.Embed names the embedding parameter "embedding" | ||
# https://github.com/google/flax/blob/main/flax/linen/linear.py#L604. | ||
elif name == 'embedding': | ||
param_types_dict[name] = spec.ParameterType.EMBEDDING |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it problematic that JAX gets more information than PyTorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. is the way I added the right way to check the name of an nn.Embedding layer in Pytorch? we should add a test to make sure (for jax and pytorch) these are properly set (at least, for some selection of params). I added a TODO to my list of things to follow up with.
| ||
if is_train: | ||
dataloader = cycle(dataloader) | ||
dataloader = cycle(dataloader) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious why you switched to also cycle the evaluation datasets?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added it so that we can iterate through them across multiple evals. I'm not sure if this is the best way to do it.