diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index 51b228712..168e554f9 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.cache_utils import Cache from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -79,7 +80,7 @@ class GPTJPipelineForwards: def gptj_model_forward( self: GPTJModel, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -89,12 +90,13 @@ class GPTJPipelineForwards: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, ) -> Union[Dict, Tuple, BaseModelOutputWithPast]: - # This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJModel.forward. + # This function is modified on the v4.51.3 transformers.models.gptj.modeling_gptj.GPTJModel.forward. # Please refer to original code of transformers for more details. # GPTJ has no cross attention in comparison to GPT2 @@ -118,8 +120,8 @@ class GPTJPipelineForwards: use_cache = False if stage_manager.is_first_stage(): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") elif input_ids is not None: batch_size, seq_length = input_ids.shape input_shape = input_ids.size() @@ -130,17 +132,34 @@ class GPTJPipelineForwards: batch_size = inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, seq_length) else: if hidden_states is None: raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") input_shape = hidden_states.size()[:-1] batch_size, seq_length = input_shape[0], input_shape[1] - device = hidden_states.device + + if stage_manager.is_first_stage(): + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + hidden_states = inputs_embeds + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, seq_length) + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + hidden_states = self.drop(hidden_states) + + seq_length = hidden_states.shape[1] + if cache_position is None: + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + causal_mask = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values, output_attentions + ) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -148,33 +167,9 @@ class GPTJPipelineForwards: # head_mask has shape n_layer x batch x num_attention_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - # position id to be assigned not just for the first stage for attn input - if position_ids is None: - position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) - if stage_manager.is_first_stage(): - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - hidden_states = inputs_embeds - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - hidden_states = self.drop(hidden_states) + output_shape = (-1, seq_length, hidden_states.size(-1)) - output_shape = input_shape + (hidden_states.size(-1),) - - attention_mask = _get_attention_mask( - self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2 - ) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None + next_decoder_cache = None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None @@ -207,29 +202,26 @@ class GPTJPipelineForwards: block.__call__, hidden_states, None, - attention_mask, + causal_mask, position_ids, head_mask[i], use_cache, output_attentions, + cache_position, ) else: outputs = block( hidden_states=hidden_states, - layer_past=None, - attention_mask=attention_mask, + layer_past=past_key_values, + attention_mask=causal_mask, position_ids=position_ids, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config.enable_sequence_parallelism: @@ -248,22 +240,17 @@ class GPTJPipelineForwards: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if stage_manager.is_last_stage(): if not return_dict: return tuple( - v - for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - ] - if v is not None + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, ) @@ -275,7 +262,7 @@ class GPTJPipelineForwards: def gptj_causallm_model_forward( self: GPTJForCausalLM, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -286,6 +273,7 @@ class GPTJPipelineForwards: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -315,6 +303,7 @@ class GPTJPipelineForwards: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, @@ -326,18 +315,28 @@ class GPTJPipelineForwards: return {"hidden_states": transformer_outputs["hidden_states"]} hidden_states = transformer_outputs[0] - lm_logits = self.lm_head(hidden_states) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + # v4.51.3 tranformers loss calculation + # make sure sampling in fp16 works correctly and + # compute loss in fp32 to match with mesh-tf version + # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179 + lm_logits = self.lm_head(hidden_states).to(torch.float32) loss = None if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = self.loss_function( + lm_logits, + labels, + vocab_size=self.config.vocab_size, + ) loss = loss.to(hidden_states.dtype) @@ -357,7 +356,7 @@ class GPTJPipelineForwards: def gptj_for_sequence_classification_forward( self: GPTJForSequenceClassification, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -379,7 +378,7 @@ class GPTJPipelineForwards: config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - # This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJForSequenceClassification.forward. + # This function is modified on the v4.51.3 transformers.models.gptj.modeling_gptj.GPTJForSequenceClassification.forward. # Please refer to original code of transformers for more details. """ logger = logging.get_logger(__name__) @@ -581,6 +580,8 @@ def get_gptj_flash_attention_forward(): Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], ]: + # This function is modified on the v4.51.3 transformers.models.gptj.modeling_gptj.GPTJAttention.forward. + # Please refer to original code of transformers for more details. assert head_mask is None, "head_mask is not supported for FlashAttention" query = self.q_proj(hidden_states) key = self.k_proj(hidden_states)