mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-23 19:49:30 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
f8caea7762
commit
840b9f3266
@ -211,7 +211,6 @@ class GPT2PipelineForwards:
|
|||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@ -877,7 +876,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
|
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
|
||||||
|
|
||||||
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
|
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
|
||||||
attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
|
attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
|
||||||
attn_output = self.c_proj(attn_output)
|
attn_output = self.c_proj(attn_output)
|
||||||
|
@ -40,7 +40,6 @@ class GPT2Policy(Policy):
|
|||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
|
|
||||||
embedding_cls = None
|
embedding_cls = None
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
embedding_cls = col_nn.VocabParallelEmbedding1D
|
embedding_cls = col_nn.VocabParallelEmbedding1D
|
||||||
@ -48,8 +47,6 @@ class GPT2Policy(Policy):
|
|||||||
if self.tie_weight:
|
if self.tie_weight:
|
||||||
embedding_cls = col_nn.PaddingEmbedding
|
embedding_cls = col_nn.PaddingEmbedding
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print("embedding_cls", embedding_cls)
|
print("embedding_cls", embedding_cls)
|
||||||
|
|
||||||
if self.shard_config.enable_fused_normalization:
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
Loading…
Reference in New Issue
Block a user