diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 8cc76dd3e..c37a6b4df 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -199,7 +199,12 @@ def get_param_info(optim: Optimizer): if optim is None: return {} - param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}} + param_info = { + "param_groups": [], + "param2id": {}, + "id2param": {}, + "param2shape": {}, + } start_index = 0 for group in optim.param_groups: packed_group = {k: v for k, v in group.items() if k != "params"} @@ -899,6 +904,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. + parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. microbatch_size (int, optional): Microbatch size when using pipeline parallelism. Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. @@ -939,6 +945,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_jit_fused: bool = False, enable_sequence_parallelism: bool = False, enable_sequence_overlap: bool = False, + parallel_output: bool = True, num_microbatches: Optional[int] = None, microbatch_size: Optional[int] = None, initial_scale: float = 2**16, @@ -1035,6 +1042,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_jit_fused=self.enable_jit_fused, enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_overlap=enable_sequence_overlap, + parallel_output=parallel_output, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 3e5cc6015..1e22d9094 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -25,6 +25,7 @@ from colossalai.shardformer.layer._operation import gather_forward_split_backwar from colossalai.shardformer.shard import ShardConfig from ..layer import cross_entropy_1d +from ..layer._operation import gather_forward_split_backward class GPT2PipelineForwards: @@ -337,6 +338,9 @@ class GPT2PipelineForwards: else: loss = loss_fct(shift_logits, shift_labels) + if not shard_config.parallel_output: + lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group) + if not return_dict: output = (lm_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output @@ -793,11 +797,12 @@ def get_gpt2_flash_attention_forward(): scale = scale * (1 / float(self.layer_idx + 1)) # use coloattention - attention = ColoAttention( - embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale - ) + if not hasattr(self, "attention"): + self.attention = ColoAttention( + embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale + ) - attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) + attn_output = self.attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) @@ -1083,6 +1088,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): else: loss = loss_fct(shift_logits, shift_labels) + if not shard_config.parallel_output: + lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group) + if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index f20ceb2d6..eb8e9f748 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -16,7 +16,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig from ..layer import cross_entropy_1d -from ..layer._operation import _gather +from ..layer._operation import gather_forward_split_backward try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask @@ -290,7 +290,7 @@ class LlamaPipelineForwards: loss = loss_fct(shift_logits, shift_labels) if not shard_config.parallel_output: - logits = _gather(logits, -1, shard_config.tensor_parallel_process_group) + logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group) if not return_dict: output = (logits,) + outputs[1:] @@ -485,8 +485,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() attn_mask_type = AttnMaskType.paddedcausal - attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) - attn_output = attention( + if not hasattr(self, "attention"): + self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) + attn_output = self.attention( query_states, key_states, value_states, @@ -593,7 +594,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): loss = loss_fct(shift_logits, shift_labels) if not shard_config.parallel_output: - logits = _gather(logits, -1, shard_config.tensor_parallel_process_group) + logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 1d2b7a570..9a49b1ba6 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -242,4 +242,4 @@ class Policy(ABC): end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] stage_indices.append([start_idx, end_idx]) - return stage_indices[0] if num_model_chunks == 1 else stage_indices + return stage_indices[0] if num_model_chunks == 1 else stage_indices \ No newline at end of file diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 415fc6dd5..da27341d9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -34,8 +34,10 @@ class ShardConfig: enable_all_optimization: bool = False enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False - parallel_output = True + parallel_output: bool = True extra_kwargs: Dict[str, Any] = field(default_factory=dict) + # TODO padding vocab + # make_vocab_size_divisible_by: int = 128 # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index 285c4866c..38361d803 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -260,7 +260,7 @@ def run_grad_acc_test(test_args): origin_model, origin_optimizer, dataloader=dataloader ) for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()): - assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) + assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) def run_dist(rank, world_size, port, early_stop: bool = True):