From b1419ef76a24c8bca0da1032331717017bd79ca7 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 28 Aug 2024 05:47:53 +0000 Subject: [PATCH] [fix] fix poc test; add comments in poc; --- .../test_schedule/test_zerobubble_poc.py | 29 +++++++++++++------ .../test_schedule/test_zerobubble_pp.py | 16 ++++++++-- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_poc.py b/tests/test_pipeline/test_schedule/test_zerobubble_poc.py index 5fa3c62e4..737e19aa8 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_poc.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_poc.py @@ -1,5 +1,6 @@ import gc from copy import deepcopy +from typing import Tuple import torch import torch.distributed as dist @@ -18,6 +19,16 @@ OUT_DIM = 8192 NUM_LAYER = 3 +def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: + num_params = 0 + num_params_trainable = 0 + for p in model.parameters(): + num_params += p.numel() + if p.requires_grad: + num_params_trainable += p.numel() + return num_params, num_params_trainable + + # A simple MLP class MlpModel(nn.Module): def __init__(self, in_dim=IN_DIM, out_dim=OUT_DIM, num_layers=NUM_LAYER): @@ -58,7 +69,7 @@ def backward_w_not_last(tensors, grad, model): # In this poc, we check feasibility of spliting dx and dw in bwd propagation -def test_dx_dw_split(): +def run_dx_dw_split(): device = "cuda:0" model = nn.Linear(8, 8, bias=None).to(device=device) print(f"model numel {get_model_numel(model)}") # 4GB @@ -93,7 +104,7 @@ def test_dx_dw_split(): # In this poc, we check nsync of spliting dx and dw in bwd propagation in following order: # fwd1 --> fwd2 --> dx1 --> dx2 --> dw1 --> dw2 -def test_double_dx_dw_split_nsync(): +def run_double_dx_dw_split_nsync(): device = "cuda:0" model = nn.Linear(8, 8, bias=None).to(device=device) # print(f"model numel {get_model_numel(model)}") # 4GB @@ -156,7 +167,7 @@ def test_double_dx_dw_split_nsync(): # In this poc, we check sync of spliting dx and dw in bwd propagation in following order: # fwd1 --> fwd2 --> dx1 --> dw1 --> dx2 --> dw2 -def test_double_dx_dw_split_sync(): +def run_double_dx_dw_split_sync(): device = "cuda:0" model = nn.Linear(8, 8, bias=None).to(device=device) x1 = torch.rand(8, 8).to(device=device) @@ -234,7 +245,7 @@ def test_double_dx_dw_split_sync(): # In this poc, we check if a memory leak has occurred after del input & loss(with graph) -def mem_dx_dw(): +def run_mem_dx_dw(): device = "cuda:0" print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") model = MlpModel().to(device=device) @@ -321,7 +332,7 @@ def mem_dx_dw(): # In this poc, we check if a memory leak has occurred after del input & loss(with graph) & activation -def activation_dx_dw(): +def run_activation_dx_dw(): device = "cuda:0" # model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device) print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") @@ -395,7 +406,7 @@ def activation_dx_dw(): # In this poc, we apply model chunk instead of layer -def model_chunk_dx_dw(): +def run_model_chunk_dx_dw(): device = "cuda:0" num_layers = 4 print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") @@ -484,7 +495,7 @@ def model_chunk_dx_dw(): # In this poc, we apply model chunk and a pp group for communication -def model_chunk_dx_dw_communication( +def run_model_chunk_dx_dw_communication( rank: int, world_size: int, port: int, @@ -729,7 +740,7 @@ def schedule_w(): # In this poc, we apply a scheduling method for each rank: schedule_f --> schedule_b --> schedule_w -def model_chunk_dx_dw_comm_interleaved( +def run_model_chunk_dx_dw_comm_interleaved( rank: int, world_size: int, port: int, @@ -1079,7 +1090,7 @@ def model_chunk_dx_dw_comm_interleaved( @rerun_if_address_is_in_use() def test_dx_dw_dist(): spawn( - model_chunk_dx_dw_comm_interleaved, + run_model_chunk_dx_dw_comm_interleaved, nprocs=4, ) diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 7f02ca477..ea7abc432 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -36,7 +36,7 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: return num_params, num_params_trainable -# Test manual v_schedule with multiple microbatch +# 1) Test manual v_schedule with multiple microbatch def run_fwd_bwd_iter_input( rank: int, world_size: int, @@ -474,7 +474,7 @@ def run_fwd_bwd_iter_input( assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) -# Test v_schedule generated by graph with multiple microbatch +# 2) Test v_schedule generated by graph with multiple microbatch def run_fwd_bwd_with_vschedule( rank: int, world_size: int, @@ -616,6 +616,18 @@ def run_fwd_bwd_with_vschedule( assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) +# 3) add optimizer base 2) +def run_fwd_bwd_vschedule_with_optim( + rank: int, + world_size: int, + port: int, + num_microbatch: int, + batch_size: int, + num_model_chunk: int, +): + pass + + @pytest.mark.dist @pytest.mark.parametrize("num_microbatch", [4]) @pytest.mark.parametrize("batch_size", [4])