Skip to content

Conversation

@JackCaoG
Copy link
Collaborator

@JackCaoG JackCaoG commented Aug 29, 2024

was able to reduce the tracing time of gmm from 6ms to 2.4 ms
image
image

@JackCaoG JackCaoG requested a review from alanwaketan August 29, 2024 00:27
@JackCaoG JackCaoG added the tpuci label Aug 29, 2024
@JackCaoG
Copy link
Collaborator Author

still need to add a test for the cache miss case.

global trace_pallas_arg_to_payload
# implcit assumption here that everything in kwargs is hashable and not a tensor,
# which is true for the gmm and tgmm.
hash_key = (kernel, static_argnums, tuple(static_argnames), tuple(jax_args),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this work with different objects but with the same size, dtype and device?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax_args are just meta tensors, I verified that same size will always map to the same hash. we are not hashing the id(static_argnames) so as long as the value is the same it will generate the same hash.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's interesting. I guess if it works it works. Then why don't just use @cache?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my understanding is that @cache cache the input, inputs of this functions are xla tensor, I felt like cache will try to access the value of those tensors. in here I only cache the JAX meta tensor.

Also let me reverify this with the real moe models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. That's fair.

@JackCaoG
Copy link
Collaborator Author

verified in the profile that trace_pallas is cached.

@JackCaoG JackCaoG marked this pull request as ready for review August 30, 2024 00:55
@JackCaoG JackCaoG merged commit 8955571 into master Aug 30, 2024
@JackCaoG JackCaoG deleted the JackCaoG/trace_pallas_cache branch August 30, 2024 00:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants