diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 580f3618c..dcb832639 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -58,7 +58,7 @@ class BertPipelineForwards: hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, - ): + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: # TODO(jianghai): add explaination of the output here. r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): @@ -1037,6 +1037,90 @@ def get_jit_fused_bert_output_forward(): return forward +# Fix the tgt_len size in sequence parallel attention: +# same with the one in BertSdpaSelfAttention forward in v4.51.3 transformers except the +# _, _, tgt_len, _ = query_layer.shape +def get_bert_sequence_parallel_attention_forward(shard_config: ShardConfig): + from transformers.models.bert.modeling_bert import BertSdpaSelfAttention + + def forward( + self: BertSdpaSelfAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + _, _, tgt_len, _ = query_layer.shape + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + return forward + + def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): def forward( self, diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 7e8e50d9b..5ca8f9869 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -6,7 +6,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import functional as F -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -21,6 +21,7 @@ from transformers.models.bloom.modeling_bloom import ( BloomForSequenceClassification, BloomForTokenClassification, BloomModel, + dropout_add, ) from transformers.utils import logging @@ -108,7 +109,7 @@ class BloomPipelineForwards: def bloom_model_forward( self: BloomModel, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, @@ -116,6 +117,7 @@ class BloomPipelineForwards: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -151,6 +153,8 @@ class BloomPipelineForwards: if use_cache: logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False + past_key_values = None + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N @@ -161,46 +165,60 @@ class BloomPipelineForwards: # case: First stage of training if stage_manager.is_first_stage(): # check input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - hidden_states = self.word_embeddings_layernorm(inputs_embeds) - # initialize in the first stage and then pass to the next stage - else: - input_shape = hidden_states.shape[:-1] - batch_size, seq_length = input_shape - - # extra recording tensor should be generated in the first stage - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - if self.gradient_checkpointing and self.training: - if use_cache: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] # source_len + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + batch_size, seq_length, _ = inputs_embeds.shape + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device) + # initialize in the first stage and then pass to the next stage + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange(past_length, past_length + seq_length, device=hidden_states.device) + + # extra recording tensor should be generated in the first stage + + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage + past_length = 0 + seq_length_with_past = seq_length + past_length - seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) else: @@ -209,13 +227,10 @@ class BloomPipelineForwards: alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) # causal_mask is constructed every stage and its input is passed through different stages - causal_mask = _prepare_4d_causal_attention_mask( - attention_mask, - input_shape=(batch_size, seq_length), - inputs_embeds=hidden_states, - past_key_values_length=past_key_values_length, + causal_mask = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values, output_attentions ) - causal_mask = causal_mask.bool() + # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config and shard_config.enable_sequence_parallelism: @@ -228,9 +243,7 @@ class BloomPipelineForwards: ) start_idx, end_idx = stage_index[0], stage_index[1] - for i, (block, layer_past) in enumerate( - zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx - ): + for i, block in enumerate(self.h[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -240,26 +253,28 @@ class BloomPipelineForwards: hidden_states, alibi, causal_mask, - layer_past, + past_key_values, head_mask[i], use_cache, output_attentions, + cache_position, ) else: outputs = block( hidden_states, - layer_past=layer_past, + layer_past=past_key_values, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, + cache_position=cache_position, ) hidden_states = outputs[0] + if use_cache: + next_decoder_cache = outputs[1] - if use_cache is True: - presents = presents + (outputs[1],) if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) @@ -277,20 +292,23 @@ class BloomPipelineForwards: # Add last hidden state hidden_states = self.ln_f(hidden_states) - # TODO(jianghai): deal with all_hidden_states, all_self_attentions, presents if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + if stage_manager.is_last_stage(): if not return_dict: return tuple( - v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None ) # attention_mask is not returned ; presents = past_key_values return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, ) @@ -718,35 +736,24 @@ def get_jit_fused_bloom_attention_forward(): head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, ): + batch_size, q_length, _ = hidden_states.shape fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + # 3 x [batch_size, num_heads, seq_length, head_dim] + query_layer, key_layer, value_layer = self._reshape(fused_qkv) - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - - batch_size, q_length, _, _ = query_layer.shape - - query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) - key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length) - value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) if layer_past is not None: - past_key, past_value = layer_past - # concatenate along seq_length dimension: - # - key: [batch_size * self.num_heads, head_dim, kv_length] - # - value: [batch_size * self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=2) - value_layer = torch.cat((past_value, value_layer), dim=1) + cache_kwargs = {"cache_position": cache_position} + key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) - _, _, kv_length = key_layer.shape - - if use_cache is True: - present = (key_layer, value_layer) - else: - present = None + # reshape qkv for further computations + query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) + key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2) + value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) # [batch_size * num_heads, q_length, kv_length] - # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 - matmul_result = alibi.baddbmm( + attention_scores = alibi.baddbmm( batch1=query_layer, batch2=key_layer, beta=self.beta, @@ -754,15 +761,13 @@ def get_jit_fused_bloom_attention_forward(): ) # change view to [batch_size, num_heads, q_length, kv_length] - attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) + attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]] + attn_weights = attn_weights + causal_mask - # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] - input_dtype = attention_scores.dtype - # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` - if input_dtype == torch.float16: - attention_scores = attention_scores.to(torch.float) - attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min) - attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype) # [batch_size, num_heads, q_length, kv_length] attention_probs = self.attention_dropout(attention_probs) @@ -771,12 +776,12 @@ def get_jit_fused_bloom_attention_forward(): attention_probs = attention_probs * head_mask # change view [batch_size x num_heads, q_length, kv_length] - attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) + attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1) # matmul: [batch_size * num_heads, q_length, head_dim] context_layer = torch.bmm(attention_probs_reshaped, value_layer) - # change view [batch_size, num_heads, q_length, head_dim] + # change view [batch_size, q_length, num_heads * head_dim] context_layer = self._merge_heads(context_layer) # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 @@ -791,9 +796,9 @@ def get_jit_fused_bloom_attention_forward(): else: output_tensor = self.dense(context_layer) - output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) - outputs = (output_tensor, present) + outputs = (output_tensor, layer_past) if output_attentions: outputs += (attention_probs,) @@ -839,13 +844,99 @@ def get_jit_fused_bloom_gelu_forward(): return forward +# Fixed the q_length args when doing the sequence parallelism in bloom model. +def get_bloom_sequence_parallel_attention_forward(shard_config: ShardConfig): + from transformers.models.bloom.modeling_bloom import BloomAttention + + def forward( + self: BloomAttention, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Cache] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ): + batch_size, q_length, _ = hidden_states.shape + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + # 3 x [batch_size, num_heads, seq_length, head_dim] + query_layer, key_layer, value_layer = self._reshape(fused_qkv) + + if layer_past is not None: + cache_kwargs = {"cache_position": cache_position} + key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) + + # reshape qkv for further computations + query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) + key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2) + value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) + + # [batch_size * num_heads, q_length, kv_length] + attention_scores = alibi.baddbmm( + batch1=query_layer, + batch2=key_layer, + beta=self.beta, + alpha=self.inv_norm_factor, + ) + if shard_config.enable_sequence_parallelism: + _, q_length, _ = query_layer.shape + # change view to [batch_size, num_heads, q_length, kv_length] + attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]] + attn_weights = attn_weights + causal_mask + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype) + + # [batch_size, num_heads, q_length, kv_length] + attention_probs = self.attention_dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # change view [batch_size x num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1) + + # matmul: [batch_size * num_heads, q_length, head_dim] + context_layer = torch.bmm(attention_probs_reshaped, value_layer) + + # change view [batch_size, q_length, num_heads * head_dim] + context_layer = self._merge_heads(context_layer) + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + + outputs = (output_tensor, layer_past) + if output_attentions: + outputs += (attention_probs,) + + return outputs + + return forward + + def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): from transformers import BloomModel def forward( self: BloomModel, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, @@ -853,6 +944,7 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: if deprecated_arguments.pop("position_ids", False) is not False: @@ -864,7 +956,6 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): ) if len(deprecated_arguments) > 0: raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -872,62 +963,60 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + batch_size, seq_length, _ = inputs_embeds.shape + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + seq_length_with_past = seq_length + past_length + if cache_position is None: + cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - hidden_states = self.word_embeddings_layernorm(inputs_embeds) - presents = () if use_cache else None + next_decoder_cache = None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Compute alibi tensor: check build_alibi_tensor documentation - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) else: attention_mask = attention_mask.to(hidden_states.device) alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) - - causal_mask = _prepare_4d_causal_attention_mask( - attention_mask, - input_shape=(batch_size, seq_length), - inputs_embeds=hidden_states, - past_key_values_length=past_key_values_length, + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - causal_mask = causal_mask.bool() - # split the input tensor along sequence dimension - # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward( hidden_states, dim=1, @@ -935,7 +1024,7 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): fp8_communication=shard_config.fp8_communication, ) - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + for i, block in enumerate(self.h): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -945,25 +1034,27 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states, alibi, causal_mask, - layer_past, + past_key_values, head_mask[i], use_cache, output_attentions, + cache_position, ) else: outputs = block( hidden_states, - layer_past=layer_past, + layer_past=past_key_values, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, + cache_position=cache_position, ) hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) + if use_cache: + next_decoder_cache = outputs[1] if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) @@ -975,18 +1066,25 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): process_group=shard_config.tensor_parallel_process_group, fp8_communication=shard_config.fp8_communication, ) + # Add last hidden state hidden_states = self.ln_f(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, ) diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index d06f8db2c..870c9272b 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -21,6 +21,7 @@ from transformers.models.falcon.modeling_falcon import ( build_alibi_tensor, ) from transformers.utils import logging +import warnings from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig @@ -103,7 +104,7 @@ def get_tp_falcon_decoder_layer_forward(): alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, @@ -113,7 +114,10 @@ def get_tp_falcon_decoder_layer_forward(): ] = None, # Add cache_position and position_embeddings args for v4.51.3 transformers **kwargs, ): - + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) residual = hidden_states # same as v4.51.3 transformers @@ -130,12 +134,12 @@ def get_tp_falcon_decoder_layer_forward(): attention_mask=attention_mask, position_ids=position_ids, alibi=alibi, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) attention_output = attn_outputs[0] @@ -170,7 +174,7 @@ def get_tp_falcon_decoder_layer_forward(): else: outputs = (output,) + outputs[1:] - return outputs # hidden_states, past_kv, attentions + return outputs # hidden_states, present, attentions return forward @@ -226,7 +230,7 @@ class FalconPipelineForwards: elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) hidden_states = inputs_embeds @@ -236,7 +240,7 @@ class FalconPipelineForwards: if self.gradient_checkpointing and self.training: if use_cache: - logger.warning_once( + logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -290,35 +294,35 @@ class FalconPipelineForwards: if self.gradient_checkpointing and self.training: outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - alibi, - causal_mask, - position_ids, - head_mask[i], - past_key_values, - use_cache, - output_attentions, - cache_position, - position_embeddings, - ) + block.__call__, + hidden_states, + alibi, + causal_mask, + position_ids, + head_mask[i], + past_key_values, + use_cache, + output_attentions, + cache_position, + position_embeddings, + ) else: outputs = block( - hidden_states, - layer_past=past_key_values, - attention_mask=causal_mask, - position_ids=position_ids, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - alibi=alibi, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + hidden_states, + layer_past=past_key_values, + attention_mask=causal_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) hidden_states = outputs[0] if use_cache is True: - outputs[1] + next_decoder_cache = outputs[1] if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index d550484da..d7d182762 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -48,6 +48,7 @@ def _get_attention_mask( sp_mode = shard_config.sequence_parallelism_mode # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only." encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() @@ -55,7 +56,7 @@ def _get_attention_mask( encoder_attention_mask = ColoAttention.prepare_attn_kwargs( (encoder_batch_size, 1, seq_len, encoder_sequence_length), dtype=hidden_states.dtype, - dtype2=encoder_hidden_states.dtype, + device=encoder_hidden_states.device, q_padding_mask=attention_mask, kv_padding_mask=encoder_attention_mask, ) @@ -77,7 +78,6 @@ def _get_attention_mask( if shard_config.enable_flash_attention: if attention_mask is not None: attention_mask = attention_mask.view(batch_size, -1) - attention_mask = ColoAttention.prepare_attn_kwargs( (batch_size, 1, seq_len, seq_len + past_key_values_length), hidden_states.dtype, @@ -835,9 +835,12 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) attention_mask = encoder_attention_mask else: query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) + + shape_q = (*query.shape[:-1], -1, self.head_dim) + shape_kv = (*key.shape[:-1], -1, self.head_dim) + query = query.view(shape_q).transpose(1, 2) + key = key.view(shape_kv).transpose(1, 2) + value = value.view(shape_kv).transpose(1, 2) if layer_past is not None: past_key, past_value = layer_past @@ -871,7 +874,9 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) ) else: attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) - attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + + attn_output = attn_output.permute(0, 2, 1, 3).contiguous() + attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous() attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present, None) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index d1ad84604..fe102eecf 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Tuple, Union import torch import torch.distributed -import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -141,7 +140,9 @@ class LlamaPipelineForwards: invert=(sp_mode != "ring_attn"), ) else: - attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position) + attn_kwargs: torch.Tensor = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values + ) # Support SP + PP. Later stages have already received the split input. split_input = disable_pp or stage_manager.is_first_stage() @@ -177,6 +178,7 @@ class LlamaPipelineForwards: all_self_attns = () if output_attentions else None next_decoder_cache = None start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1]) + position_embeddings = self.rotary_emb(hidden_states, position_ids) num_ckpt_layers = 0 if self.gradient_checkpointing and self.training: @@ -204,6 +206,7 @@ class LlamaPipelineForwards: output_attentions, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -214,6 +217,7 @@ class LlamaPipelineForwards: output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] @@ -486,8 +490,8 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[Union[torch.Tensor, Dict]] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, @@ -505,30 +509,14 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s ) bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] # sp: modify sp_len when sequence parallel mode is ring if is_share_sp_tp(sp_mode): q_len *= sp_size - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": @@ -537,9 +525,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -552,7 +540,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -610,17 +598,12 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication ) else: - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights return forward diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 569fc4a45..7bdf1e65f 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -57,6 +57,7 @@ class Qwen2PipelineForwards: use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, @@ -131,14 +132,6 @@ class Qwen2PipelineForwards: else: position_ids = position_ids.view(-1, seq_length).long() - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage if shard_config.enable_flash_attention: @@ -152,16 +145,16 @@ class Qwen2PipelineForwards: is_causal=True, ) else: - if self._attn_implementation == "flash_attention_2": + if self.config._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: + elif self.config._attn_implementation == "sdpa" and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), - inputs_embeds, + hidden_states, past_key_values_length, ) else: @@ -195,6 +188,8 @@ class Qwen2PipelineForwards: all_self_attns = () if output_attentions else None next_decoder_cache = None + position_embeddings = self.rotary_emb(hidden_states, position_ids) + start_idx, end_idx = stage_index[0], stage_index[1] num_ckpt_layers = 0 if self.gradient_checkpointing and self.training: @@ -214,7 +209,7 @@ class Qwen2PipelineForwards: if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None + past_key_values[idx] if past_key_values is not None else None if idx - start_idx < num_ckpt_layers: layer_outputs = self._gradient_checkpointing_func( @@ -225,15 +220,19 @@ class Qwen2PipelineForwards: past_key_values, output_attentions, use_cache, + cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, ) hidden_states = layer_outputs[0] @@ -491,11 +490,10 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s def forward( self: Qwen2Attention, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if sp_mode is not None: @@ -519,9 +517,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -533,9 +531,8 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute @@ -563,7 +560,7 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s attention_mask = attention_mask[:, slicing_tokens:] attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -605,11 +602,11 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication ) else: - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None return forward @@ -627,6 +624,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, force_sp_output_gather: bool = True, ) -> Union[Tuple, BaseModelOutputWithPast]: @@ -648,6 +646,9 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + seq_length_with_past = seq_length past_key_values_length = 0 @@ -664,9 +665,6 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No else: position_ids = position_ids.view(-1, seq_length).long() - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions hidden_states = inputs_embeds @@ -700,6 +698,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None + position_embeddings = self.rotary_emb(hidden_states, position_ids) if sp_mode in ["ring", "split_gather"]: hidden_states = split_forward_gather_backward( @@ -723,22 +722,23 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No past_key_values, output_attentions, use_cache, + cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 63cd49280..fd4e020b0 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -11,6 +11,7 @@ import colossalai.shardformer.layer as col_nn from ..modeling.bert import ( BertPipelineForwards, bert_sequence_parallel_forward_fn, + get_bert_sequence_parallel_attention_forward, get_jit_fused_bert_intermediate_forward, get_jit_fused_bert_output_forward, get_jit_fused_bert_self_output_forward, @@ -48,6 +49,7 @@ class BertPolicy(Policy): BertLayer, BertModel, BertOutput, + BertSdpaSelfAttention, BertSelfOutput, ) @@ -77,6 +79,16 @@ class BertPolicy(Policy): use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_sequence_parallelism: + # Fix the tgt_len size in bert sequence parallel attention forward. + self.append_or_create_method_replacement( + description={ + "forward": get_bert_sequence_parallel_attention_forward(self.shard_config), + }, + policy=policy, + target_key=BertSdpaSelfAttention, + ) + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index c7691698b..af49a4d19 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -11,6 +11,7 @@ import colossalai.shardformer.layer as col_nn from ..modeling.bloom import ( BloomPipelineForwards, build_bloom_alibi_tensor_fn, + get_bloom_sequence_parallel_attention_forward, get_bloom_sequence_parallel_forward_fn, get_jit_fused_bloom_attention_forward, get_jit_fused_bloom_gelu_forward, @@ -61,6 +62,15 @@ class BloomPolicy(Policy): use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_sequence_parallelism: + self.append_or_create_method_replacement( + description={ + "forward": get_bloom_sequence_parallel_attention_forward(self.shard_config), + }, + policy=policy, + target_key=BloomAttention, + ) + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.n_head % self.shard_config.tensor_parallel_size == 0 diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 362f33176..68a548aee 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -246,7 +246,6 @@ class FalconPolicy(Policy): module = self.model.transformer stage_manager = self.pipeline_stage_manager held_layers = [] - held_layers.append(module.rotary_emb) if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None layers_per_stage = stage_manager.distribute_layers(len(module.h)) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index c57d33826..ba7a5c5bc 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -38,14 +38,8 @@ class GPT2Policy(Policy): def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model - ATTN_IMPLEMENTATION = { - "eager": GPT2Attention, - } - policy = {} - attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] - embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = col_nn.VocabParallelEmbedding1D @@ -280,7 +274,7 @@ class GPT2Policy(Policy): "forward": get_gpt2_flash_attention_forward(shard_config=self.shard_config), }, policy=policy, - target_key=attn_cls, + target_key=GPT2Attention, ) if not self.shard_config.pipeline_stage_manager and self.shard_config.enable_sequence_parallelism: diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index e8f9471f9..9ad63dd7f 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -33,22 +33,9 @@ class LlamaPolicy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaFlashAttention2, - LlamaModel, - LlamaSdpaAttention, - ) + from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel - ATTN_IMPLEMENTATION = { - "eager": LlamaAttention, - "flash_attention_2": LlamaFlashAttention2, - "sdpa": LlamaSdpaAttention, - } policy = {} - - attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -82,7 +69,7 @@ class LlamaPolicy(Policy): num_kv_heads //= sp_size decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads - policy[attn_cls] = ModulePolicyDescription( + policy[LlamaAttention] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: @@ -91,7 +78,7 @@ class LlamaPolicy(Policy): "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), }, policy=policy, - target_key=attn_cls, + target_key=LlamaAttention, ) if self.pipeline_stage_manager is None: @@ -354,6 +341,7 @@ class LlamaPolicy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] + held_layers.append(module.rotary_emb) if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None layers_per_stage = stage_manager.distribute_layers(len(module.layers)) diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 84d2b2fdb..7f8a35e46 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -65,15 +65,9 @@ class Qwen2Policy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - ATTN_IMPLEMENTATION = { - "eager": Qwen2Attention, - "flash_attention_2": Qwen2FlashAttention2, - "sdpa": Qwen2SdpaAttention, - } policy = {} - attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -93,7 +87,7 @@ class Qwen2Policy(Policy): if getattr(self.model.config, "num_key_value_heads", False): decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size - policy[attn_cls] = ModulePolicyDescription( + policy[Qwen2Attention] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) @@ -301,12 +295,13 @@ class Qwen2Policy(Policy): ) if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: + print("self.shard_config.enable_flash_attention", self.shard_config.enable_flash_attention) self.append_or_create_method_replacement( description={ "forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), }, policy=policy, - target_key=attn_cls, + target_key=Qwen2Attention, ) if self.pipeline_stage_manager is None: # replace qwen2 model forward method @@ -370,6 +365,7 @@ class Qwen2Policy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] + held_layers.append(module.rotary_emb) if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None layers_per_stage = stage_manager.distribute_layers(len(module.layers)) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 393f7ffca..b67c494a6 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -180,7 +180,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "split_gather", "enable_flash_attention": True, - "use_lazy_init": True, + "use_lazy_init": False, "precision": "fp16", "initial_scale": 1, }, @@ -238,7 +238,7 @@ def run_gpt2_test(test_config): "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32", "initial_scale": 1, @@ -247,7 +247,7 @@ def run_gpt2_test(test_config): "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp16", "zero_stage": 1,