This commit is contained in:
flybird11111 2025-05-01 09:00:13 +08:00
parent fdcc3691fa
commit f8caea7762
4 changed files with 61 additions and 77 deletions

View File

@ -39,7 +39,6 @@ def _get_attention_mask(
attention_mask: Optional[torch.FloatTensor], attention_mask: Optional[torch.FloatTensor],
encoder_hidden_states: Optional[torch.Tensor], encoder_hidden_states: Optional[torch.Tensor],
encoder_attention_mask: Optional[torch.FloatTensor], encoder_attention_mask: Optional[torch.FloatTensor],
head_mask: Optional[torch.Tensor] = None,
) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]: ) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]:
# Received input is already split for non-first pipeline stages, # Received input is already split for non-first pipeline stages,
# but attn mask isn't # but attn mask isn't
@ -49,9 +48,7 @@ def _get_attention_mask(
sp_mode = shard_config.sequence_parallelism_mode sp_mode = shard_config.sequence_parallelism_mode
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
_use_sdpa = self._attn_implementation == "sdpa"
print("_use_sdpa", _use_sdpa)
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
if self.config.add_cross_attention and encoder_hidden_states is not None: if self.config.add_cross_attention and encoder_hidden_states is not None:
assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only." assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only."
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
@ -59,7 +56,7 @@ def _get_attention_mask(
encoder_attention_mask = ColoAttention.prepare_attn_kwargs( encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
(encoder_batch_size, 1, seq_len, encoder_sequence_length), (encoder_batch_size, 1, seq_len, encoder_sequence_length),
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
dtype2=encoder_hidden_states.dtype, device=encoder_hidden_states.device,
q_padding_mask=attention_mask, q_padding_mask=attention_mask,
kv_padding_mask=encoder_attention_mask, kv_padding_mask=encoder_attention_mask,
) )
@ -67,11 +64,6 @@ def _get_attention_mask(
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None: if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device) encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device)
if _use_sdpa:
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
mask=encoder_attention_mask, dtype=hidden_states.dtype, tgt_len=encoder_hidden_shape[-1]
)
elif not self._attn_implementation == "flash_attention_2":
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else: else:
if shard_config.enable_flash_attention: if shard_config.enable_flash_attention:
@ -95,13 +87,6 @@ def _get_attention_mask(
) )
elif self._attn_implementation == "flash_attention_2": elif self._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif _use_sdpa:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(batch_size, hidden_states[-1]),
inputs_embeds=hidden_states,
past_key_values_length=past_key_values_length,
)
elif attention_mask is not None: elif attention_mask is not None:
if batch_size <= 0: if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0") raise ValueError("batch_size has to be defined and > 0")
@ -224,7 +209,6 @@ class GPT2PipelineForwards:
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask
) )

View File

@ -400,7 +400,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
suffix="lm_head", suffix="lm_head",
target_module=col_nn.VocabParallelLMHead1D, target_module=col_nn.VocabParallelLMHead1D,
kwargs={ kwargs={
"gather_output": True, "gather_output": False,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
}, },
), ),
@ -418,20 +418,20 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
target_key=GPT2LMHeadModel, target_key=GPT2LMHeadModel,
) )
# if self.shard_config.parallel_output: if self.shard_config.parallel_output:
# self.append_or_create_method_replacement( self.append_or_create_method_replacement(
# description={ description={
# "forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config) "forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config)
# }, },
# policy=module_policy, policy=module_policy,
# target_key=GPT2LMHeadModel, target_key=GPT2LMHeadModel,
# ) )
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=GPT2LMHeadModel, model_cls=GPT2LMHeadModel,
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
shard_config=self.shard_config, # shard_config=self.shard_config,
policy=module_policy, policy=module_policy,
) )
return module_policy return module_policy

View File

@ -127,43 +127,43 @@ model_zoo.register(
loss_fn=loss_fn_for_gpt2_model, loss_fn=loss_fn_for_gpt2_model,
model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
) )
# model_zoo.register( model_zoo.register(
# name="transformers_gpt_lm", name="transformers_gpt_lm",
# model_fn=lambda: transformers.GPT2LMHeadModel(config), model_fn=lambda: transformers.GPT2LMHeadModel(config),
# data_gen_fn=data_gen_for_lm, data_gen_fn=data_gen_for_lm,
# output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
# loss_fn=loss_fn, loss_fn=loss_fn,
# model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
# ) )
# model_zoo.register( model_zoo.register(
# name="transformers_gpt_double_heads", name="transformers_gpt_double_heads",
# model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
# data_gen_fn=date_gen_for_double_heads, data_gen_fn=date_gen_for_double_heads,
# output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
# loss_fn=lambda x: x.loss + x.mc_loss, loss_fn=lambda x: x.loss + x.mc_loss,
# model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
# ) )
# model_zoo.register( model_zoo.register(
# name="transformers_gpt_for_question_answering", name="transformers_gpt_for_question_answering",
# model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
# data_gen_fn=data_gen_for_question_answering, data_gen_fn=data_gen_for_question_answering,
# output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
# loss_fn=loss_fn, loss_fn=loss_fn,
# model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
# ) )
# model_zoo.register( model_zoo.register(
# name="transformers_gpt_for_token_classification", name="transformers_gpt_for_token_classification",
# model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification), model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification),
# data_gen_fn=data_gen_for_token_classification, data_gen_fn=data_gen_for_token_classification,
# output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
# loss_fn=loss_fn, loss_fn=loss_fn,
# model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
# ) )
# model_zoo.register( model_zoo.register(
# name="transformers_gpt_for_sequence_classification", name="transformers_gpt_for_sequence_classification",
# model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification), model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification),
# data_gen_fn=data_gen_for_sequence_classification, data_gen_fn=data_gen_for_sequence_classification,
# output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
# loss_fn=loss_fn, loss_fn=loss_fn,
# model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
# ) )

View File

@ -157,18 +157,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring_attn", "sequence_parallelism_mode": "ring_attn",
"num_microbatches": 1, "num_microbatches": 1,
# "enable_all_optimization": True, "enable_all_optimization": True,
"use_lazy_init": True, "use_lazy_init": True,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
{ {
"tp_size": 2, "tp_size": 4,
"pp_size": 1, "pp_size": 1,
"num_microbatches": 1, "num_microbatches": 2,
"enable_sequence_parallelism": False, "enable_sequence_parallelism": True,
# "sequence_parallelism_mode": "split_gather", "sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False, "enable_flash_attention": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
@ -238,7 +238,7 @@ def run_gpt2_test(test_config):
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
"initial_scale": 1, "initial_scale": 1,
@ -247,7 +247,7 @@ def run_gpt2_test(test_config):
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp16", "precision": "fp16",
"zero_stage": 1, "zero_stage": 1,