This commit is contained in:
flybird11111 2025-05-01 01:04:34 +00:00 committed by GitHub
commit 2ee4abdfa0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 16 deletions

View File

@ -48,6 +48,7 @@ def _get_attention_mask(
sp_mode = shard_config.sequence_parallelism_mode sp_mode = shard_config.sequence_parallelism_mode
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None: if self.config.add_cross_attention and encoder_hidden_states is not None:
assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only." assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only."
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
@ -55,7 +56,7 @@ def _get_attention_mask(
encoder_attention_mask = ColoAttention.prepare_attn_kwargs( encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
(encoder_batch_size, 1, seq_len, encoder_sequence_length), (encoder_batch_size, 1, seq_len, encoder_sequence_length),
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
dtype2=encoder_hidden_states.dtype, device=encoder_hidden_states.device,
q_padding_mask=attention_mask, q_padding_mask=attention_mask,
kv_padding_mask=encoder_attention_mask, kv_padding_mask=encoder_attention_mask,
) )
@ -77,7 +78,6 @@ def _get_attention_mask(
if shard_config.enable_flash_attention: if shard_config.enable_flash_attention:
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask.view(batch_size, -1)
attention_mask = ColoAttention.prepare_attn_kwargs( attention_mask = ColoAttention.prepare_attn_kwargs(
(batch_size, 1, seq_len, seq_len + past_key_values_length), (batch_size, 1, seq_len, seq_len + past_key_values_length),
hidden_states.dtype, hidden_states.dtype,
@ -835,9 +835,12 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
attention_mask = encoder_attention_mask attention_mask = encoder_attention_mask
else: else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim) shape_q = (*query.shape[:-1], -1, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim) shape_kv = (*key.shape[:-1], -1, self.head_dim)
query = query.view(shape_q).transpose(1, 2)
key = key.view(shape_kv).transpose(1, 2)
value = value.view(shape_kv).transpose(1, 2)
if layer_past is not None: if layer_past is not None:
past_key, past_value = layer_past past_key, past_value = layer_past
@ -871,7 +874,9 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
) )
else: else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
attn_output = self.c_proj(attn_output) attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output) attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present, None) outputs = (attn_output, present, None)

View File

@ -38,14 +38,8 @@ class GPT2Policy(Policy):
def module_policy(self): def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
ATTN_IMPLEMENTATION = {
"eager": GPT2Attention,
}
policy = {} policy = {}
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
embedding_cls = None embedding_cls = None
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D embedding_cls = col_nn.VocabParallelEmbedding1D
@ -53,6 +47,11 @@ class GPT2Policy(Policy):
if self.tie_weight: if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding embedding_cls = col_nn.PaddingEmbedding
<<<<<<< Updated upstream
print("embedding_cls", embedding_cls)
=======
>>>>>>> Stashed changes
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm norm_cls = col_nn.FusedLayerNorm
else: else:
@ -280,7 +279,7 @@ class GPT2Policy(Policy):
"forward": get_gpt2_flash_attention_forward(shard_config=self.shard_config), "forward": get_gpt2_flash_attention_forward(shard_config=self.shard_config),
}, },
policy=policy, policy=policy,
target_key=attn_cls, target_key=GPT2Attention,
) )
if not self.shard_config.pipeline_stage_manager and self.shard_config.enable_sequence_parallelism: if not self.shard_config.pipeline_stage_manager and self.shard_config.enable_sequence_parallelism:
@ -430,6 +429,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=GPT2LMHeadModel, model_cls=GPT2LMHeadModel,
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
shard_config=self.shard_config,
policy=module_policy, policy=module_policy,
) )
return module_policy return module_policy

View File

@ -180,7 +180,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather", "sequence_parallelism_mode": "split_gather",
"enable_flash_attention": True, "enable_flash_attention": True,
"use_lazy_init": True, "use_lazy_init": False,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
@ -238,7 +238,7 @@ def run_gpt2_test(test_config):
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
"initial_scale": 1, "initial_scale": 1,
@ -247,7 +247,7 @@ def run_gpt2_test(test_config):
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp16", "precision": "fp16",
"zero_stage": 1, "zero_stage": 1,