diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_poc.py b/tests/test_pipeline/test_schedule/test_zerobubble_poc.py index ac7ea3f9a..5fa3c62e4 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_poc.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_poc.py @@ -1,6 +1,5 @@ import gc from copy import deepcopy -from typing import Tuple import torch import torch.distributed as dist @@ -13,11 +12,13 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.testing import rerun_if_address_is_in_use, spawn +# info of model IN_DIM = 8192 OUT_DIM = 8192 NUM_LAYER = 3 +# A simple MLP class MlpModel(nn.Module): def __init__(self, in_dim=IN_DIM, out_dim=OUT_DIM, num_layers=NUM_LAYER): super().__init__() @@ -29,29 +30,10 @@ class MlpModel(nn.Module): return x -def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: - num_params = 0 - num_params_trainable = 0 - for p in model.parameters(): - num_params += p.numel() - if p.requires_grad: - num_params_trainable += p.numel() - return num_params, num_params_trainable - - # Step1: dx = w*dy def backward_b(loss, x, model): print(f"Before bwd b: {torch.cuda.memory_allocated()/1024**3 :.3f} GB") - # print(f"Before x grad {x.grad}") - # for name, param in model.named_parameters(): - # print(f"Before bwd b \n param {param}\n param gard {param.grad}\n") - torch.autograd.backward(loss, inputs=x, retain_graph=True) - - # for name, param in model.named_parameters(): - # print(f"After bwd b \n param {param}\n param gard {param.grad}\n") - - # print(f"After x grad {x.grad}") print(f"After bwd b: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") @@ -64,15 +46,7 @@ def backward_b_not_last(tensors, grad, x, model): def backward_w(loss, model): print(f"Before bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - - # for name, param in model.named_parameters(): - # print(f"Before bwd w \n param {param}\n param gard {param.grad}\n") - torch.autograd.backward(loss, inputs=list(model.parameters())) - - # for name, param in model.named_parameters(): - # print(f"After bwd w \n param {param}\n param gard {param.grad}\n") - print(f"After bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") @@ -83,6 +57,7 @@ def backward_w_not_last(tensors, grad, model): print(f"After bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") +# In this poc, we check feasibility of spliting dx and dw in bwd propagation def test_dx_dw_split(): device = "cuda:0" model = nn.Linear(8, 8, bias=None).to(device=device) @@ -116,6 +91,8 @@ def test_dx_dw_split(): assert torch.equal(p1.grad, p2.grad) +# In this poc, we check nsync of spliting dx and dw in bwd propagation in following order: +# fwd1 --> fwd2 --> dx1 --> dx2 --> dw1 --> dw2 def test_double_dx_dw_split_nsync(): device = "cuda:0" model = nn.Linear(8, 8, bias=None).to(device=device) @@ -177,16 +154,14 @@ def test_double_dx_dw_split_nsync(): assert_close(p1.grad, p2.grad) +# In this poc, we check sync of spliting dx and dw in bwd propagation in following order: +# fwd1 --> fwd2 --> dx1 --> dw1 --> dx2 --> dw2 def test_double_dx_dw_split_sync(): device = "cuda:0" model = nn.Linear(8, 8, bias=None).to(device=device) - # print(f"model numel {get_model_numel(model)}") # 4GB x1 = torch.rand(8, 8).to(device=device) x2 = torch.rand(8, 8).to(device=device) - # x1 = torch.ones(8, 8).to(device=device) - # x2 = torch.ones(8, 8).to(device=device) - ref_model = deepcopy(model) ref_x1 = x1.clone() ref_x2 = x2.clone() @@ -239,7 +214,6 @@ def test_double_dx_dw_split_sync(): ref_loss2 = ref_model(ref_x2).sum() for p1, p2 in zip(model.parameters(), ref_model.parameters()): - # print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n") assert_close(p1, p2) assert_close(p1.grad, p2.grad) @@ -255,31 +229,13 @@ def test_double_dx_dw_split_sync(): # assert dx2 & dw2 == bwd 2 assert_close(x2.grad, ref_x2.grad) for p1, p2 in zip(model.parameters(), ref_model.parameters()): - # print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n") assert_close(p1, p2) assert_close(p1.grad, p2.grad) -def deallocate_output_tensor(out): - """Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. - - This method should be called right after the output tensor has been - sent to the next pipeline stage. At this point, the output tensor is - only useful for its '.grad_fn' field, and not its '.data'. - """ - assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ - assert out._base is None, "counter-productive to free a view of another tensor." - out.data = torch.empty( - (1,), - device=out.device, - dtype=out.dtype, - ) - - -# del loss and x +# In this poc, we check if a memory leak has occurred after del input & loss(with graph) def mem_dx_dw(): device = "cuda:0" - # model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device) print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") model = MlpModel().to(device=device) print(f"model numel {get_model_numel(model)}") # 4GB @@ -314,8 +270,6 @@ def mem_dx_dw(): # dw1 backward_w(loss1, model) - # deallocate_output_tensor(x1) - # deallocate_output_tensor(loss1) del loss1, x1 # del x1 # del y1 @@ -335,8 +289,6 @@ def mem_dx_dw(): # dw2 backward_w(loss2, model) - # deallocate_output_tensor(x2) - # deallocate_output_tensor(loss2) del x2, loss2 # del x2 # del y2 @@ -356,8 +308,6 @@ def mem_dx_dw(): # dw2 backward_w(loss3, model) - # deallocate_output_tensor(x3) - # deallocate_output_tensor(loss3) # del x3 # del y3 del x3, loss3 @@ -370,7 +320,7 @@ def mem_dx_dw(): print(obj) -# del activation +# In this poc, we check if a memory leak has occurred after del input & loss(with graph) & activation def activation_dx_dw(): device = "cuda:0" # model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device) @@ -385,17 +335,6 @@ def activation_dx_dw(): x3.requires_grad_() print(f"After init Model, x1,x2,x3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - # activations = {} - # def register_hooks(module): - # def activation_hook(module, input, output): - # activations[f"{module.__class__.__name__}_{id(module)}"] = output.detach() - # def bwd_hook(module, grad_input, grad_output): - # del activations[f"{module.__class__.__name__}_{id(module)}"] - # module.register_forward_hook(activation_hook) - # module.register_backward_hook(bwd_hook) - - # model.apply(register_hooks) - ############ # step1: ############ @@ -408,15 +347,9 @@ def activation_dx_dw(): # dx1 backward_b(loss1, x1, model) - # for name, p in model.named_parameters(): - # print(f"p grad {p.grad}") - # dw1 backward_w(loss1, model) - # for name, p in model.named_parameters(): - # del p.grad - # del loss1, x1 del loss1, x1, output1 print(f"After del : {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") @@ -433,15 +366,9 @@ def activation_dx_dw(): # dx2 backward_b(loss2, x2, model) - # for name, p in model.named_parameters(): - # print(f"p grad {p.grad}") - # dw2 backward_w(loss2, model) - # for name, p in model.named_parameters(): - # print(f"p grad {p.grad}") - # del x2, loss2 del x2, loss2, output2 print(f"After del : {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") @@ -467,6 +394,7 @@ def activation_dx_dw(): print(f"After del : {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") +# In this poc, we apply model chunk instead of layer def model_chunk_dx_dw(): device = "cuda:0" num_layers = 4 @@ -555,6 +483,7 @@ def model_chunk_dx_dw(): print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") +# In this poc, we apply model chunk and a pp group for communication def model_chunk_dx_dw_communication( rank: int, world_size: int, @@ -598,9 +527,6 @@ def model_chunk_dx_dw_communication( ########################## if rank == 0: output1 = model_chunk_0(input) - # detach output1; then output1 for chunk 0, output1_dt for chunk 1; - # output1_dt_rank0 = output1.detach() - # output1_dt_rank0.requires_grad_() print( f"After chunk0 fwd (include detach output1): {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" ) @@ -689,7 +615,7 @@ def model_chunk_dx_dw_communication( print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") -# Return: output, loss +# fwd schedule def schedule_f( stage_manager: PipelineStageManager, comm: PipelineP2PCommunication, @@ -738,6 +664,7 @@ def schedule_f( return input, output, None +# bwd b schedule def schedule_b( stage_manager: PipelineStageManager, comm: PipelineP2PCommunication, @@ -759,7 +686,6 @@ def schedule_b( # bwd step backward_b_not_last(tensors=output, grad=output_grad, x=input, model=model_chunk[model_chunk_id]) - backward_w_not_last(tensors=output, grad=output_grad, model=model_chunk[model_chunk_id]) # send bwd to prev @@ -776,27 +702,17 @@ def schedule_b( output_grad = output_grad else: prev_rank = stage_manager.get_prev_rank() - # print(f"prev_rank {prev_rank} curr rank {stage_manager.get_rank()}") output_grad, _ = comm.recv_backward(next_rank=prev_rank) # bwd step - # print(f"Before input grad {input.grad}") - # for name, param in model_chunk[model_chunk_id].named_parameters(): - # print(f"Before {name} grad {param.grad}") - if stage_manager.is_first_stage(ignore_chunk=True): backward_b(loss=output_grad, x=input, model=model_chunk[model_chunk_id]) backward_w(loss=output_grad, model=model_chunk[model_chunk_id]) else: # commom bwd step - # print(f"output_grad {output_grad}") backward_b_not_last(tensors=output, grad=output_grad, x=input, model=model_chunk[model_chunk_id]) backward_w_not_last(tensors=output, grad=output_grad, model=model_chunk[model_chunk_id]) - # print(f"After input grad {input.grad}") - # for name, param in model_chunk[model_chunk_id].named_parameters(): - # print(f"After {name} grad {param.grad}") - # send bwd to next if stage_manager.is_last_stage(ignore_chunk=True): return input.grad @@ -807,10 +723,12 @@ def schedule_b( return input.grad +# bwd w schedule (dw already splite in schedule b) def schedule_w(): pass +# In this poc, we apply a scheduling method for each rank: schedule_f --> schedule_b --> schedule_w def model_chunk_dx_dw_comm_interleaved( rank: int, world_size: int, @@ -858,21 +776,9 @@ def model_chunk_dx_dw_comm_interleaved( if idx == 3 or idx == 4: chunk_3.append(sub_model) - # # test checkpoint - # check_fn = lambda submodule: isinstance(submodule, (Linear)) - # non_reentrant_wrapper = partial( - # checkpoint_wrapper, - # # checkpoint_impl=CheckpointImpl.NO_REENTRANT, - # checkpoint_impl=CheckpointImpl.REENTRANT, - # ) - # apply_activation_checkpointing( - # model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn - # ) - print( f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" ) - # set_checkpoint_early_stop(False) # buffer use to save input and output ########################## @@ -1051,7 +957,6 @@ def model_chunk_dx_dw_comm_interleaved( model_chunk=chunk_3, model_chunk_id=chunk_id, ) - # print(f"input_grad4 {input_grad4}") ###### # bwd rank 1->4 @@ -1069,7 +974,6 @@ def model_chunk_dx_dw_comm_interleaved( model_chunk=chunk_3, model_chunk_id=chunk_id, ) - # print(f"input_grad3 {input_grad3}") # chunk 2 id 0 (layer 2) bwd if rank == 2: @@ -1083,7 +987,6 @@ def model_chunk_dx_dw_comm_interleaved( model_chunk=chunk_2, model_chunk_id=chunk_id, ) - # print(f"input_grad2 {input_grad2}") # chunk 1 id 0 (layer 1) bwd if rank == 1: @@ -1110,7 +1013,6 @@ def model_chunk_dx_dw_comm_interleaved( model_chunk=chunk_0, model_chunk_id=chunk_id, ) - # print(f"input_grad0 {input_grad0}") ########################## # Fwd bwd for base @@ -1169,8 +1071,6 @@ def model_chunk_dx_dw_comm_interleaved( del input2, output2, input_grad2, input5, output5, input_grad5 if rank == 3: del input3, output3, input_grad3, input4, output4, input_grad4 - # print(f"After del device: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") - del loss_base, output_base print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") @@ -1185,11 +1085,4 @@ def test_dx_dw_dist(): if __name__ == "__main__": - # test_dx_dw_split() - # test_double_dx_dw_split_nsync() - # test_double_dx_dw_split_sync() - # mem_dx_dw() - # activation_dx_dw() - # model_chunk_dx_dw() - test_dx_dw_dist()