From 5480b811c5ee4c90891e61d0f35c4e909dd8a17a Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Tue, 6 May 2025 15:58:53 +0800 Subject: [PATCH] upgrade_bloom --- colossalai/shardformer/modeling/bloom.py | 224 ++++++++++++----------- 1 file changed, 120 insertions(+), 104 deletions(-) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 7e8e50d9b..1e8b8b3e2 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -6,7 +6,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import functional as F -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -108,7 +108,7 @@ class BloomPipelineForwards: def bloom_model_forward( self: BloomModel, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, @@ -116,6 +116,7 @@ class BloomPipelineForwards: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -151,6 +152,8 @@ class BloomPipelineForwards: if use_cache: logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False + past_key_values = None + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N @@ -161,46 +164,60 @@ class BloomPipelineForwards: # case: First stage of training if stage_manager.is_first_stage(): # check input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - hidden_states = self.word_embeddings_layernorm(inputs_embeds) - # initialize in the first stage and then pass to the next stage - else: - input_shape = hidden_states.shape[:-1] - batch_size, seq_length = input_shape - - # extra recording tensor should be generated in the first stage - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - if self.gradient_checkpointing and self.training: - if use_cache: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] # source_len + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + batch_size, seq_length, _ = inputs_embeds.shape + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device) + # initialize in the first stage and then pass to the next stage + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange(past_length, past_length + seq_length, device=hidden_states.device) + + # extra recording tensor should be generated in the first stage + + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage + past_length = 0 + seq_length_with_past = seq_length + past_length - seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) else: @@ -209,13 +226,10 @@ class BloomPipelineForwards: alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) # causal_mask is constructed every stage and its input is passed through different stages - causal_mask = _prepare_4d_causal_attention_mask( - attention_mask, - input_shape=(batch_size, seq_length), - inputs_embeds=hidden_states, - past_key_values_length=past_key_values_length, + causal_mask = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values, output_attentions ) - causal_mask = causal_mask.bool() + # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config and shard_config.enable_sequence_parallelism: @@ -228,9 +242,7 @@ class BloomPipelineForwards: ) start_idx, end_idx = stage_index[0], stage_index[1] - for i, (block, layer_past) in enumerate( - zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx - ): + for i, block in enumerate(self.h[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -240,26 +252,28 @@ class BloomPipelineForwards: hidden_states, alibi, causal_mask, - layer_past, + past_key_values, head_mask[i], use_cache, output_attentions, + cache_position, ) else: outputs = block( hidden_states, - layer_past=layer_past, + layer_past=past_key_values, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, + cache_position=cache_position, ) hidden_states = outputs[0] + if use_cache: + next_decoder_cache = outputs[1] - if use_cache is True: - presents = presents + (outputs[1],) if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) @@ -277,20 +291,23 @@ class BloomPipelineForwards: # Add last hidden state hidden_states = self.ln_f(hidden_states) - # TODO(jianghai): deal with all_hidden_states, all_self_attentions, presents if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + if stage_manager.is_last_stage(): if not return_dict: return tuple( - v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None ) # attention_mask is not returned ; presents = past_key_values return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, ) @@ -845,7 +862,7 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): def forward( self: BloomModel, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, @@ -853,6 +870,7 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: if deprecated_arguments.pop("position_ids", False) is not False: @@ -864,7 +882,6 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): ) if len(deprecated_arguments) > 0: raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -872,62 +889,60 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + batch_size, seq_length, _ = inputs_embeds.shape + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + seq_length_with_past = seq_length + past_length + if cache_position is None: + cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - hidden_states = self.word_embeddings_layernorm(inputs_embeds) - presents = () if use_cache else None + next_decoder_cache = None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Compute alibi tensor: check build_alibi_tensor documentation - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) else: attention_mask = attention_mask.to(hidden_states.device) alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) - - causal_mask = _prepare_4d_causal_attention_mask( - attention_mask, - input_shape=(batch_size, seq_length), - inputs_embeds=hidden_states, - past_key_values_length=past_key_values_length, + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - causal_mask = causal_mask.bool() - # split the input tensor along sequence dimension - # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward( hidden_states, dim=1, @@ -935,7 +950,7 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): fp8_communication=shard_config.fp8_communication, ) - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + for i, block in enumerate(self.h): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -945,48 +960,49 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states, alibi, causal_mask, - layer_past, + past_key_values, head_mask[i], use_cache, output_attentions, + cache_position, ) else: outputs = block( hidden_states, - layer_past=layer_past, + layer_past=past_key_values, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, + cache_position=cache_position, ) hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) + if use_cache: + next_decoder_cache = outputs[1] if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - # When sequence parallelism done, gather the output tensor in forward and split it in backward - hidden_states = gather_forward_split_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - fp8_communication=shard_config.fp8_communication, - ) # Add last hidden state hidden_states = self.ln_f(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, )