diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index bbad921b2..5c25c5bfa 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -430,8 +430,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): with self.stage_manager.switch_model_chunk_id(model_chunk_id): # fwd calculate internal_inputs = {} if input_obj is None else input_obj - # internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] - output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs) + internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] + output_obj = model_forward(model_chunk, micro_batch, internal_inputs) # last layer in model if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): @@ -449,7 +449,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk: Union[ModuleList, Module], model_chunk_id: int, optimizer: OptimizerWrapper, - micro_batch: Optional[dict], + # micro_batch: Optional[dict], input_obj: Optional[dict], output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict], @@ -478,11 +478,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj_ = [] output_obj_grad_ = [] - # For chunk 0 stage 0, use micro_batch as input_obj_ + # For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx. if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - input_obj_, _ = tree_flatten(micro_batch) - output_obj_, _ = tree_flatten(output_obj) # y - output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + return None # For loss backward; output_obj is loss; output_obj_grad should be None elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): @@ -497,6 +495,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj_, _ = tree_flatten(output_obj) # y output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + # filter item which is not torch.Tensor + input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None] + output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None] + output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None] + optimizer.backward_by_grad( tensor=output_obj_, grad=output_obj_grad_, @@ -507,9 +510,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # Format output_obj_grad input_obj_grad = {} if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - for k, v in micro_batch.items(): - if isinstance(v, torch.Tensor) and v.grad is not None: - input_obj_grad[k] = v.grad + pass else: for k, v in input_obj.items(): if isinstance(v, torch.Tensor) and v.grad is not None: @@ -550,10 +551,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj_, _ = tree_flatten(output_obj) # y output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + # filter item which is not torch.Tensor + output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None] + output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None] + optimizer.backward_by_grad( tensor=output_obj_, grad=output_obj_grad_, - inputs=list(model_chunk[model_chunk_id].parameters()), + inputs=list(model_chunk.parameters()), retain_graph=False, ) @@ -634,7 +639,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): tree_map(release_tensor_data, output_obj) # add input and output object for backward b - self.input_tensors[model_chunk_id].append((micro_batch, input_obj)) + self.input_tensors[model_chunk_id].append(input_obj) # for bwd b&w, we only need the graph(grad_fn) of output_obj # Do not release_tensor_data loss, release_tensor_data other output_obj; @@ -692,7 +697,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) # get input and output object from buffer; - micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0) + input_obj = self.input_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0) # save output_tensor_grad for dw @@ -708,7 +713,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk=model_chunk, model_chunk_id=model_chunk_id, optimizer=optimizer, - micro_batch=micro_batch, input_obj=input_obj, output_obj=output_obj, output_obj_grad=output_tensor_grad, diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 354f110f0..50cc965bb 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -26,6 +26,7 @@ class PipelineStageManager: pg_mesh: ProcessGroupMesh, pipeline_axis: int, enable_interleave: bool = False, + use_zbv: bool = False, num_model_chunks: int = 1, num_layers_per_stage: Optional[List[int]] = None, ) -> None: @@ -49,6 +50,7 @@ class PipelineStageManager: next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") self.is_interleave = enable_interleave + self.use_zbv = use_zbv # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers self.num_model_chunks: int = num_model_chunks # for shardformer, hold stage indices of model @@ -85,6 +87,16 @@ class PipelineStageManager: num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) stage_indices = [] + if self.use_zbv: + stage_indices.append([num_layers_per_stage_accumulated[stage], num_layers_per_stage_accumulated[stage + 1]]) + stage_indices.append( + [ + num_layers_per_stage_accumulated[2 * num_stages - stage - 1], + num_layers_per_stage_accumulated[2 * num_stages - stage], + ] + ) + return stage_indices + for model_chunk in range(num_model_chunks): start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py index e2f71ff89..f79bdeb3a 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -15,6 +15,7 @@ class _PipelineStageManager(PipelineStageManager): self.is_interleave = False self.num_layers_per_stage = None self.num_model_chunks = 1 + self.use_zbv = False @property def num_stages(self): diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py index d39c5ea91..722b8fd7c 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -15,6 +15,7 @@ class _PipelineStageManager(PipelineStageManager): self.is_interleave = False self.num_layers_per_stage = None self.num_model_chunks = 1 + self.use_zbv = False @property def num_stages(self): diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 14bc3475d..0f2d6c49c 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -1,6 +1,5 @@ from copy import deepcopy from functools import partial -from types import MethodType from typing import Tuple import pytest @@ -22,36 +21,50 @@ from tests.kit.model_zoo import model_zoo class MlpModel(nn.Module): - def __init__(self, in_dim, out_dim, num_layers): + def __init__( + self, + in_dim, + out_dim, + num_layers, + stage_index=None, + stage_mgr: PipelineStageManager = None, + ): super().__init__() - self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) + self.layers = nn.Sequential(*[nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) def forward( self, - hidden_states, + data: torch.Tensor = None, + hidden_states: torch.Tensor = None, + stage_index=None, + stage_mgr: PipelineStageManager = None, + model_chunk_id: int = None, ): - for layer in self.layers: - hidden_states = layer(hidden_states) - return hidden_states - - -def pp_linear_fwd( - forward, - data: torch.Tensor = None, - hidden_states: torch.Tensor = None, - stage_mgr: PipelineStageManager = None, - model_chunk_id: int = None, -): - with stage_mgr.switch_model_chunk_id(model_chunk_id): - # fwd end - if stage_mgr.is_first_stage() and model_chunk_id == 1: - return forward(hidden_states) - # fwd start - elif stage_mgr.is_first_stage() and model_chunk_id == 0: - return {"hidden_states": forward(data)} - # fwd middle + if stage_mgr is None: + hidden_states = data + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states else: - return {"hidden_states": forward(hidden_states)} + # Set not used layer to None + held_layers = self.layers[stage_index[0] : stage_index[1]] + + # fwd end + if stage_mgr.is_first_stage() and stage_mgr.model_chunk_id == 1: + return held_layers(hidden_states) + # fwd start + elif stage_mgr.is_first_stage() and stage_mgr.model_chunk_id == 0: + return {"hidden_states": held_layers(data)} + # fwd middle + else: + return {"hidden_states": held_layers(hidden_states)} + + +def assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups): + for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()): + if key_base == key_pp: + if key_base != "params": + assert val_base == val_pp def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: @@ -554,7 +567,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config): num_model_chunk = test_config["num_model_chunk"] # stage_manager stage_manager = PipelineStageManager( - pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk + pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk, use_zbv=True ) h, a, s = 4096, 32, 1024 @@ -600,67 +613,27 @@ def run_fwd_bwd_vschedule_with_optim(test_config): before_init_memory = torch.cuda.memory_allocated() / 1024**3 print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};") model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) - # data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)} - # input_base = [t.clone() for t in data_iter] input_base = {k: v.clone() for k, v in data_iter.items()} model_base = deepcopy(model) + model_pp = deepcopy(model) + layers_per_stage = stage_manager.distribute_layers(len(model.layers)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) - if rank == 0: - # layer 0 & 7 to chunk 0 on rank0 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 0 or idx == 7: - sub_model._forward = sub_model.forward - sub_model.forward = MethodType( - partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)), - sub_model._forward, - ) - local_chunk.append(sub_model) - elif rank == 1: - # layer 1 & 6 to chunk 1 on rank1 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 1 or idx == 6: - sub_model._forward = sub_model.forward - sub_model.forward = MethodType( - partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)), - sub_model._forward, - ) - local_chunk.append(sub_model) - elif rank == 2: - # layer 2 & 5 to chunk 2 on rank2 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 2 or idx == 5: - sub_model._forward = sub_model.forward - sub_model.forward = MethodType( - partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)), - sub_model._forward, - ) - local_chunk.append(sub_model) - else: - # layer 3 & 4 to chunk 3 on rank3 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 3 or idx == 4: - sub_model._forward = sub_model.forward - sub_model.forward = MethodType( - partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)), - sub_model._forward, - ) - local_chunk.append(sub_model) + model_pp._forward = model_pp.forward + + model_pp.forward = partial(model_pp._forward, stage_mgr=stage_manager) # init optimizer optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5) - optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), momentum=0.1, lr=1e-5)) + optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5)) after_init_memory = torch.cuda.memory_allocated() / 1024**3 print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};") torch.cuda.synchronize() result = scheduler.forward_backward_step( - model_chunk=local_chunk, + model_chunk=model_pp, data_iter=iter([data_iter]), criterion=criterion, optimizer=optimizer_pp, @@ -694,7 +667,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config): # Fwd bwd for base ########################## # fwd & bwd - output_base = model_base(input_base["data"]) + # output_base = model_base(input_base["data"]) + output_base = model_base.forward(data=input_base["data"]) loss_base = criterion_base(output_base) loss_base.backward() optimizer_base.step() @@ -707,63 +681,53 @@ def run_fwd_bwd_vschedule_with_optim(test_config): assert_close(result["loss"], loss_base) assert_close(result["outputs"]["hidden_states"], output_base) - # print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ") - ########################## - # assert weight - ########################## - if rank == 0: - # layer 0 - assert_close(local_chunk[0].weight, model_base.layers[0].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad) - # layer 7 - assert_close(local_chunk[1].weight, model_base.layers[7].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad) - if rank == 1: - # layer 1 - assert_close(local_chunk[0].weight, model_base.layers[1].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad) - # layer 6 - assert_close(local_chunk[1].weight, model_base.layers[6].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad) - if rank == 2: - # layer 2 - assert_close(local_chunk[0].weight, model_base.layers[2].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad) - # layer 5 - assert_close(local_chunk[1].weight, model_base.layers[5].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad) - if rank == 3: - # layer 3 - assert_close(local_chunk[0].weight, model_base.layers[3].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad) - # layer 4 - assert_close(local_chunk[1].weight, model_base.layers[4].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) - - ########################## - # assert optim state - ########################## + # ########################## + # # assert weight & optim state + # ########################## optim_base_state = optimizer_base.state_dict()["state"] optim_pp_state = optimizer_pp.state_dict()["state"] optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0] optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0] - # if rank == 0: - # print(f"optim_base_state {optim_base_state}") - # assert param group - for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()): - if key_base == key_pp: - if key_base != "params": - assert val_base == val_pp - else: - # BUG: - # param_base: [0, 1, 2, 3, 4, 5, 6, 7]; - # params pp: [0, 1]; - assert val_base[:2] == val_pp + if rank == 0: + # layer 0 + assert_close(model_pp.layers[0].weight, model_base.layers[0].weight) + assert_close(model_pp.layers[0].weight.grad, model_base.layers[0].weight.grad) + assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[0]["momentum_buffer"]) + # layer 7 + assert_close(model_pp.layers[7].weight, model_base.layers[7].weight) + assert_close(model_pp.layers[7].weight.grad, model_base.layers[7].weight.grad) + assert_close(optim_pp_state[7]["momentum_buffer"], optim_base_state[7]["momentum_buffer"]) + if rank == 1: + # layer 1 + assert_close(model_pp.layers[1].weight, model_base.layers[1].weight) + assert_close(model_pp.layers[1].weight.grad, model_base.layers[1].weight.grad) + assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[1]["momentum_buffer"]) + # layer 6 + assert_close(model_pp.layers[6].weight, model_base.layers[6].weight) + assert_close(model_pp.layers[6].weight.grad, model_base.layers[6].weight.grad) + assert_close(optim_pp_state[6]["momentum_buffer"], optim_base_state[6]["momentum_buffer"]) + if rank == 2: + # layer 2 + assert_close(model_pp.layers[2].weight, model_base.layers[2].weight) + assert_close(model_pp.layers[2].weight.grad, model_base.layers[2].weight.grad) + assert_close(optim_pp_state[2]["momentum_buffer"], optim_base_state[2]["momentum_buffer"]) + # layer 5 + assert_close(model_pp.layers[5].weight, model_base.layers[5].weight) + assert_close(model_pp.layers[5].weight.grad, model_base.layers[5].weight.grad) + assert_close(optim_pp_state[5]["momentum_buffer"], optim_base_state[5]["momentum_buffer"]) + if rank == 3: + # layer 3 + assert_close(model_pp.layers[3].weight, model_base.layers[3].weight) + assert_close(model_pp.layers[3].weight.grad, model_base.layers[3].weight.grad) + assert_close(optim_pp_state[3]["momentum_buffer"], optim_base_state[3]["momentum_buffer"]) + # layer 4 + assert_close(model_pp.layers[4].weight, model_base.layers[4].weight) + assert_close(model_pp.layers[4].weight.grad, model_base.layers[4].weight.grad) + assert_close(optim_pp_state[4]["momentum_buffer"], optim_base_state[4]["momentum_buffer"]) - # assert state - assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[2 * rank]["momentum_buffer"]) - assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[2 * rank + 1]["momentum_buffer"]) + # assert optim param_groups + assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups) # TODO:4) support Hybrid base 3)