diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index 6cd74b3b4..a37bef29a 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -58,6 +58,28 @@ class OptimizerWrapper: def backward_by_grad(self, tensor: Tensor, grad: Tensor): torch.autograd.backward(tensor, grad) + def backward_b_by_grad(self, tensor: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True): + """ + Performs a backward pass for dx, we only calculate dx = w*dy here + + Args: + tensor (Tensor): y or loss of current chunk; + grad_tensors (Tensor): dy of current chunk; + input_obj (Tensor): x of current chunk; + retain_graph (bool): default to be True, we retain graph in backward_b + """ + torch.autograd.backward( + tensors=tensor, + grad_tensors=grad_tensors, + inputs=inputs, + retain_graph=retain_graph, + ) + + def backward_w_by_grad(): + """ + Performs a backward pass for dw, we only calculate dw = x*dy here + """ + def state_dict(self): """ Returns the optimizer state. diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 02ecf5b19..90da38fcd 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -413,7 +413,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): self, model_chunk: Union[ModuleList, Module], model_chunk_id: int, - # optimizer: OptimizerWrapper, + optimizer: OptimizerWrapper, input_obj: Optional[dict], output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict], @@ -447,7 +447,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): torch.autograd.backward(output_obj, inputs=input_obj, retain_graph=True) else: # commom bwd step - # BUG:output_obj_grad is None torch.autograd.backward( tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True ) @@ -564,7 +563,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): scheduled_node, model_chunk: Union[ModuleList, Module], model_chunk_id: int, - # optimizer: OptimizerWrapper, + optimizer: OptimizerWrapper, # input_obj: Optional[dict], # output_obj: Union[dict, torch.Tensor], # output_obj_grad: Optional[dict], @@ -614,7 +613,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): input_object_grad = self.backward_b_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, - # optimizer: OptimizerWrapper, + optimizer=optimizer, input_obj=input_obj, output_obj=output_obj, output_obj_grad=output_tensor_grad, @@ -715,6 +714,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): scheduled_node=scheduled_node, model_chunk=model_chunk, model_chunk_id=scheduled_node.chunk, + optimizer=optimizer, ) elif scheduled_node.type == "W": self.schedule_w( diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index ea7abc432..d97e60e2f 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -9,6 +9,7 @@ from torch.testing import assert_close import colossalai from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import OptimizerWrapper from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager @@ -625,7 +626,148 @@ def run_fwd_bwd_vschedule_with_optim( batch_size: int, num_model_chunk: int, ): - pass + # init dist + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + rank = dist.get_rank() + pp_size = world_size + pg_mesh = ProcessGroupMesh(pp_size) + num_microbatch = num_microbatch + # stage_manager + stage_manager = PipelineStageManager( + pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk + ) + + h, a, s = 4096, 32, 1024 + mem_f = 34 * h + 5 * a * s + mem_w = -32 * h + mem_b = -mem_w - mem_f + graph = PipelineGraph( + n_stage=world_size, + n_micro=num_microbatch, + f_cost=6, + b_cost=6, + w_cost=6, + c_cost=6, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + # max_mem=mem_f * (p * 2 + m_offset), + ) + + zbv_schedule = graph.get_v_schedule() + + scheduler = ZeroBubbleVPipeScheduler( + schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ? + stage_manager=stage_manager, + num_model_chunks=num_model_chunk, + num_microbatch=num_microbatch, + overlap_p2p=False, + ) + + # init loss func + def criterion(x, *args, **kwargs): + return (x * x).mean() + + # init model and input + batch_size = batch_size + num_layers = 8 + assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk" + in_dim = out_dim = 8 + print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") + model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) + data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] + + input_base = [t.clone() for t in data_iter] + model_base = deepcopy(model) + + if rank == 0: + # layer 0 & 7 to chunk 0 on rank0 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 0 or idx == 7: + local_chunk.append(sub_model) + elif rank == 1: + # layer 1 & 6 to chunk 1 on rank1 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 1 or idx == 6: + local_chunk.append(sub_model) + elif rank == 2: + # layer 2 & 5 to chunk 2 on rank2 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 2 or idx == 5: + local_chunk.append(sub_model) + else: + # layer 3 & 4 to chunk 3 on rank3 + local_chunk = torch.nn.Sequential().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 3 or idx == 4: + local_chunk.append(sub_model) + + # init optimizer + optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5) + optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5)) + + print( + f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + torch.cuda.synchronize() + scheduler.run_forward_backward( + model_chunk=local_chunk, + data_iter=iter(data_iter), + criterion=criterion, + optimizer=optimizer_pp, + return_loss=None, + return_outputs=None, + ) + + ########################## + # Fwd bwd for base + ########################## + # fwd & bwd + output_base = model_base(input_base[0]) + loss_base = criterion(output_base) + loss_base.backward() + optimizer_base.step() + print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ########################## + # assert weight + ########################## + if rank == 0: + # layer 0 + assert_close(local_chunk[0].weight, model_base.layers[0].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad) + # layer 7 + assert_close(local_chunk[1].weight, model_base.layers[7].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad) + if rank == 1: + # layer 1 + assert_close(local_chunk[0].weight, model_base.layers[1].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad) + # layer 6 + assert_close(local_chunk[1].weight, model_base.layers[6].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad) + if rank == 2: + # layer 2 + assert_close(local_chunk[0].weight, model_base.layers[2].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad) + # layer 5 + assert_close(local_chunk[1].weight, model_base.layers[5].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad) + if rank == 3: + # layer 3 + assert_close(local_chunk[0].weight, model_base.layers[3].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad) + # layer 4 + assert_close(local_chunk[1].weight, model_base.layers[4].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) + + ########################## + # assert optim state + ########################## @pytest.mark.dist @@ -634,8 +776,16 @@ def run_fwd_bwd_vschedule_with_optim( @pytest.mark.parametrize("num_model_chunk", [4]) @rerun_if_address_is_in_use() def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int): + # spawn( + # run_fwd_bwd_with_vschedule, + # nprocs=4, + # num_microbatch=num_microbatch, + # batch_size=batch_size, + # num_model_chunk=num_model_chunk, + # ) + spawn( - run_fwd_bwd_with_vschedule, + run_fwd_bwd_vschedule_with_optim, nprocs=4, num_microbatch=num_microbatch, batch_size=batch_size,