mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-26 15:32:22 +00:00
[fix] fix poc test; add comments in poc;
This commit is contained in:
parent
582ba0d6ff
commit
b1419ef76a
@ -1,5 +1,6 @@
|
|||||||
import gc
|
import gc
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -18,6 +19,16 @@ OUT_DIM = 8192
|
|||||||
NUM_LAYER = 3
|
NUM_LAYER = 3
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
|
num_params = 0
|
||||||
|
num_params_trainable = 0
|
||||||
|
for p in model.parameters():
|
||||||
|
num_params += p.numel()
|
||||||
|
if p.requires_grad:
|
||||||
|
num_params_trainable += p.numel()
|
||||||
|
return num_params, num_params_trainable
|
||||||
|
|
||||||
|
|
||||||
# A simple MLP
|
# A simple MLP
|
||||||
class MlpModel(nn.Module):
|
class MlpModel(nn.Module):
|
||||||
def __init__(self, in_dim=IN_DIM, out_dim=OUT_DIM, num_layers=NUM_LAYER):
|
def __init__(self, in_dim=IN_DIM, out_dim=OUT_DIM, num_layers=NUM_LAYER):
|
||||||
@ -58,7 +69,7 @@ def backward_w_not_last(tensors, grad, model):
|
|||||||
|
|
||||||
|
|
||||||
# In this poc, we check feasibility of spliting dx and dw in bwd propagation
|
# In this poc, we check feasibility of spliting dx and dw in bwd propagation
|
||||||
def test_dx_dw_split():
|
def run_dx_dw_split():
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
model = nn.Linear(8, 8, bias=None).to(device=device)
|
model = nn.Linear(8, 8, bias=None).to(device=device)
|
||||||
print(f"model numel {get_model_numel(model)}") # 4GB
|
print(f"model numel {get_model_numel(model)}") # 4GB
|
||||||
@ -93,7 +104,7 @@ def test_dx_dw_split():
|
|||||||
|
|
||||||
# In this poc, we check nsync of spliting dx and dw in bwd propagation in following order:
|
# In this poc, we check nsync of spliting dx and dw in bwd propagation in following order:
|
||||||
# fwd1 --> fwd2 --> dx1 --> dx2 --> dw1 --> dw2
|
# fwd1 --> fwd2 --> dx1 --> dx2 --> dw1 --> dw2
|
||||||
def test_double_dx_dw_split_nsync():
|
def run_double_dx_dw_split_nsync():
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
model = nn.Linear(8, 8, bias=None).to(device=device)
|
model = nn.Linear(8, 8, bias=None).to(device=device)
|
||||||
# print(f"model numel {get_model_numel(model)}") # 4GB
|
# print(f"model numel {get_model_numel(model)}") # 4GB
|
||||||
@ -156,7 +167,7 @@ def test_double_dx_dw_split_nsync():
|
|||||||
|
|
||||||
# In this poc, we check sync of spliting dx and dw in bwd propagation in following order:
|
# In this poc, we check sync of spliting dx and dw in bwd propagation in following order:
|
||||||
# fwd1 --> fwd2 --> dx1 --> dw1 --> dx2 --> dw2
|
# fwd1 --> fwd2 --> dx1 --> dw1 --> dx2 --> dw2
|
||||||
def test_double_dx_dw_split_sync():
|
def run_double_dx_dw_split_sync():
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
model = nn.Linear(8, 8, bias=None).to(device=device)
|
model = nn.Linear(8, 8, bias=None).to(device=device)
|
||||||
x1 = torch.rand(8, 8).to(device=device)
|
x1 = torch.rand(8, 8).to(device=device)
|
||||||
@ -234,7 +245,7 @@ def test_double_dx_dw_split_sync():
|
|||||||
|
|
||||||
|
|
||||||
# In this poc, we check if a memory leak has occurred after del input & loss(with graph)
|
# In this poc, we check if a memory leak has occurred after del input & loss(with graph)
|
||||||
def mem_dx_dw():
|
def run_mem_dx_dw():
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||||
model = MlpModel().to(device=device)
|
model = MlpModel().to(device=device)
|
||||||
@ -321,7 +332,7 @@ def mem_dx_dw():
|
|||||||
|
|
||||||
|
|
||||||
# In this poc, we check if a memory leak has occurred after del input & loss(with graph) & activation
|
# In this poc, we check if a memory leak has occurred after del input & loss(with graph) & activation
|
||||||
def activation_dx_dw():
|
def run_activation_dx_dw():
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
# model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device)
|
# model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device)
|
||||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||||
@ -395,7 +406,7 @@ def activation_dx_dw():
|
|||||||
|
|
||||||
|
|
||||||
# In this poc, we apply model chunk instead of layer
|
# In this poc, we apply model chunk instead of layer
|
||||||
def model_chunk_dx_dw():
|
def run_model_chunk_dx_dw():
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
num_layers = 4
|
num_layers = 4
|
||||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||||
@ -484,7 +495,7 @@ def model_chunk_dx_dw():
|
|||||||
|
|
||||||
|
|
||||||
# In this poc, we apply model chunk and a pp group for communication
|
# In this poc, we apply model chunk and a pp group for communication
|
||||||
def model_chunk_dx_dw_communication(
|
def run_model_chunk_dx_dw_communication(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
port: int,
|
port: int,
|
||||||
@ -729,7 +740,7 @@ def schedule_w():
|
|||||||
|
|
||||||
|
|
||||||
# In this poc, we apply a scheduling method for each rank: schedule_f --> schedule_b --> schedule_w
|
# In this poc, we apply a scheduling method for each rank: schedule_f --> schedule_b --> schedule_w
|
||||||
def model_chunk_dx_dw_comm_interleaved(
|
def run_model_chunk_dx_dw_comm_interleaved(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
port: int,
|
port: int,
|
||||||
@ -1079,7 +1090,7 @@ def model_chunk_dx_dw_comm_interleaved(
|
|||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_dx_dw_dist():
|
def test_dx_dw_dist():
|
||||||
spawn(
|
spawn(
|
||||||
model_chunk_dx_dw_comm_interleaved,
|
run_model_chunk_dx_dw_comm_interleaved,
|
||||||
nprocs=4,
|
nprocs=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
|||||||
return num_params, num_params_trainable
|
return num_params, num_params_trainable
|
||||||
|
|
||||||
|
|
||||||
# Test manual v_schedule with multiple microbatch
|
# 1) Test manual v_schedule with multiple microbatch
|
||||||
def run_fwd_bwd_iter_input(
|
def run_fwd_bwd_iter_input(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
@ -474,7 +474,7 @@ def run_fwd_bwd_iter_input(
|
|||||||
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
||||||
|
|
||||||
|
|
||||||
# Test v_schedule generated by graph with multiple microbatch
|
# 2) Test v_schedule generated by graph with multiple microbatch
|
||||||
def run_fwd_bwd_with_vschedule(
|
def run_fwd_bwd_with_vschedule(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
@ -616,6 +616,18 @@ def run_fwd_bwd_with_vschedule(
|
|||||||
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
||||||
|
|
||||||
|
|
||||||
|
# 3) add optimizer base 2)
|
||||||
|
def run_fwd_bwd_vschedule_with_optim(
|
||||||
|
rank: int,
|
||||||
|
world_size: int,
|
||||||
|
port: int,
|
||||||
|
num_microbatch: int,
|
||||||
|
batch_size: int,
|
||||||
|
num_model_chunk: int,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("num_microbatch", [4])
|
@pytest.mark.parametrize("num_microbatch", [4])
|
||||||
@pytest.mark.parametrize("batch_size", [4])
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
|
Loading…
Reference in New Issue
Block a user