Skip to content

Conversation

@lsy323
Copy link
Collaborator

@lsy323 lsy323 commented Mar 20, 2024

Support dynamism on aten.embedding and aten.split_with_sizes

  • The indices tensors to aten.embedding can have more than 1 dimension. During tracing, it will be flattened. To avoid the view op with symbolic output shape, we flatten the indices tensor at FX level and the view with symbolic output shape will be handled by aten.view -> xla.dynamic_view pass.

  • Add an FX pass to decompose aten.split_with_sizes into slice ops to reuse the dynamism added to slice.

Test:

  • New tests added to test the FX passes and the lowered StableHLO
@lsy323 lsy323 marked this pull request as ready for review March 20, 2024 17:15
@lsy323 lsy323 requested a review from qihqi March 20, 2024 17:17
@lsy323 lsy323 added the dynamism Dynamic Shape Features label Mar 20, 2024
@lsy323 lsy323 merged commit 6b76c7f into pytorch:master Mar 20, 2024
@lsy323 lsy323 deleted the lsiyuan/ebd-split branch March 20, 2024 23:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dynamism Dynamic Shape Features

2 participants