mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-14 22:23:23 +00:00
fix
This commit is contained in:
parent
fdcc3691fa
commit
f8caea7762
@ -39,7 +39,6 @@ def _get_attention_mask(
|
|||||||
attention_mask: Optional[torch.FloatTensor],
|
attention_mask: Optional[torch.FloatTensor],
|
||||||
encoder_hidden_states: Optional[torch.Tensor],
|
encoder_hidden_states: Optional[torch.Tensor],
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor],
|
encoder_attention_mask: Optional[torch.FloatTensor],
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]:
|
) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]:
|
||||||
# Received input is already split for non-first pipeline stages,
|
# Received input is already split for non-first pipeline stages,
|
||||||
# but attn mask isn't
|
# but attn mask isn't
|
||||||
@ -49,9 +48,7 @@ def _get_attention_mask(
|
|||||||
sp_mode = shard_config.sequence_parallelism_mode
|
sp_mode = shard_config.sequence_parallelism_mode
|
||||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
_use_sdpa = self._attn_implementation == "sdpa"
|
|
||||||
print("_use_sdpa", _use_sdpa)
|
|
||||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
|
|
||||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||||
assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only."
|
assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only."
|
||||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||||
@ -59,7 +56,7 @@ def _get_attention_mask(
|
|||||||
encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
|
encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||||
(encoder_batch_size, 1, seq_len, encoder_sequence_length),
|
(encoder_batch_size, 1, seq_len, encoder_sequence_length),
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
dtype2=encoder_hidden_states.dtype,
|
device=encoder_hidden_states.device,
|
||||||
q_padding_mask=attention_mask,
|
q_padding_mask=attention_mask,
|
||||||
kv_padding_mask=encoder_attention_mask,
|
kv_padding_mask=encoder_attention_mask,
|
||||||
)
|
)
|
||||||
@ -67,12 +64,7 @@ def _get_attention_mask(
|
|||||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||||
if encoder_attention_mask is None:
|
if encoder_attention_mask is None:
|
||||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device)
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device)
|
||||||
if _use_sdpa:
|
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
|
||||||
mask=encoder_attention_mask, dtype=hidden_states.dtype, tgt_len=encoder_hidden_shape[-1]
|
|
||||||
)
|
|
||||||
elif not self._attn_implementation == "flash_attention_2":
|
|
||||||
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
|
||||||
else:
|
else:
|
||||||
if shard_config.enable_flash_attention:
|
if shard_config.enable_flash_attention:
|
||||||
encoder_attention_mask = {"attention_mask": None}
|
encoder_attention_mask = {"attention_mask": None}
|
||||||
@ -95,13 +87,6 @@ def _get_attention_mask(
|
|||||||
)
|
)
|
||||||
elif self._attn_implementation == "flash_attention_2":
|
elif self._attn_implementation == "flash_attention_2":
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||||
elif _use_sdpa:
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
input_shape=(batch_size, hidden_states[-1]),
|
|
||||||
inputs_embeds=hidden_states,
|
|
||||||
past_key_values_length=past_key_values_length,
|
|
||||||
)
|
|
||||||
elif attention_mask is not None:
|
elif attention_mask is not None:
|
||||||
if batch_size <= 0:
|
if batch_size <= 0:
|
||||||
raise ValueError("batch_size has to be defined and > 0")
|
raise ValueError("batch_size has to be defined and > 0")
|
||||||
@ -224,7 +209,6 @@ class GPT2PipelineForwards:
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
head_mask
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -400,7 +400,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||||||
suffix="lm_head",
|
suffix="lm_head",
|
||||||
target_module=col_nn.VocabParallelLMHead1D,
|
target_module=col_nn.VocabParallelLMHead1D,
|
||||||
kwargs={
|
kwargs={
|
||||||
"gather_output": True,
|
"gather_output": False,
|
||||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
@ -418,20 +418,20 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||||||
target_key=GPT2LMHeadModel,
|
target_key=GPT2LMHeadModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if self.shard_config.parallel_output:
|
if self.shard_config.parallel_output:
|
||||||
# self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
# description={
|
description={
|
||||||
# "forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config)
|
"forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config)
|
||||||
# },
|
},
|
||||||
# policy=module_policy,
|
policy=module_policy,
|
||||||
# target_key=GPT2LMHeadModel,
|
target_key=GPT2LMHeadModel,
|
||||||
# )
|
)
|
||||||
|
|
||||||
if self.pipeline_stage_manager is not None:
|
if self.pipeline_stage_manager is not None:
|
||||||
self.set_pipeline_forward(
|
self.set_pipeline_forward(
|
||||||
model_cls=GPT2LMHeadModel,
|
model_cls=GPT2LMHeadModel,
|
||||||
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
|
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
|
||||||
shard_config=self.shard_config,
|
# shard_config=self.shard_config,
|
||||||
policy=module_policy,
|
policy=module_policy,
|
||||||
)
|
)
|
||||||
return module_policy
|
return module_policy
|
||||||
|
@ -127,43 +127,43 @@ model_zoo.register(
|
|||||||
loss_fn=loss_fn_for_gpt2_model,
|
loss_fn=loss_fn_for_gpt2_model,
|
||||||
model_attribute=ModelAttribute(has_control_flow=True),
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
)
|
)
|
||||||
# model_zoo.register(
|
model_zoo.register(
|
||||||
# name="transformers_gpt_lm",
|
name="transformers_gpt_lm",
|
||||||
# model_fn=lambda: transformers.GPT2LMHeadModel(config),
|
model_fn=lambda: transformers.GPT2LMHeadModel(config),
|
||||||
# data_gen_fn=data_gen_for_lm,
|
data_gen_fn=data_gen_for_lm,
|
||||||
# output_transform_fn=output_transform_fn,
|
output_transform_fn=output_transform_fn,
|
||||||
# loss_fn=loss_fn,
|
loss_fn=loss_fn,
|
||||||
# model_attribute=ModelAttribute(has_control_flow=True),
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
# )
|
)
|
||||||
# model_zoo.register(
|
model_zoo.register(
|
||||||
# name="transformers_gpt_double_heads",
|
name="transformers_gpt_double_heads",
|
||||||
# model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
|
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
|
||||||
# data_gen_fn=date_gen_for_double_heads,
|
data_gen_fn=date_gen_for_double_heads,
|
||||||
# output_transform_fn=output_transform_fn,
|
output_transform_fn=output_transform_fn,
|
||||||
# loss_fn=lambda x: x.loss + x.mc_loss,
|
loss_fn=lambda x: x.loss + x.mc_loss,
|
||||||
# model_attribute=ModelAttribute(has_control_flow=True),
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
# )
|
)
|
||||||
# model_zoo.register(
|
model_zoo.register(
|
||||||
# name="transformers_gpt_for_question_answering",
|
name="transformers_gpt_for_question_answering",
|
||||||
# model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
|
model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
|
||||||
# data_gen_fn=data_gen_for_question_answering,
|
data_gen_fn=data_gen_for_question_answering,
|
||||||
# output_transform_fn=output_transform_fn,
|
output_transform_fn=output_transform_fn,
|
||||||
# loss_fn=loss_fn,
|
loss_fn=loss_fn,
|
||||||
# model_attribute=ModelAttribute(has_control_flow=True),
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
# )
|
)
|
||||||
# model_zoo.register(
|
model_zoo.register(
|
||||||
# name="transformers_gpt_for_token_classification",
|
name="transformers_gpt_for_token_classification",
|
||||||
# model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification),
|
model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification),
|
||||||
# data_gen_fn=data_gen_for_token_classification,
|
data_gen_fn=data_gen_for_token_classification,
|
||||||
# output_transform_fn=output_transform_fn,
|
output_transform_fn=output_transform_fn,
|
||||||
# loss_fn=loss_fn,
|
loss_fn=loss_fn,
|
||||||
# model_attribute=ModelAttribute(has_control_flow=True),
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
# )
|
)
|
||||||
# model_zoo.register(
|
model_zoo.register(
|
||||||
# name="transformers_gpt_for_sequence_classification",
|
name="transformers_gpt_for_sequence_classification",
|
||||||
# model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification),
|
model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification),
|
||||||
# data_gen_fn=data_gen_for_sequence_classification,
|
data_gen_fn=data_gen_for_sequence_classification,
|
||||||
# output_transform_fn=output_transform_fn,
|
output_transform_fn=output_transform_fn,
|
||||||
# loss_fn=loss_fn,
|
loss_fn=loss_fn,
|
||||||
# model_attribute=ModelAttribute(has_control_flow=True),
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
# )
|
)
|
||||||
|
@ -157,18 +157,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
"sequence_parallelism_mode": "ring_attn",
|
"sequence_parallelism_mode": "ring_attn",
|
||||||
"num_microbatches": 1,
|
"num_microbatches": 1,
|
||||||
# "enable_all_optimization": True,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 4,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"num_microbatches": 1,
|
"num_microbatches": 2,
|
||||||
"enable_sequence_parallelism": False,
|
"enable_sequence_parallelism": True,
|
||||||
# "sequence_parallelism_mode": "split_gather",
|
"sequence_parallelism_mode": "split_gather",
|
||||||
"enable_flash_attention": False,
|
"enable_flash_attention": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
@ -238,7 +238,7 @@ def run_gpt2_test(test_config):
|
|||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
@ -247,7 +247,7 @@ def run_gpt2_test(test_config):
|
|||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
|
Loading…
Reference in New Issue
Block a user