This commit is contained in:
flybird11111 2025-05-01 09:04:24 +08:00
parent 840b9f3266
commit 519c2d0eab
3 changed files with 6 additions and 7 deletions

View File

@ -85,8 +85,6 @@ def _get_attention_mask(
attention_mask,
is_causal=True,
)
elif self._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")

View File

@ -47,7 +47,10 @@ class GPT2Policy(Policy):
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
<<<<<<< Updated upstream
print("embedding_cls", embedding_cls)
=======
>>>>>>> Stashed changes
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
@ -220,7 +223,6 @@ class GPT2Policy(Policy):
if embedding_cls is not None:
# padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
print("embedding_cls", embedding_cls)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="wte",
@ -391,7 +393,6 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
module_policy = super().module_policy()
module_policy[GPT2LMHeadModel] = ModulePolicyDescription()
if self.shard_config.enable_tensor_parallelism:
print("self.shard_config.enable_tensor_parallelism", self.shard_config.enable_tensor_parallelism)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
@ -428,7 +429,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
self.set_pipeline_forward(
model_cls=GPT2LMHeadModel,
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
# shard_config=self.shard_config,
shard_config=self.shard_config,
policy=module_policy,
)
return module_policy

View File

@ -165,11 +165,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": True,
"use_lazy_init": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},