- Notifications
You must be signed in to change notification settings - Fork 31.5k
Open
Labels
Description
Feature request
Is there a tutorial for using DeepSpeed's activation checkpointing instead of PyTorch's?
I'm using Trainer with ZeRO integration to train my model. Here's my code:
if training_args.deepspeed_gradient_checkpointing and training_args.deepspeed: from deepspeed.runtime.activation_checkpointing.checkpointing import configure configure(mpu_=None) from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint model._set_gradient_checkpointing(training_args.deepspeed_gradient_checkpointing, checkpoint){ "activation_checkpointing": { "partition_activations": true, "cpu_checkpointing": true, "contiguous_memory_optimization": false, "number_checkpoints": null, "synchronize_checkpoint_boundary": false, "profile": false } }torchrun --nproc_per_node=8 \ --nnodes=${NNODES} \ --node_rank=${NODE_RANK} \ --master_addr=${MASTER_ADDR} \ --master_port=${MASTER_PORT} \ train.py \ --deepspeed ${DEEPSPEED_CONFIG_PATH} \ --gradient_checkpointing FalseHowever, I got this in FlashAttention2:
class XXXFlashAttention2(XXXAttention): def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: output_attentions = False bsz, q_len, _ = hidden_states.size() # <---- this got error query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) File "modeling_xxx.py", line 518, in forward bsz, q_len, _ = hidden_states.size() ValueError: not enough values to unpack (expected 3, got 2) Motivation
It seems there isn't such a tutorial available at the moment in either deepspeed's tutorial or huggingface.
Your contribution
Provide my results
hrushikesh198 and veritas9872