mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[zero] support extra dp (#6123)
* [zero] support extra dp * [zero] update checkpoint * fix bugs * fix bugs
This commit is contained in:
42
tests/test_zero/test_low_level/test_coll_nd.py
Normal file
42
tests/test_zero/test_low_level/test_coll_nd.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd
|
||||
|
||||
|
||||
def check_all_gather_2d():
|
||||
seed_all(1024)
|
||||
tensor = torch.rand(128, device=get_current_device())
|
||||
extra_dp_size, inner_dp_size = 2, 2
|
||||
pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size)
|
||||
extra_dp_group = pg_mesh.get_group_along_axis(0)
|
||||
inner_dp_group = pg_mesh.get_group_along_axis(1)
|
||||
ranks = [dist.get_rank(extra_dp_group), dist.get_rank(inner_dp_group)]
|
||||
sizes = [dist.get_world_size(extra_dp_group), dist.get_world_size(inner_dp_group)]
|
||||
chunk = tensor.chunk(dist.get_world_size())[np.ravel_multi_index(ranks, sizes)].clone()
|
||||
out = torch.zeros_like(tensor)
|
||||
all_gather_into_flat_tensor_nd(out, chunk, group=(extra_dp_group, inner_dp_group))
|
||||
assert torch.equal(out, tensor)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
|
||||
check_all_gather_2d()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_comm_nd():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_comm_nd()
|
@@ -2,11 +2,13 @@ import copy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
@@ -123,7 +125,8 @@ def exam_zero_1_2(fp8_communication: bool):
|
||||
|
||||
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
||||
@parameterize("master_weights", [True, False])
|
||||
def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
||||
@parameterize("extra_dp_size", [1, 2])
|
||||
def exam_zero_1_torch_ddp(dtype: torch.dtype, master_weights: bool, extra_dp_size: int):
|
||||
"""
|
||||
In this test, two pairs of model and optimizers are created.
|
||||
1. zero: use sharded optimizer and fp16 parameters
|
||||
@@ -132,6 +135,15 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
||||
We feed these two sets of models with the same input and check if the
|
||||
differences in model output and updated parameters are within tolerance.
|
||||
"""
|
||||
if extra_dp_size > 1 and dtype != torch.bfloat16:
|
||||
return
|
||||
if extra_dp_size > 1:
|
||||
pg_mesh = ProcessGroupMesh(extra_dp_size, dist.get_world_size() // extra_dp_size)
|
||||
extra_dp_group = pg_mesh.get_group_along_axis(0)
|
||||
dp_group = pg_mesh.get_group_along_axis(1)
|
||||
else:
|
||||
extra_dp_group = None
|
||||
dp_group = None
|
||||
local_rank = torch.distributed.get_rank()
|
||||
seed_all(1453)
|
||||
|
||||
@@ -153,6 +165,8 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
||||
initial_scale=1,
|
||||
reduce_bucket_size=1024 * 1024,
|
||||
master_weights=master_weights,
|
||||
dp_process_group=dp_group,
|
||||
extra_dp_group=extra_dp_group,
|
||||
)
|
||||
|
||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||
@@ -200,14 +214,14 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
|
||||
exam_zero_1_torch_ddp(world_size=world_size)
|
||||
exam_zero_1_torch_ddp()
|
||||
exam_zero_1_2()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_1_2():
|
||||
spawn(run_dist, 2)
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -2,12 +2,14 @@ import copy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
@@ -40,11 +42,19 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
||||
assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def exam_zero_1_torch_ddp_ckpt():
|
||||
@parameterize("extra_dp_size", [1, 2])
|
||||
def exam_zero_1_torch_ddp_ckpt(extra_dp_size: int):
|
||||
"""
|
||||
We examine the state_dict of zero and DDP.
|
||||
Moreover, we examine the zero's loading checkpoint of a torch ckpt.
|
||||
"""
|
||||
if extra_dp_size > 1:
|
||||
pg_mesh = ProcessGroupMesh(extra_dp_size, dist.get_world_size() // extra_dp_size)
|
||||
extra_dp_group = pg_mesh.get_group_along_axis(0)
|
||||
dp_group = pg_mesh.get_group_along_axis(1)
|
||||
else:
|
||||
dp_group = None
|
||||
extra_dp_group = None
|
||||
local_rank = torch.distributed.get_rank()
|
||||
seed_all(1453)
|
||||
|
||||
@@ -60,7 +70,12 @@ def exam_zero_1_torch_ddp_ckpt():
|
||||
# we only test stage 1 here
|
||||
# the state dicts of stage 1 and stage 2 are the same
|
||||
zero_optimizer = LowLevelZeroOptimizer(
|
||||
zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=262144
|
||||
zero_optimizer,
|
||||
overlap_communication=True,
|
||||
initial_scale=1,
|
||||
reduce_bucket_size=262144,
|
||||
dp_process_group=dp_group,
|
||||
extra_dp_group=extra_dp_group,
|
||||
)
|
||||
|
||||
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
|
||||
@@ -111,7 +126,7 @@ def run_dist(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_ckpt():
|
||||
spawn(run_dist, 2)
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user