mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 23:11:55 +00:00
[fix] fix poc test; add comments in poc;
This commit is contained in:
parent
582ba0d6ff
commit
b1419ef76a
@ -1,5 +1,6 @@
|
||||
import gc
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -18,6 +19,16 @@ OUT_DIM = 8192
|
||||
NUM_LAYER = 3
|
||||
|
||||
|
||||
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
num_params = 0
|
||||
num_params_trainable = 0
|
||||
for p in model.parameters():
|
||||
num_params += p.numel()
|
||||
if p.requires_grad:
|
||||
num_params_trainable += p.numel()
|
||||
return num_params, num_params_trainable
|
||||
|
||||
|
||||
# A simple MLP
|
||||
class MlpModel(nn.Module):
|
||||
def __init__(self, in_dim=IN_DIM, out_dim=OUT_DIM, num_layers=NUM_LAYER):
|
||||
@ -58,7 +69,7 @@ def backward_w_not_last(tensors, grad, model):
|
||||
|
||||
|
||||
# In this poc, we check feasibility of spliting dx and dw in bwd propagation
|
||||
def test_dx_dw_split():
|
||||
def run_dx_dw_split():
|
||||
device = "cuda:0"
|
||||
model = nn.Linear(8, 8, bias=None).to(device=device)
|
||||
print(f"model numel {get_model_numel(model)}") # 4GB
|
||||
@ -93,7 +104,7 @@ def test_dx_dw_split():
|
||||
|
||||
# In this poc, we check nsync of spliting dx and dw in bwd propagation in following order:
|
||||
# fwd1 --> fwd2 --> dx1 --> dx2 --> dw1 --> dw2
|
||||
def test_double_dx_dw_split_nsync():
|
||||
def run_double_dx_dw_split_nsync():
|
||||
device = "cuda:0"
|
||||
model = nn.Linear(8, 8, bias=None).to(device=device)
|
||||
# print(f"model numel {get_model_numel(model)}") # 4GB
|
||||
@ -156,7 +167,7 @@ def test_double_dx_dw_split_nsync():
|
||||
|
||||
# In this poc, we check sync of spliting dx and dw in bwd propagation in following order:
|
||||
# fwd1 --> fwd2 --> dx1 --> dw1 --> dx2 --> dw2
|
||||
def test_double_dx_dw_split_sync():
|
||||
def run_double_dx_dw_split_sync():
|
||||
device = "cuda:0"
|
||||
model = nn.Linear(8, 8, bias=None).to(device=device)
|
||||
x1 = torch.rand(8, 8).to(device=device)
|
||||
@ -234,7 +245,7 @@ def test_double_dx_dw_split_sync():
|
||||
|
||||
|
||||
# In this poc, we check if a memory leak has occurred after del input & loss(with graph)
|
||||
def mem_dx_dw():
|
||||
def run_mem_dx_dw():
|
||||
device = "cuda:0"
|
||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
model = MlpModel().to(device=device)
|
||||
@ -321,7 +332,7 @@ def mem_dx_dw():
|
||||
|
||||
|
||||
# In this poc, we check if a memory leak has occurred after del input & loss(with graph) & activation
|
||||
def activation_dx_dw():
|
||||
def run_activation_dx_dw():
|
||||
device = "cuda:0"
|
||||
# model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device)
|
||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
@ -395,7 +406,7 @@ def activation_dx_dw():
|
||||
|
||||
|
||||
# In this poc, we apply model chunk instead of layer
|
||||
def model_chunk_dx_dw():
|
||||
def run_model_chunk_dx_dw():
|
||||
device = "cuda:0"
|
||||
num_layers = 4
|
||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
@ -484,7 +495,7 @@ def model_chunk_dx_dw():
|
||||
|
||||
|
||||
# In this poc, we apply model chunk and a pp group for communication
|
||||
def model_chunk_dx_dw_communication(
|
||||
def run_model_chunk_dx_dw_communication(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
port: int,
|
||||
@ -729,7 +740,7 @@ def schedule_w():
|
||||
|
||||
|
||||
# In this poc, we apply a scheduling method for each rank: schedule_f --> schedule_b --> schedule_w
|
||||
def model_chunk_dx_dw_comm_interleaved(
|
||||
def run_model_chunk_dx_dw_comm_interleaved(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
port: int,
|
||||
@ -1079,7 +1090,7 @@ def model_chunk_dx_dw_comm_interleaved(
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dx_dw_dist():
|
||||
spawn(
|
||||
model_chunk_dx_dw_comm_interleaved,
|
||||
run_model_chunk_dx_dw_comm_interleaved,
|
||||
nprocs=4,
|
||||
)
|
||||
|
||||
|
@ -36,7 +36,7 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
return num_params, num_params_trainable
|
||||
|
||||
|
||||
# Test manual v_schedule with multiple microbatch
|
||||
# 1) Test manual v_schedule with multiple microbatch
|
||||
def run_fwd_bwd_iter_input(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
@ -474,7 +474,7 @@ def run_fwd_bwd_iter_input(
|
||||
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
||||
|
||||
|
||||
# Test v_schedule generated by graph with multiple microbatch
|
||||
# 2) Test v_schedule generated by graph with multiple microbatch
|
||||
def run_fwd_bwd_with_vschedule(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
@ -616,6 +616,18 @@ def run_fwd_bwd_with_vschedule(
|
||||
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
||||
|
||||
|
||||
# 3) add optimizer base 2)
|
||||
def run_fwd_bwd_vschedule_with_optim(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
port: int,
|
||||
num_microbatch: int,
|
||||
batch_size: int,
|
||||
num_model_chunk: int,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("num_microbatch", [4])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
|
Loading…
Reference in New Issue
Block a user