[fix] fix poc test; add comments in poc;

This commit is contained in:
duanjunwen 2024-08-28 05:47:53 +00:00
parent 582ba0d6ff
commit b1419ef76a
2 changed files with 34 additions and 11 deletions

View File

@ -1,5 +1,6 @@
import gc
from copy import deepcopy
from typing import Tuple
import torch
import torch.distributed as dist
@ -18,6 +19,16 @@ OUT_DIM = 8192
NUM_LAYER = 3
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
num_params = 0
num_params_trainable = 0
for p in model.parameters():
num_params += p.numel()
if p.requires_grad:
num_params_trainable += p.numel()
return num_params, num_params_trainable
# A simple MLP
class MlpModel(nn.Module):
def __init__(self, in_dim=IN_DIM, out_dim=OUT_DIM, num_layers=NUM_LAYER):
@ -58,7 +69,7 @@ def backward_w_not_last(tensors, grad, model):
# In this poc, we check feasibility of spliting dx and dw in bwd propagation
def test_dx_dw_split():
def run_dx_dw_split():
device = "cuda:0"
model = nn.Linear(8, 8, bias=None).to(device=device)
print(f"model numel {get_model_numel(model)}") # 4GB
@ -93,7 +104,7 @@ def test_dx_dw_split():
# In this poc, we check nsync of spliting dx and dw in bwd propagation in following order:
# fwd1 --> fwd2 --> dx1 --> dx2 --> dw1 --> dw2
def test_double_dx_dw_split_nsync():
def run_double_dx_dw_split_nsync():
device = "cuda:0"
model = nn.Linear(8, 8, bias=None).to(device=device)
# print(f"model numel {get_model_numel(model)}") # 4GB
@ -156,7 +167,7 @@ def test_double_dx_dw_split_nsync():
# In this poc, we check sync of spliting dx and dw in bwd propagation in following order:
# fwd1 --> fwd2 --> dx1 --> dw1 --> dx2 --> dw2
def test_double_dx_dw_split_sync():
def run_double_dx_dw_split_sync():
device = "cuda:0"
model = nn.Linear(8, 8, bias=None).to(device=device)
x1 = torch.rand(8, 8).to(device=device)
@ -234,7 +245,7 @@ def test_double_dx_dw_split_sync():
# In this poc, we check if a memory leak has occurred after del input & loss(with graph)
def mem_dx_dw():
def run_mem_dx_dw():
device = "cuda:0"
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
model = MlpModel().to(device=device)
@ -321,7 +332,7 @@ def mem_dx_dw():
# In this poc, we check if a memory leak has occurred after del input & loss(with graph) & activation
def activation_dx_dw():
def run_activation_dx_dw():
device = "cuda:0"
# model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device)
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
@ -395,7 +406,7 @@ def activation_dx_dw():
# In this poc, we apply model chunk instead of layer
def model_chunk_dx_dw():
def run_model_chunk_dx_dw():
device = "cuda:0"
num_layers = 4
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
@ -484,7 +495,7 @@ def model_chunk_dx_dw():
# In this poc, we apply model chunk and a pp group for communication
def model_chunk_dx_dw_communication(
def run_model_chunk_dx_dw_communication(
rank: int,
world_size: int,
port: int,
@ -729,7 +740,7 @@ def schedule_w():
# In this poc, we apply a scheduling method for each rank: schedule_f --> schedule_b --> schedule_w
def model_chunk_dx_dw_comm_interleaved(
def run_model_chunk_dx_dw_comm_interleaved(
rank: int,
world_size: int,
port: int,
@ -1079,7 +1090,7 @@ def model_chunk_dx_dw_comm_interleaved(
@rerun_if_address_is_in_use()
def test_dx_dw_dist():
spawn(
model_chunk_dx_dw_comm_interleaved,
run_model_chunk_dx_dw_comm_interleaved,
nprocs=4,
)

View File

@ -36,7 +36,7 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
return num_params, num_params_trainable
# Test manual v_schedule with multiple microbatch
# 1) Test manual v_schedule with multiple microbatch
def run_fwd_bwd_iter_input(
rank: int,
world_size: int,
@ -474,7 +474,7 @@ def run_fwd_bwd_iter_input(
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
# Test v_schedule generated by graph with multiple microbatch
# 2) Test v_schedule generated by graph with multiple microbatch
def run_fwd_bwd_with_vschedule(
rank: int,
world_size: int,
@ -616,6 +616,18 @@ def run_fwd_bwd_with_vschedule(
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
# 3) add optimizer base 2)
def run_fwd_bwd_vschedule_with_optim(
rank: int,
world_size: int,
port: int,
num_microbatch: int,
batch_size: int,
num_model_chunk: int,
):
pass
@pytest.mark.dist
@pytest.mark.parametrize("num_microbatch", [4])
@pytest.mark.parametrize("batch_size", [4])