mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 13:05:26 +00:00
[upgrade] upgrade gpt2 (#6291)
* fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -38,14 +38,8 @@ class GPT2Policy(Policy):
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": GPT2Attention,
|
||||
}
|
||||
|
||||
policy = {}
|
||||
|
||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = col_nn.VocabParallelEmbedding1D
|
||||
@@ -280,7 +274,7 @@ class GPT2Policy(Policy):
|
||||
"forward": get_gpt2_flash_attention_forward(shard_config=self.shard_config),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
target_key=GPT2Attention,
|
||||
)
|
||||
|
||||
if not self.shard_config.pipeline_stage_manager and self.shard_config.enable_sequence_parallelism:
|
||||
|
Reference in New Issue
Block a user