diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 3568a5dda..1b3b765c2 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -28,8 +28,7 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed -from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler -from colossalai.pipeline.schedule.v_schedule import PipelineGraph +from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer @@ -1093,10 +1092,8 @@ class HybridParallelPlugin(PipelinePluginBase): self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: - assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style" - assert ( - pp_style == "interleaved" or pp_style == "zbv" - ) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" + assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" @@ -1106,7 +1103,7 @@ class HybridParallelPlugin(PipelinePluginBase): self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, - enable_interleave=(pp_style == "interleaved") or (pp_style == "zbv"), + enable_interleave=(pp_style == "interleaved"), num_model_chunks=num_model_chunks, num_layers_per_stage=num_layers_per_stage, ) @@ -1128,31 +1125,6 @@ class HybridParallelPlugin(PipelinePluginBase): microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, ) - elif pp_style == "zbv": - h, a, s = 4096, 32, 1024 - mem_f = 34 * h + 5 * a * s - mem_w = -32 * h - mem_b = -mem_w - mem_f - zbv_schedule = PipelineGraph( - n_stage=self.pp_size, - n_micro=num_microbatches, - f_cost=1, - b_cost=1, - w_cost=1, - c_cost=1, - f_mem=mem_f, - b_mem=mem_b, - w_mem=mem_w, - ).get_v_schedule() - self.schedule = ZeroBubbleVPipeScheduler( - schedule=zbv_schedule, - stage_manager=self.stage_manager, - num_model_chunks=num_model_chunks, - num_microbatch=num_microbatches, - microbatch_size=microbatch_size, - enable_metadata_cache=enable_metadata_cache, - overlap_p2p=overlap_p2p, - ) else: raise NotImplementedError() if sequence_parallelism_mode == "ring_attn": diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index b2c988a8b..c1e48d5f7 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -14,7 +14,16 @@ from colossalai.logging import disable_existing_loggers from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_weight, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) class MlpModel(nn.Module): @@ -679,6 +688,11 @@ def run_fwd_bwd_vschedule_with_optim(test_config): # TODO:4) support Hybrid base 3) +def run_with_hybridplugin(test_config): + pass + + +# TODO:5) support MoEHybrid base 3) @parameterize( "test_config", [ @@ -693,20 +707,55 @@ def run_fwd_bwd_vschedule_with_optim(test_config): }, ], ) -def run_with_hybridplugin(test_config): - pass +def run_with_moehybridplugin(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + test_config["use_lazy_init"] = False + test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel + test_config["initial_scale"] = 2**16 # avoid overflow + model_list = [ + "transformers_bert", + ] + clear_layout_converter() + torch.set_default_dtype(torch.bfloat16) + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name in model_list: + ( + org_model, + org_optimizer, + sharded_model, + sharded_optimizer, + criterion, + booster, + ) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, torch.optim.SGD, torch.optim.SGD) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) -# TODO:5) support MoEHybrid base 3) -def run_with_moehybridplugin( - rank: int, - world_size: int, - port: int, - num_microbatch: int, - batch_size: int, - num_model_chunk: int, -): - pass + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + bert = unwrap_model(org_model, "BertModel", "bert") + sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") + weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"] + + org_optimizer.step() + sharded_optimizer.step() + + # check weights + if test_config["precision"] == "bf16": + atol, rtol = 5e-4, 5e-4 + else: + atol, rtol = 5e-4, 5e-4 + if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): + check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + # check optim states + # check_dist_optim_state(org_optimizer, sharded_optimizer.optim) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + print(f"Bert Model Zoo Test Passed") # TODO:6) support booster & Hybrid base 4)