diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 09f6e005c..bb493be3e 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -39,7 +39,6 @@ def _get_attention_mask( attention_mask: Optional[torch.FloatTensor], encoder_hidden_states: Optional[torch.Tensor], encoder_attention_mask: Optional[torch.FloatTensor], - head_mask: Optional[torch.Tensor] = None, ) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]: # Received input is already split for non-first pipeline stages, # but attn mask isn't @@ -49,9 +48,7 @@ def _get_attention_mask( sp_mode = shard_config.sequence_parallelism_mode # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - _use_sdpa = self._attn_implementation == "sdpa" - print("_use_sdpa", _use_sdpa) - from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa + if self.config.add_cross_attention and encoder_hidden_states is not None: assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only." encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() @@ -59,7 +56,7 @@ def _get_attention_mask( encoder_attention_mask = ColoAttention.prepare_attn_kwargs( (encoder_batch_size, 1, seq_len, encoder_sequence_length), dtype=hidden_states.dtype, - dtype2=encoder_hidden_states.dtype, + device=encoder_hidden_states.device, q_padding_mask=attention_mask, kv_padding_mask=encoder_attention_mask, ) @@ -67,12 +64,7 @@ def _get_attention_mask( encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device) - if _use_sdpa: - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - mask=encoder_attention_mask, dtype=hidden_states.dtype, tgt_len=encoder_hidden_shape[-1] - ) - elif not self._attn_implementation == "flash_attention_2": - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: if shard_config.enable_flash_attention: encoder_attention_mask = {"attention_mask": None} @@ -95,13 +87,6 @@ def _get_attention_mask( ) elif self._attn_implementation == "flash_attention_2": attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif _use_sdpa: - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask=attention_mask, - input_shape=(batch_size, hidden_states[-1]), - inputs_embeds=hidden_states, - past_key_values_length=past_key_values_length, - ) elif attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") @@ -224,7 +209,6 @@ class GPT2PipelineForwards: attention_mask, encoder_hidden_states, encoder_attention_mask, - head_mask ) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 6806260f5..d5d97fd2d 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -400,7 +400,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy): suffix="lm_head", target_module=col_nn.VocabParallelLMHead1D, kwargs={ - "gather_output": True, + "gather_output": False, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, }, ), @@ -418,20 +418,20 @@ class GPT2LMHeadModelPolicy(GPT2Policy): target_key=GPT2LMHeadModel, ) - # if self.shard_config.parallel_output: - # self.append_or_create_method_replacement( - # description={ - # "forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config) - # }, - # policy=module_policy, - # target_key=GPT2LMHeadModel, - # ) + if self.shard_config.parallel_output: + self.append_or_create_method_replacement( + description={ + "forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config) + }, + policy=module_policy, + target_key=GPT2LMHeadModel, + ) if self.pipeline_stage_manager is not None: self.set_pipeline_forward( model_cls=GPT2LMHeadModel, new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, - shard_config=self.shard_config, + # shard_config=self.shard_config, policy=module_policy, ) return module_policy diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 8ebe283f9..f2b139bec 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -127,43 +127,43 @@ model_zoo.register( loss_fn=loss_fn_for_gpt2_model, model_attribute=ModelAttribute(has_control_flow=True), ) -# model_zoo.register( -# name="transformers_gpt_lm", -# model_fn=lambda: transformers.GPT2LMHeadModel(config), -# data_gen_fn=data_gen_for_lm, -# output_transform_fn=output_transform_fn, -# loss_fn=loss_fn, -# model_attribute=ModelAttribute(has_control_flow=True), -# ) -# model_zoo.register( -# name="transformers_gpt_double_heads", -# model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), -# data_gen_fn=date_gen_for_double_heads, -# output_transform_fn=output_transform_fn, -# loss_fn=lambda x: x.loss + x.mc_loss, -# model_attribute=ModelAttribute(has_control_flow=True), -# ) -# model_zoo.register( -# name="transformers_gpt_for_question_answering", -# model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), -# data_gen_fn=data_gen_for_question_answering, -# output_transform_fn=output_transform_fn, -# loss_fn=loss_fn, -# model_attribute=ModelAttribute(has_control_flow=True), -# ) -# model_zoo.register( -# name="transformers_gpt_for_token_classification", -# model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification), -# data_gen_fn=data_gen_for_token_classification, -# output_transform_fn=output_transform_fn, -# loss_fn=loss_fn, -# model_attribute=ModelAttribute(has_control_flow=True), -# ) -# model_zoo.register( -# name="transformers_gpt_for_sequence_classification", -# model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification), -# data_gen_fn=data_gen_for_sequence_classification, -# output_transform_fn=output_transform_fn, -# loss_fn=loss_fn, -# model_attribute=ModelAttribute(has_control_flow=True), -# ) +model_zoo.register( + name="transformers_gpt_lm", + model_fn=lambda: transformers.GPT2LMHeadModel(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_double_heads", + model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), + data_gen_fn=date_gen_for_double_heads, + output_transform_fn=output_transform_fn, + loss_fn=lambda x: x.loss + x.mc_loss, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_for_question_answering", + model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_for_token_classification", + model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification), + data_gen_fn=data_gen_for_token_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_for_sequence_classification", + model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 0334f5c2c..3062eaf40 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -157,18 +157,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "ring_attn", "num_microbatches": 1, - # "enable_all_optimization": True, + "enable_all_optimization": True, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, }, { - "tp_size": 2, + "tp_size": 4, "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": False, - # "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": False, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": True, "use_lazy_init": False, "precision": "fp16", "initial_scale": 1, @@ -238,7 +238,7 @@ def run_gpt2_test(test_config): "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32", "initial_scale": 1, @@ -247,7 +247,7 @@ def run_gpt2_test(test_config): "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp16", "zero_stage": 1,