diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index bb493be3e..c83142deb 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -211,7 +211,6 @@ class GPT2PipelineForwards: encoder_attention_mask, ) - if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( @@ -877,7 +876,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) ) else: attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) - + attn_output = attn_output.permute(0, 2, 1, 3).contiguous() attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous() attn_output = self.c_proj(attn_output) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index d5d97fd2d..b6370d632 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -40,7 +40,6 @@ class GPT2Policy(Policy): policy = {} - embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = col_nn.VocabParallelEmbedding1D @@ -48,8 +47,6 @@ class GPT2Policy(Policy): if self.tie_weight: embedding_cls = col_nn.PaddingEmbedding - - print("embedding_cls", embedding_cls) if self.shard_config.enable_fused_normalization: