mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-31 13:59:23 +00:00 
			
		
		
		
	Merge 885210dc27 into 46ed5d856b
				
					
				
			This commit is contained in:
		| @@ -1,16 +1,9 @@ | |||||||
| import math |  | ||||||
| import warnings |  | ||||||
| from typing import List, Optional, Tuple, Union | from typing import List, Optional, Tuple, Union | ||||||
|  |  | ||||||
| import torch | import torch | ||||||
| import torch.distributed as dist | import torch.distributed as dist | ||||||
| from torch.distributed import ProcessGroup | from torch.distributed import ProcessGroup | ||||||
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | ||||||
| from transformers.modeling_attn_mask_utils import ( |  | ||||||
|     AttentionMaskConverter, |  | ||||||
|     _prepare_4d_causal_attention_mask, |  | ||||||
|     _prepare_4d_causal_attention_mask_for_sdpa, |  | ||||||
| ) |  | ||||||
| from transformers.modeling_outputs import ( | from transformers.modeling_outputs import ( | ||||||
|     BaseModelOutputWithPastAndCrossAttentions, |     BaseModelOutputWithPastAndCrossAttentions, | ||||||
|     CausalLMOutputWithCrossAttentions, |     CausalLMOutputWithCrossAttentions, | ||||||
| @@ -110,19 +103,18 @@ def get_tp_falcon_decoder_layer_forward(): | |||||||
|         alibi: Optional[torch.Tensor], |         alibi: Optional[torch.Tensor], | ||||||
|         attention_mask: torch.Tensor, |         attention_mask: torch.Tensor, | ||||||
|         position_ids: Optional[torch.LongTensor] = None, |         position_ids: Optional[torch.LongTensor] = None, | ||||||
|         layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |         layer_past: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, | ||||||
|         head_mask: Optional[torch.Tensor] = None, |         head_mask: Optional[torch.Tensor] = None, | ||||||
|         use_cache: bool = False, |         use_cache: bool = False, | ||||||
|         output_attentions: bool = False, |         output_attentions: bool = False, | ||||||
|  |         cache_position: Optional[torch.LongTensor] = None, | ||||||
|  |         position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ): |     ): | ||||||
|         if "padding_mask" in kwargs: |  | ||||||
|             warnings.warn( |  | ||||||
|                 "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" |  | ||||||
|             ) |  | ||||||
|         residual = hidden_states |         residual = hidden_states | ||||||
|  |  | ||||||
|         if self.config.new_decoder_architecture: |         if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2: | ||||||
|             attention_layernorm_out = self.ln_attn(hidden_states) |             attention_layernorm_out = self.ln_attn(hidden_states) | ||||||
|             mlp_layernorm_out = self.ln_mlp(hidden_states) |             mlp_layernorm_out = self.ln_mlp(hidden_states) | ||||||
|         else: |         else: | ||||||
| @@ -138,7 +130,8 @@ def get_tp_falcon_decoder_layer_forward(): | |||||||
|             head_mask=head_mask, |             head_mask=head_mask, | ||||||
|             use_cache=use_cache, |             use_cache=use_cache, | ||||||
|             output_attentions=output_attentions, |             output_attentions=output_attentions, | ||||||
|             **kwargs, |             cache_position=cache_position, | ||||||
|  |             position_embeddings=position_embeddings, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         attention_output = attn_outputs[0] |         attention_output = attn_outputs[0] | ||||||
| @@ -152,6 +145,13 @@ def get_tp_falcon_decoder_layer_forward(): | |||||||
|                 ) |                 ) | ||||||
|                 mlp_layernorm_out = self.post_attention_layernorm(residual) |                 mlp_layernorm_out = self.post_attention_layernorm(residual) | ||||||
|  |  | ||||||
|  |         if ( | ||||||
|  |             self.config.new_decoder_architecture | ||||||
|  |             and self.config.parallel_attn | ||||||
|  |             and self.config.num_ln_in_parallel_attn == 1 | ||||||
|  |         ): | ||||||
|  |             mlp_layernorm_out = attention_layernorm_out | ||||||
|  |  | ||||||
|         outputs = attn_outputs[1:] |         outputs = attn_outputs[1:] | ||||||
|  |  | ||||||
|         # MLP. |         # MLP. | ||||||
| @@ -167,7 +167,7 @@ def get_tp_falcon_decoder_layer_forward(): | |||||||
|         else: |         else: | ||||||
|             outputs = (output,) + outputs[1:] |             outputs = (output,) + outputs[1:] | ||||||
|  |  | ||||||
|         return outputs  # hidden_states, present, attentions |         return outputs  # hidden_states, past_kv, attentions | ||||||
|  |  | ||||||
|     return forward |     return forward | ||||||
|  |  | ||||||
| @@ -190,6 +190,7 @@ class FalconPipelineForwards: | |||||||
|         output_attentions: Optional[bool] = None, |         output_attentions: Optional[bool] = None, | ||||||
|         output_hidden_states: Optional[bool] = None, |         output_hidden_states: Optional[bool] = None, | ||||||
|         return_dict: Optional[bool] = None, |         return_dict: Optional[bool] = None, | ||||||
|  |         cache_position: Optional[torch.LongTensor] = None, | ||||||
|         stage_manager: Optional[PipelineStageManager] = None, |         stage_manager: Optional[PipelineStageManager] = None, | ||||||
|         hidden_states: Optional[torch.FloatTensor] = None, |         hidden_states: Optional[torch.FloatTensor] = None, | ||||||
|         stage_index: Optional[List[int]] = None, |         stage_index: Optional[List[int]] = None, | ||||||
| @@ -206,7 +207,6 @@ class FalconPipelineForwards: | |||||||
|             logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") |             logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") | ||||||
|             use_cache = False |             use_cache = False | ||||||
|  |  | ||||||
|         if past_key_values is not None: |  | ||||||
|         logger.warning_once("past_key_values is not supported for pipeline models at the moment.") |         logger.warning_once("past_key_values is not supported for pipeline models at the moment.") | ||||||
|         past_key_values = None |         past_key_values = None | ||||||
|  |  | ||||||
| @@ -221,7 +221,7 @@ class FalconPipelineForwards: | |||||||
|             elif inputs_embeds is not None: |             elif inputs_embeds is not None: | ||||||
|                 batch_size, seq_length, _ = inputs_embeds.shape |                 batch_size, seq_length, _ = inputs_embeds.shape | ||||||
|             else: |             else: | ||||||
|                 raise ValueError("You have to specify either input_ids or inputs_embeds") |                 raise ValueError("You must specify exactly one of input_ids or inputs_embeds") | ||||||
|             if inputs_embeds is None: |             if inputs_embeds is None: | ||||||
|                 inputs_embeds = self.word_embeddings(input_ids) |                 inputs_embeds = self.word_embeddings(input_ids) | ||||||
|             hidden_states = inputs_embeds |             hidden_states = inputs_embeds | ||||||
| @@ -229,12 +229,9 @@ class FalconPipelineForwards: | |||||||
|             input_shape = hidden_states.shape[:-1] |             input_shape = hidden_states.shape[:-1] | ||||||
|             batch_size, seq_length = input_shape |             batch_size, seq_length = input_shape | ||||||
|  |  | ||||||
|         if past_key_values is None: |  | ||||||
|             past_key_values = tuple([None] * len(self.h)) |  | ||||||
|  |  | ||||||
|         if self.gradient_checkpointing and self.training: |         if self.gradient_checkpointing and self.training: | ||||||
|             if use_cache: |             if use_cache: | ||||||
|                 logger.warning( |                 logger.warning_once( | ||||||
|                     "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |                     "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." | ||||||
|                 ) |                 ) | ||||||
|                 use_cache = False |                 use_cache = False | ||||||
| @@ -243,10 +240,10 @@ class FalconPipelineForwards: | |||||||
|         all_hidden_states = () if output_hidden_states else None |         all_hidden_states = () if output_hidden_states else None | ||||||
|  |  | ||||||
|         # Compute alibi tensor: check build_alibi_tensor documentation |         # Compute alibi tensor: check build_alibi_tensor documentation | ||||||
|  |         alibi = None | ||||||
|         past_key_values_length = 0 |         past_key_values_length = 0 | ||||||
|         if past_key_values[0] is not None: |  | ||||||
|             past_key_values_length = past_key_values[0][0].shape[-2] |  | ||||||
|  |  | ||||||
|  |         batch_size, seq_length, _ = hidden_states.shape | ||||||
|         if self.use_alibi: |         if self.use_alibi: | ||||||
|             mask = ( |             mask = ( | ||||||
|                 torch.ones( |                 torch.ones( | ||||||
| @@ -256,61 +253,17 @@ class FalconPipelineForwards: | |||||||
|                 else attention_mask |                 else attention_mask | ||||||
|             ) |             ) | ||||||
|             alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype) |             alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype) | ||||||
|         else: |  | ||||||
|             alibi = None |         if cache_position is None: | ||||||
|  |             cache_position = torch.arange( | ||||||
|  |                 past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device | ||||||
|  |             ) | ||||||
|  |  | ||||||
|         if position_ids is None: |         if position_ids is None: | ||||||
|                 device = input_ids.device if input_ids is not None else inputs_embeds.device |             position_ids = cache_position.unsqueeze(0) | ||||||
|                 position_ids = torch.arange( |  | ||||||
|                     past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device |  | ||||||
|                 ) |  | ||||||
|                 position_ids = position_ids.unsqueeze(0) |  | ||||||
|  |  | ||||||
|         if self._use_flash_attention_2: |         causal_mask = self._update_causal_mask( | ||||||
|             # 2d mask is passed through the layers |             attention_mask, hidden_states, cache_position, past_key_values, output_attentions, head_mask, alibi | ||||||
|             attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None |  | ||||||
|         elif self._use_sdpa and not output_attentions: |  | ||||||
|             # output_attentions=True can not be supported when using SDPA, and we fall back on |  | ||||||
|             # the manual implementation that requires a 4D causal mask in all cases. |  | ||||||
|             if alibi is None: |  | ||||||
|                 attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( |  | ||||||
|                     attention_mask, |  | ||||||
|                     (batch_size, seq_length), |  | ||||||
|                     inputs_embeds, |  | ||||||
|                     past_key_values_length, |  | ||||||
|                 ) |  | ||||||
|             elif head_mask is None: |  | ||||||
|                 alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) |  | ||||||
|  |  | ||||||
|                 attention_mask_2d = attention_mask |  | ||||||
|                 # We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched. |  | ||||||
|                 attention_mask = _prepare_4d_causal_attention_mask( |  | ||||||
|                     attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length |  | ||||||
|                 ) |  | ||||||
|  |  | ||||||
|                 # We take care to integrate alibi bias in the attention_mask here. |  | ||||||
|                 if attention_mask_2d is None: |  | ||||||
|                     attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads) |  | ||||||
|                 else: |  | ||||||
|                     min_dtype = torch.finfo(alibi.dtype).min |  | ||||||
|                     attention_mask = torch.masked_fill( |  | ||||||
|                         alibi / math.sqrt(self.config.hidden_size // self.num_heads), |  | ||||||
|                         attention_mask < -1, |  | ||||||
|                         min_dtype, |  | ||||||
|                     ) |  | ||||||
|  |  | ||||||
|                     # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend |  | ||||||
|                     # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 |  | ||||||
|                     if seq_length > 1 and attention_mask.device.type == "cuda": |  | ||||||
|                         attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype) |  | ||||||
|             else: |  | ||||||
|                 # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case. |  | ||||||
|                 attention_mask = _prepare_4d_causal_attention_mask( |  | ||||||
|                     attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length |  | ||||||
|                 ) |  | ||||||
|         else: |  | ||||||
|             # 4d mask is passed through the layers |  | ||||||
|             attention_mask = _prepare_4d_causal_attention_mask( |  | ||||||
|                 attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length |  | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         # Prepare head mask if needed |         # Prepare head mask if needed | ||||||
| @@ -319,10 +272,11 @@ class FalconPipelineForwards: | |||||||
|         # head_mask has shape n_layer x batch x num_heads x N x N |         # head_mask has shape n_layer x batch x num_heads x N x N | ||||||
|         head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |         head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) | ||||||
|  |  | ||||||
|  |         # create position embeddings to be shared across the decoder layers | ||||||
|  |         position_embeddings = self.rotary_emb(hidden_states, position_ids) | ||||||
|  |  | ||||||
|         start_idx, end_idx = stage_index[0], stage_index[1] |         start_idx, end_idx = stage_index[0], stage_index[1] | ||||||
|         for i, (block, layer_past) in enumerate( |         for i, block in enumerate(self.h[start_idx:end_idx], start=start_idx): | ||||||
|             zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx |  | ||||||
|         ): |  | ||||||
|             if output_hidden_states: |             if output_hidden_states: | ||||||
|                 all_hidden_states = all_hidden_states + (hidden_states,) |                 all_hidden_states = all_hidden_states + (hidden_states,) | ||||||
|  |  | ||||||
| @@ -331,28 +285,32 @@ class FalconPipelineForwards: | |||||||
|                     block.__call__, |                     block.__call__, | ||||||
|                     hidden_states, |                     hidden_states, | ||||||
|                     alibi, |                     alibi, | ||||||
|                     attention_mask, |                     causal_mask, | ||||||
|                     position_ids, |                     position_ids, | ||||||
|                     head_mask[i], |                     head_mask[i], | ||||||
|                     layer_past, |                     past_key_values, | ||||||
|                     use_cache, |                     use_cache, | ||||||
|                     output_attentions, |                     output_attentions, | ||||||
|  |                     cache_position, | ||||||
|  |                     position_embeddings, | ||||||
|                 ) |                 ) | ||||||
|             else: |             else: | ||||||
|                 outputs = block( |                 outputs = block( | ||||||
|                     hidden_states, |                     hidden_states, | ||||||
|                     layer_past=layer_past, |                     layer_past=past_key_values, | ||||||
|                     attention_mask=attention_mask, |                     attention_mask=causal_mask, | ||||||
|                     position_ids=position_ids, |                     position_ids=position_ids, | ||||||
|                     head_mask=head_mask[i], |                     head_mask=head_mask[i], | ||||||
|                     use_cache=use_cache, |                     use_cache=use_cache, | ||||||
|                     output_attentions=output_attentions, |                     output_attentions=output_attentions, | ||||||
|                     alibi=alibi, |                     alibi=alibi, | ||||||
|  |                     cache_position=cache_position, | ||||||
|  |                     position_embeddings=position_embeddings, | ||||||
|                 ) |                 ) | ||||||
|  |  | ||||||
|             hidden_states = outputs[0] |             hidden_states = outputs[0] | ||||||
|             if use_cache is True: |             if use_cache is True: | ||||||
|                 presents = presents + (outputs[1],) |                 outputs[1] | ||||||
|  |  | ||||||
|             if output_attentions: |             if output_attentions: | ||||||
|                 all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) |                 all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) | ||||||
| @@ -365,6 +323,7 @@ class FalconPipelineForwards: | |||||||
|             all_hidden_states = all_hidden_states + (hidden_states,) |             all_hidden_states = all_hidden_states + (hidden_states,) | ||||||
|  |  | ||||||
|         if stage_manager.is_last_stage(): |         if stage_manager.is_last_stage(): | ||||||
|  |  | ||||||
|             if not return_dict: |             if not return_dict: | ||||||
|                 return tuple( |                 return tuple( | ||||||
|                     v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None |                     v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None | ||||||
|   | |||||||
| @@ -246,6 +246,7 @@ class FalconPolicy(Policy): | |||||||
|             module = self.model.transformer |             module = self.model.transformer | ||||||
|         stage_manager = self.pipeline_stage_manager |         stage_manager = self.pipeline_stage_manager | ||||||
|         held_layers = [] |         held_layers = [] | ||||||
|  |         held_layers.append(module.rotary_emb) | ||||||
|         if stage_manager.is_interleave: |         if stage_manager.is_interleave: | ||||||
|             assert stage_manager.num_model_chunks is not None |             assert stage_manager.num_model_chunks is not None | ||||||
|             layers_per_stage = stage_manager.distribute_layers(len(module.h)) |             layers_per_stage = stage_manager.distribute_layers(len(module.h)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user