diff --git a/model_zoo/gpt/gpt.py b/model_zoo/gpt/gpt.py index dadbc152b..d7b5750fb 100644 --- a/model_zoo/gpt/gpt.py +++ b/model_zoo/gpt/gpt.py @@ -292,7 +292,7 @@ class GPT(nn.Module): # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # Adapted from huggingface if attention_mask is not None: - batch_size = x.shape[0] + batch_size = input_ids.shape[0] attention_mask = attention_mask.view(batch_size, -1) attention_mask = col_nn.partition_batch(attention_mask) attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)