Skip to content

Conversation

znado
Copy link
Contributor

@znado znado commented May 8, 2022

  • refactoring dataset iterators to return Dict[str, Any] instead of a tuple/triplet of inputs/labels/masks
  • cleaning up some of the WMT code
  • standardizing param shape/type utilities
  • small cleanups/fixes to get all reference submissions to work with random inputs for 1 train and 2 eval steps (takes ~272s total to test)
@znado znado requested a review from fsschneider May 8, 2022 21:06
@github-actions
Copy link

github-actions bot commented May 8, 2022

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

Copy link
Contributor

@fsschneider fsschneider left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really awesome!

I added a few comments and nits. I think, after merging #66 and addressing these comments, this is ready to merge.

# flax.linen.Embed names the embedding parameter "embedding"
# https://github.com/google/flax/blob/main/flax/linen/linear.py#L604.
elif name == 'embedding':
param_types_dict[name] = spec.ParameterType.EMBEDDING
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it problematic that JAX gets more information than PyTorch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. is the way I added the right way to check the name of an nn.Embedding layer in Pytorch? we should add a test to make sure (for jax and pytorch) these are properly set (at least, for some selection of params). I added a TODO to my list of things to follow up with.


if is_train:
dataloader = cycle(dataloader)
dataloader = cycle(dataloader)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious why you switched to also cycle the evaluation datasets?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it so that we can iterate through them across multiple evals. I'm not sure if this is the best way to do it.

@znado znado merged commit cf14d26 into mlcommons:main May 17, 2022
@github-actions github-actions bot locked and limited conversation to collaborators May 17, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

2 participants