[upgrade] upgrade gpt2 (#6291)

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2025-05-08 14:10:21 +08:00
committed by GitHub
parent 8497ecc3e5
commit a4c6e189fa
3 changed files with 15 additions and 16 deletions

View File

@@ -38,14 +38,8 @@ class GPT2Policy(Policy):
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
ATTN_IMPLEMENTATION = {
"eager": GPT2Attention,
}
policy = {}
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
@@ -280,7 +274,7 @@ class GPT2Policy(Policy):
"forward": get_gpt2_flash_attention_forward(shard_config=self.shard_config),
},
policy=policy,
target_key=attn_cls,
target_key=GPT2Attention,
)
if not self.shard_config.pipeline_stage_manager and self.shard_config.enable_sequence_parallelism: