mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-02 20:35:29 +00:00
fix
This commit is contained in:
parent
c6291be1b1
commit
fdcc3691fa
@ -39,6 +39,7 @@ def _get_attention_mask(
|
||||
attention_mask: Optional[torch.FloatTensor],
|
||||
encoder_hidden_states: Optional[torch.Tensor],
|
||||
encoder_attention_mask: Optional[torch.FloatTensor],
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]:
|
||||
# Received input is already split for non-first pipeline stages,
|
||||
# but attn mask isn't
|
||||
@ -48,6 +49,9 @@ def _get_attention_mask(
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
_use_sdpa = self._attn_implementation == "sdpa"
|
||||
print("_use_sdpa", _use_sdpa)
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
|
||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||
assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only."
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
@ -63,7 +67,12 @@ def _get_attention_mask(
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device)
|
||||
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
if _use_sdpa:
|
||||
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
mask=encoder_attention_mask, dtype=hidden_states.dtype, tgt_len=encoder_hidden_shape[-1]
|
||||
)
|
||||
elif not self._attn_implementation == "flash_attention_2":
|
||||
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
if shard_config.enable_flash_attention:
|
||||
encoder_attention_mask = {"attention_mask": None}
|
||||
@ -77,7 +86,6 @@ def _get_attention_mask(
|
||||
if shard_config.enable_flash_attention:
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
(batch_size, 1, seq_len, seq_len + past_key_values_length),
|
||||
hidden_states.dtype,
|
||||
@ -85,6 +93,15 @@ def _get_attention_mask(
|
||||
attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
elif self._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif _use_sdpa:
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask=attention_mask,
|
||||
input_shape=(batch_size, hidden_states[-1]),
|
||||
inputs_embeds=hidden_states,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
elif attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size has to be defined and > 0")
|
||||
@ -207,8 +224,10 @@ class GPT2PipelineForwards:
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask
|
||||
)
|
||||
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
@ -835,9 +854,12 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
|
||||
attention_mask = encoder_attention_mask
|
||||
else:
|
||||
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
||||
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||
|
||||
shape_q = (*query.shape[:-1], -1, self.head_dim)
|
||||
shape_kv = (*key.shape[:-1], -1, self.head_dim)
|
||||
query = query.view(shape_q).transpose(1, 2)
|
||||
key = key.view(shape_kv).transpose(1, 2)
|
||||
value = value.view(shape_kv).transpose(1, 2)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
@ -871,7 +893,9 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
|
||||
)
|
||||
else:
|
||||
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
|
||||
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
|
||||
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
|
||||
attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
|
||||
attn_output = self.c_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
outputs = (attn_output, present, None)
|
||||
|
@ -38,13 +38,8 @@ class GPT2Policy(Policy):
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": GPT2Attention,
|
||||
}
|
||||
|
||||
policy = {}
|
||||
|
||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
@ -53,6 +48,10 @@ class GPT2Policy(Policy):
|
||||
if self.tie_weight:
|
||||
embedding_cls = col_nn.PaddingEmbedding
|
||||
|
||||
|
||||
|
||||
print("embedding_cls", embedding_cls)
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = col_nn.FusedLayerNorm
|
||||
else:
|
||||
@ -224,6 +223,7 @@ class GPT2Policy(Policy):
|
||||
|
||||
if embedding_cls is not None:
|
||||
# padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
|
||||
print("embedding_cls", embedding_cls)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="wte",
|
||||
@ -280,7 +280,7 @@ class GPT2Policy(Policy):
|
||||
"forward": get_gpt2_flash_attention_forward(shard_config=self.shard_config),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
target_key=GPT2Attention,
|
||||
)
|
||||
|
||||
if not self.shard_config.pipeline_stage_manager and self.shard_config.enable_sequence_parallelism:
|
||||
@ -394,12 +394,13 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
||||
module_policy = super().module_policy()
|
||||
module_policy[GPT2LMHeadModel] = ModulePolicyDescription()
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
print("self.shard_config.enable_tensor_parallelism", self.shard_config.enable_tensor_parallelism)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.VocabParallelLMHead1D,
|
||||
kwargs={
|
||||
"gather_output": False,
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
},
|
||||
),
|
||||
@ -417,19 +418,20 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
||||
target_key=GPT2LMHeadModel,
|
||||
)
|
||||
|
||||
if self.shard_config.parallel_output:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config)
|
||||
},
|
||||
policy=module_policy,
|
||||
target_key=GPT2LMHeadModel,
|
||||
)
|
||||
# if self.shard_config.parallel_output:
|
||||
# self.append_or_create_method_replacement(
|
||||
# description={
|
||||
# "forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config)
|
||||
# },
|
||||
# policy=module_policy,
|
||||
# target_key=GPT2LMHeadModel,
|
||||
# )
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(
|
||||
model_cls=GPT2LMHeadModel,
|
||||
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
|
||||
shard_config=self.shard_config,
|
||||
policy=module_policy,
|
||||
)
|
||||
return module_policy
|
||||
|
@ -127,43 +127,43 @@ model_zoo.register(
|
||||
loss_fn=loss_fn_for_gpt2_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_gpt_lm",
|
||||
model_fn=lambda: transformers.GPT2LMHeadModel(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_gpt_double_heads",
|
||||
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
|
||||
data_gen_fn=date_gen_for_double_heads,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=lambda x: x.loss + x.mc_loss,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_gpt_for_question_answering",
|
||||
model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_question_answering,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_gpt_for_token_classification",
|
||||
model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification),
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_gpt_for_sequence_classification",
|
||||
model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
# model_zoo.register(
|
||||
# name="transformers_gpt_lm",
|
||||
# model_fn=lambda: transformers.GPT2LMHeadModel(config),
|
||||
# data_gen_fn=data_gen_for_lm,
|
||||
# output_transform_fn=output_transform_fn,
|
||||
# loss_fn=loss_fn,
|
||||
# model_attribute=ModelAttribute(has_control_flow=True),
|
||||
# )
|
||||
# model_zoo.register(
|
||||
# name="transformers_gpt_double_heads",
|
||||
# model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
|
||||
# data_gen_fn=date_gen_for_double_heads,
|
||||
# output_transform_fn=output_transform_fn,
|
||||
# loss_fn=lambda x: x.loss + x.mc_loss,
|
||||
# model_attribute=ModelAttribute(has_control_flow=True),
|
||||
# )
|
||||
# model_zoo.register(
|
||||
# name="transformers_gpt_for_question_answering",
|
||||
# model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
|
||||
# data_gen_fn=data_gen_for_question_answering,
|
||||
# output_transform_fn=output_transform_fn,
|
||||
# loss_fn=loss_fn,
|
||||
# model_attribute=ModelAttribute(has_control_flow=True),
|
||||
# )
|
||||
# model_zoo.register(
|
||||
# name="transformers_gpt_for_token_classification",
|
||||
# model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification),
|
||||
# data_gen_fn=data_gen_for_token_classification,
|
||||
# output_transform_fn=output_transform_fn,
|
||||
# loss_fn=loss_fn,
|
||||
# model_attribute=ModelAttribute(has_control_flow=True),
|
||||
# )
|
||||
# model_zoo.register(
|
||||
# name="transformers_gpt_for_sequence_classification",
|
||||
# model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification),
|
||||
# data_gen_fn=data_gen_for_sequence_classification,
|
||||
# output_transform_fn=output_transform_fn,
|
||||
# loss_fn=loss_fn,
|
||||
# model_attribute=ModelAttribute(has_control_flow=True),
|
||||
# )
|
||||
|
@ -157,19 +157,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring_attn",
|
||||
"num_microbatches": 1,
|
||||
"enable_all_optimization": True,
|
||||
# "enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"enable_sequence_parallelism": False,
|
||||
# "sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
@ -180,7 +180,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
|
Loading…
Reference in New Issue
Block a user