Skip to content
Prev Previous commit
Next Next commit
Update torchrec_intro_tutorial.py
  • Loading branch information
isururanawaka authored Jul 10, 2025
commit 09db9b962db0bdbb67bc22a5a1d2a4012836a2bb
12 changes: 6 additions & 6 deletions intermediate_source/torchrec_intro_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@
# In order to train models with massive embedding tables, sharding these
# tables across GPUs is required, which then introduces a whole new set of
# problems and opportunities in parallelism and optimization. Luckily, we have
# the TorchRec library that has encountered, consolidated, and addressed
# the TorchRec library <https://docs.pytorch.org/torchrec/overview.html>`__ that has encountered, consolidated, and addressed
# many of these concerns. TorchRec serves as a **library that provides
# primitives for large scale distributed embeddings**.
#
Expand Down Expand Up @@ -496,11 +496,11 @@
#
# * **The module sharder**: This class exposes a ``shard`` API
# that handles sharding a TorchRec Module, producing a sharded module.
# * For ``EmbeddingBagCollection``, the sharder is `EmbeddingBagCollectionSharder <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder>`__
# * For ``EmbeddingBagCollection``, the sharder is `EmbeddingBagCollectionSharder `
# * **Sharded module**: This class is a sharded variant of a TorchRec module.
# It has the same input/output as a the regular TorchRec module, but much
# more optimized and works in a distributed environment.
# * For ``EmbeddingBagCollection``, the sharded variant is `ShardedEmbeddingBagCollection <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection>`__
# * For ``EmbeddingBagCollection``, the sharded variant is `ShardedEmbeddingBagCollection`
#
# Every TorchRec module has an unsharded and sharded variant.
#
Expand Down Expand Up @@ -619,7 +619,7 @@
# Remember that TorchRec is a highly optimized library for distributed
# embeddings. A concept that TorchRec introduces to enable higher
# performance for training on GPU is a
# `LazyAwaitable <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.types.LazyAwaitable>`__.
# `LazyAwaitable `.
# You will see ``LazyAwaitable`` types as outputs of various sharded
# TorchRec modules. All a ``LazyAwaitable`` type does is delay calculating some
# result as long as possible, and it does it by acting like an async type.
Expand Down Expand Up @@ -693,7 +693,7 @@ def _wait_impl(self) -> torch.Tensor:
# order for distribution of gradients. ``input_dist``, ``lookup``, and
# ``output_dist`` all depend on the sharding scheme. Since we sharded in a
# table-wise fashion, these APIs are modules that are constructed by
# `TwPooledEmbeddingSharding <https://pytorch.org/torchrec/torchrec.distributed.sharding.html#torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding>`__.
# `TwPooledEmbeddingSharding`.
#

sharded_ebc
Expand Down Expand Up @@ -742,7 +742,7 @@ def _wait_impl(self) -> torch.Tensor:
# ``EmbeddingBagCollection`` to generate a
# ``ShardedEmbeddingBagCollection`` module. This workflow is fine, but
# typically when implementing model parallel,
# `DistributedModelParallel <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.model_parallel.DistributedModelParallel>`__
# `DistributedModelParallel`
# (DMP) is used as the standard interface. When wrapping your model (in
# our case ``ebc``), with DMP, the following will occur:
#
Expand Down