mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
[upgrade]Upgrade qwen2 (#6302)
* upgrade qwen2 * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
a4c6e189fa
commit
f118146564
@ -57,6 +57,7 @@ class Qwen2PipelineForwards:
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
@ -131,14 +132,6 @@ class Qwen2PipelineForwards:
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if shard_config.enable_flash_attention:
|
||||
@ -152,16 +145,16 @@ class Qwen2PipelineForwards:
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
if self._attn_implementation == "flash_attention_2":
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._attn_implementation == "sdpa" and not output_attentions:
|
||||
elif self.config._attn_implementation == "sdpa" and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
@ -195,6 +188,8 @@ class Qwen2PipelineForwards:
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
num_ckpt_layers = 0
|
||||
if self.gradient_checkpointing and self.training:
|
||||
@ -214,7 +209,7 @@ class Qwen2PipelineForwards:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if idx - start_idx < num_ckpt_layers:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
@ -225,15 +220,19 @@ class Qwen2PipelineForwards:
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@ -491,11 +490,10 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||
def forward(
|
||||
self: Qwen2Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if sp_mode is not None:
|
||||
@ -519,9 +517,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
@ -533,9 +531,8 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
@ -563,7 +560,7 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||
attention_mask = attention_mask[:, slicing_tokens:]
|
||||
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
||||
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
@ -605,11 +602,11 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
return attn_output, None
|
||||
|
||||
return forward
|
||||
|
||||
@ -627,6 +624,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
force_sp_output_gather: bool = True,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
@ -648,6 +646,9 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
@ -664,9 +665,6 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@ -700,6 +698,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
hidden_states = split_forward_gather_backward(
|
||||
@ -723,22 +722,23 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
|
@ -65,15 +65,9 @@ class Qwen2Policy(Policy):
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": Qwen2Attention,
|
||||
"flash_attention_2": Qwen2FlashAttention2,
|
||||
"sdpa": Qwen2SdpaAttention,
|
||||
}
|
||||
|
||||
policy = {}
|
||||
|
||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
@ -93,7 +87,7 @@ class Qwen2Policy(Policy):
|
||||
if getattr(self.model.config, "num_key_value_heads", False):
|
||||
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
|
||||
|
||||
policy[attn_cls] = ModulePolicyDescription(
|
||||
policy[Qwen2Attention] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
|
||||
@ -301,12 +295,13 @@ class Qwen2Policy(Policy):
|
||||
)
|
||||
|
||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||
print("self.shard_config.enable_flash_attention", self.shard_config.enable_flash_attention)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
target_key=Qwen2Attention,
|
||||
)
|
||||
if self.pipeline_stage_manager is None:
|
||||
# replace qwen2 model forward method
|
||||
@ -370,6 +365,7 @@ class Qwen2Policy(Policy):
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
held_layers.append(module.rotary_emb)
|
||||
if stage_manager.is_interleave:
|
||||
assert stage_manager.num_model_chunks is not None
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
|
Loading…
Reference in New Issue
Block a user