Skip to content

Pytorch T5 does not run on GPU #2472

@nreimers

Description

@nreimers

🐛 Bug

When I try to run T5 from the latest transformers version (and also from the most recent git version) on the GPU, I get the following error:

Traceback (most recent call last): File "T5_example.py", line 32, in <module> outputs = model(input_ids=input_ids) File "/home/reimers/anaconda3/envs/sbert/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__ result = self.forward(*input, **kwargs) File "/home/reimers/sbert/transformers/src/transformers/modeling_t5.py", line 780, in forward File "/home/reimers/anaconda3/envs/sbert/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__ result = self.forward(*input, **kwargs) File "/home/reimers/sbert/transformers/src/transformers/modeling_t5.py", line 616, in forward encoder_decoder_position_bias=encoder_decoder_position_bias, File "/home/reimers/anaconda3/envs/sbert/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__ result = self.forward(*input, **kwargs) File "/home/reimers/sbert/transformers/src/transformers/modeling_t5.py", line 422, in forward self_attention_outputs = self.layer[0]( File "/home/reimers/anaconda3/envs/sbert/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__ result = self.forward(*input, **kwargs) File "/home/reimers/sbert/transformers/src/transformers/modeling_t5.py", line 373, in forward attention_output = self.SelfAttention( File "/home/reimers/anaconda3/envs/sbert/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__ result = self.forward(*input, **kwargs) File "/home/reimers/sbert/transformers/src/transformers/modeling_t5.py", line 338, in forward raise ValueError("No position_bias provided and no weights to compute position_bias") File "/home/reimers/sbert/transformers/src/transformers/modeling_t5.py", line 289, in compute_bias values = self.relative_attention_bias(rp_bucket) File "/home/reimers/anaconda3/envs/sbert/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__ result = self.forward(*input, **kwargs) File "/home/reimers/anaconda3/envs/sbert/lib/python3.7/site-packages/torch/nn/modules/sparse.py", line 114, in forward self.norm_type, self.scale_grad_by_freq, self.sparse) File "/home/reimers/anaconda3/envs/sbert/lib/python3.7/site-packages/torch/nn/functional.py", line 1484, in embedding return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) RuntimeError: Expected object of device type cuda but got device type cpu for argument #3 'index' in call to _th_index_select 

This is the example code to reproduce the problem:

from transformers import T5Model, T5Tokenizer import torch tokenizer = T5Tokenizer.from_pretrained('t5-small') model = T5Model.from_pretrained('t5-small') model = model.to('cuda') input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute"), device='cuda').unsqueeze(0) outputs = model(input_ids=input_ids) last_hidden_states = outputs[0] 

The error is in the file modeling_t5.py at line 284-289:

rp_bucket = self._relative_position_bucket( relative_position, # shape (qlen, klen) bidirectional=not self.is_decoder, num_buckets=self.relative_attention_num_buckets, ) values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) 

rp_bucket is a tensor on the CPU, which causes the above error.

If I move rp_bucket to the GPU, the code works correctly on the GPU:

rp_bucket = self._relative_position_bucket( relative_position, # shape (qlen, klen) bidirectional=not self.is_decoder, num_buckets=self.relative_attention_num_buckets, ) rp_bucket = rp_bucket.to('cuda') #Dirty quick fix values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) 

I'm not sure why rp_bucket is on the CPU.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions