diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index ffba27a50..5d6f47636 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,11 +1,11 @@ import logging from functools import partial from types import MethodType -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch -from torch import Tensor -from torch.nn import Module +from torch import Tensor, nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss import colossalai.shardformer.layer as col_nn from colossalai.pipeline.stage_manager import PipelineStageManager @@ -48,6 +48,10 @@ class GPT2Policy(Policy): suffix="wte", target_module=col_nn.VocabParallelEmbedding1D, ), + SubModuleReplacementDescription( + suffix="drop", + target_module=col_nn.DropoutForParallelInput, + ), ]) policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, @@ -120,6 +124,45 @@ class GPT2Policy(Policy): def postprocess(self): return self.model + def get_held_layers(self) -> List[nn.Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == 'GPT2Model': + module = self.model + else: + module = self.model.transformer + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.wte) + held_layers.append(module.wpe) + held_layers.append(module.drop) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == 'GPT2Model': + module = self.model + else: + module = self.model.transformer + + layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=model_cls) + # GPT2Model class GPT2ModelPolicy(GPT2Policy): @@ -131,40 +174,16 @@ class GPT2ModelPolicy(GPT2Policy): from transformers.models.gpt2.modeling_gpt2 import GPT2Model policy = super().module_policy() - if self.pipeline_stage_manager: - # set None as default - stage_manager = self.pipeline_stage_manager - layers_per_stage = Policy.distribute_layers(len(self.model.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - 'forward': - partial(GPT2PipelineForwards.gpt2_model_forward, - stage_manager=stage_manager, - stage_index=stage_index) - } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=GPT2Model) + self.set_pipeline_forward(model_cls=GPT2Model, + new_forward=GPT2PipelineForwards.gpt2_model_forward, + policy=policy) return policy - def get_held_layers(self) -> List[Module]: - """Get pipeline layers for current stage.""" - module = self.model - stage_manager = self.pipeline_stage_manager - held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.wte) - held_layers.append(module.wpe) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.h[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.ln_f) - return held_layers + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() def get_shared_params(self) -> List[Dict[int, Tensor]]: - # TODO: check whether there is shared param in gpt2model - """No shared params in gpt2 model.""" + """No shared params in GPT2Model.""" return [] @@ -188,10 +207,31 @@ class GPT2LMHeadModelPolicy(GPT2Policy): ]) } module_policy.update(addon_module) + + self.set_pipeline_forward(model_cls=GPT2LMHeadModel, + new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, + policy=module_policy) return module_policy + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + '''The weights of wte and lm_head are shared.''' + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager and id(module.transformer.wte.weight) == id(module.lm_head.weight): + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + else: + return [] + def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism \ + and self.pipeline_stage_manager is None: binding_map = {"transformer.wte.weight": "lm_head.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) @@ -199,7 +239,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy): return self.model -# GPT22DoubleHeadsModel +# GPT2DoubleHeadsModel class GPT2DoubleHeadsModelPolicy(GPT2Policy): def __init__(self) -> None: @@ -219,10 +259,38 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): ]) } module_policy.update(addon_module) + + self.set_pipeline_forward(model_cls=GPT2DoubleHeadsModel, + new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward, + policy=module_policy) + return module_policy + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + multiple_choice_head = self.model.multiple_choice_head + held_layers.append(self.model.lm_head) + held_layers.append(multiple_choice_head.summary) + held_layers.append(multiple_choice_head.activation) + held_layers.append(multiple_choice_head.first_dropout) + held_layers.append(multiple_choice_head.last_dropout) + + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + '''The weights of wte and lm_head are shared.''' + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager and id(module.transformer.wte.weight) == id(module.lm_head.weight): + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + else: + return [] + def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism \ + and self.pipeline_stage_manager is None: binding_map = {"transformer.wte.weight": "lm_head.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) @@ -236,6 +304,36 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy): def __init__(self) -> None: super().__init__() + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification + + module_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2ForTokenClassification: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription(suffix="dropout", target_module=col_nn.DropoutForParallelInput) + ]) + } + module_policy.update(addon_module) + + self.set_pipeline_forward(model_cls=GPT2ForTokenClassification, + new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward, + policy=module_policy) + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in GPT2ForTokenClassification.""" + return [] + # GPT2ForSequenceClassification class GPT2ForSequenceClassificationPolicy(GPT2Policy): @@ -243,6 +341,25 @@ class GPT2ForSequenceClassificationPolicy(GPT2Policy): def __init__(self) -> None: super().__init__() + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification + + module_policy = super().module_policy() + self.set_pipeline_forward(model_cls=GPT2ForSequenceClassification, + new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward, + policy=module_policy) + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in GPT2ForTokenClassification.""" + return [] + class GPT2PipelineForwards: ''' @@ -299,8 +416,7 @@ class GPT2PipelineForwards: if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, seq_length) else: - if hidden_states is None: - raise ValueError("hidden_states shouln't be None for stages other than the first stage.") + assert hidden_states is not None input_shape = hidden_states.size()[:-1] batch_size, seq_length = input_shape[0], input_shape[1] device = hidden_states.device @@ -462,3 +578,356 @@ class GPT2PipelineForwards: else: # always return dict for intermediate stage return {'hidden_states': hidden_states} + + @staticmethod + def gpt2_lmhead_model_forward( + self: 'GPT2LMHeadModel', + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Tuple, 'CausalLMOutputWithCrossAttentions']: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. + Please refer to original code of transformers for more details. + """ + + from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + hidden_states = outputs[0] + lm_logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + @staticmethod + def gpt2_double_heads_model_forward( + self: 'GPT2DoubleHeadsModel', + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Tuple, 'GPT2DoubleHeadsModelOutput']: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel.forward. + Please refer to original code of transformers for more details. + ```""" + from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModelOutput + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + hidden_states = outputs[0] + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @staticmethod + def gpt2_for_token_classification_forward( + self: 'GPT2ForTokenClassification', + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Tuple, 'TokenClassifierOutput']: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification.forward. + # Please refer to original code of transformers for more details. + """ + + from transformers.modeling_outputs import TokenClassifierOutput + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @staticmethod + def gpt2_for_sequence_classification_forward( + self: 'GPT2ForSequenceClassification', + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Tuple, 'SequenceClassifierOutputWithPast']: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification.forward. + # Please refer to original code of transformers for more details. + """ + from transformers.modeling_outputs import SequenceClassifierOutputWithPast + + if input_ids is not None: + batch_size, _ = input_ids.shape[:2] + else: + batch_size, _ = hidden_states.shape[:2] + assert (self.config.pad_token_id is not None + or batch_size == 1), "Cannot handle batch sizes > 1 if no padding token is defined." + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + hidden_states = outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logging.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`") + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py index 5f92f638f..dd439a394 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py @@ -5,15 +5,9 @@ import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.logging import disable_existing_loggers from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward +from tests.test_shardformer.test_model._utils import build_pipeline_model def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -21,8 +15,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo pass -@parameterize('enable_fused_normalization', [False]) @parameterize('enable_tensor_parallelism', [False]) +@parameterize('enable_fused_normalization', [False]) @parameterize('use_lazy_init', [False]) #TODO: merge this into test_shard_gpt2 def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): @@ -32,30 +26,30 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name != "transformers_gpt": - continue + for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): inputs = data_gen_fn() inputs = {k: v.cuda() for k, v in inputs.items()} + input_ids, _ = inputs['input_ids'], inputs['attention_mask'] + batch_size, seq_len = input_ids.shape + hidden_size = 768 + hidden_state_shape = (batch_size, seq_len, hidden_size) - org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - org_model.train() - org_output = org_model(**inputs) - hidden_state_shape = org_output['last_hidden_state'].shape - - if stage_manager.is_first_stage(): - output = sharded_model(**inputs) - assert output['hidden_states'].shape == hidden_state_shape - else: - attention_mask = inputs['attention_mask'] + if not stage_manager.is_first_stage(): + # change inputs if not the first stage hidden_states = torch.zeros(*hidden_state_shape).cuda() - output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask) - if stage_manager.is_last_stage(): - assert output['last_hidden_state'].shape == hidden_state_shape - else: - assert output['hidden_states'].shape == hidden_state_shape + inputs['input_ids'] = None + inputs['hidden_states'] = hidden_states + + _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + sharded_model.train() + output = sharded_model(**inputs) + if stage_manager.is_last_stage(): + if name != 'transformers_gpt': + assert output.loss is not None + else: + assert output['hidden_states'].shape == hidden_state_shape torch.cuda.empty_cache()