diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index bc99be4cc..84deafefe 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -94,9 +94,9 @@ class GPT2PipelineForwards: if hidden_states is None: raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") input_shape = hidden_states.size()[:-1] - batch_size = input_shape[0] device = hidden_states.device hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:]) + batch_size = hidden_states.shape[0] # GPT2Attention mask. if attention_mask is not None: