[legacy] clean up legacy code (#4743)

* [legacy] remove outdated codes of pipeline (#4692)

* [legacy] remove cli of benchmark and update optim (#4690)

* [legacy] remove cli of benchmark and update optim

* [doc] fix cli doc test

* [legacy] fix engine clip grad norm

* [legacy] remove outdated colo tensor (#4694)

* [legacy] remove outdated colo tensor

* [test] fix test import

* [legacy] move outdated zero to legacy (#4696)

* [legacy] clean up utils (#4700)

* [legacy] clean up utils

* [example] update examples

* [legacy] clean up amp

* [legacy] fix amp module

* [legacy] clean up gpc (#4742)

* [legacy] clean up context

* [legacy] clean core, constants and global vars

* [legacy] refactor initialize

* [example] fix examples ci

* [example] fix examples ci

* [legacy] fix tests

* [example] fix gpt example

* [example] fix examples ci

* [devops] fix ci installation

* [example] fix examples ci
This commit is contained in:
Hongxin Liu
2023-09-18 16:31:06 +08:00
committed by GitHub
parent 32e7f99416
commit b5f9e37c70
342 changed files with 2919 additions and 4182 deletions

View File

@@ -1,124 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pytest
import torch
import torch.nn.functional as F
from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.random import add_seed, reset_seeds, seed, set_mode
from colossalai.testing import clear_cache_before_run, parameterize
from colossalai.utils.activation_checkpoint import checkpoint
def forward(x, weight):
out = torch.matmul(x, weight)
with seed(ParallelMode.DATA):
out_ = F.dropout(out, p=0.4, training=True)
return out_
def forward_inplace_ckpt(x, weight, cpu_offload=False):
out = torch.matmul(x, weight)
bn = torch.nn.BatchNorm1d(4, affine=False)
bn = bn.to(device="cuda")
out = bn(out)
def ckpt0(x):
return F.relu(x, inplace=True)
out = checkpoint(ckpt0, cpu_offload, out, use_reentrant=False)
return out
def forward_inplace(x, weight):
out = torch.matmul(x, weight)
bn = torch.nn.BatchNorm1d(4, affine=False)
bn = bn.to(device="cuda")
out = bn(out)
out = F.relu(out, inplace=True)
return out
@clear_cache_before_run()
@parameterize("use_reentrant", [True, False])
@parameterize("cpu_offload", [True, False])
def test_activation_checkpointing(cpu_offload, use_reentrant):
# as seed manager is singleton
# if we don't reset seeds here,
# other tests might affect this test
reset_seeds()
# We put initialization here to avoid change cuda rng state below
inputs = torch.rand(2, 2, requires_grad=True, device='cuda')
weight = torch.rand(2, 4, requires_grad=True, device='cuda')
# Get a copy of input tensors
inputs_ = torch.empty(2, 2, requires_grad=True, device='cuda')
inputs_.data.copy_(inputs.data)
weight_ = torch.empty(2, 4, requires_grad=True, device='cuda')
weight_.data.copy_(weight.data)
add_seed(ParallelMode.GLOBAL, 1024)
add_seed(ParallelMode.DATA, 1026)
set_mode(ParallelMode.GLOBAL)
global_cuda_rng_state = torch.cuda.get_rng_state()
set_mode(ParallelMode.DATA)
data_parallel_cuda_rng_state = torch.cuda.get_rng_state()
set_mode(ParallelMode.GLOBAL)
out = forward(inputs, weight)
loss = out.sum()
loss.backward()
# Recover cuda rng states
set_mode(ParallelMode.GLOBAL)
torch.cuda.set_rng_state(global_cuda_rng_state)
set_mode(ParallelMode.DATA)
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
set_mode(ParallelMode.GLOBAL)
out = checkpoint(forward, cpu_offload, inputs_, weight_, use_reentrant=use_reentrant)
loss = out.sum()
loss.backward()
assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match'
torch.cuda.empty_cache()
# Extra test for use_reentrant=False
if use_reentrant == False:
# Recover cuda rng states
set_mode(ParallelMode.GLOBAL)
torch.cuda.set_rng_state(global_cuda_rng_state)
set_mode(ParallelMode.DATA)
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
set_mode(ParallelMode.GLOBAL)
out = forward_inplace(inputs, weight)
loss = out.sum()
loss.backward()
# Recover cuda rng states
set_mode(ParallelMode.GLOBAL)
torch.cuda.set_rng_state(global_cuda_rng_state)
set_mode(ParallelMode.DATA)
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
set_mode(ParallelMode.GLOBAL)
out = forward_inplace_ckpt(inputs_, weight_, cpu_offload=cpu_offload)
loss = out.sum()
loss.backward()
assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match'
torch.cuda.empty_cache()
# as seed manager is singleton
# if we don't reset seeds here,
# other tests will fail if running together with this test
# as other tests can't overwrite the seed set by this test
reset_seeds()
if __name__ == "__main__":
test_activation_checkpointing(False, False)

View File

@@ -1,77 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pprint
import pytest
import torch
import torch.nn as nn
import colossalai.legacy.nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
from colossalai.utils import is_using_pp
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
def build_pipeline(model):
from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
depth = len(model)
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
layers = []
for i in range(depth):
if start <= i < end:
layers.append(model[i])
else:
layers.append(nn.Identity())
return nn.Sequential(*tuple(layers))
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
def check_checkpoint_1d(rank, world_size, port):
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),)
disable_existing_loggers()
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
sd1 = m1.state_dict()
if gpc.get_global_rank() == 0:
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
save_checkpoint("test.pt", 0, m1)
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
if is_using_pp():
m2 = build_pipeline(m2)
load_checkpoint("test.pt", m2)
sd2 = m2.state_dict()
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
sd2 = gather_pipeline_parallel_state_dict(sd2)
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
if gpc.get_global_rank() == 0:
for k, v in sd1.items():
assert k in sd2
check_equal(v, sd2[k].to(torch.device("cpu")))
@pytest.mark.dist
@pytest.mark.skip("takes too long")
@skip_if_not_enough_gpus(min_gpus=8)
@rerun_if_address_is_in_use()
def test_checkpoint_1d():
spawn(check_checkpoint_1d, 8)
if __name__ == "__main__":
test_checkpoint_1d()

View File

@@ -1,77 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pprint
import pytest
import torch
import torch.nn as nn
import colossalai.legacy.nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
from colossalai.utils import is_using_pp
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
def build_pipeline(model):
from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
depth = len(model)
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
layers = []
for i in range(depth):
if start <= i < end:
layers.append(model[i])
else:
layers.append(nn.Identity())
return nn.Sequential(*tuple(layers))
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
def check_checkpoint_2d(rank, world_size, port):
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),)
disable_existing_loggers()
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
sd1 = m1.state_dict()
if gpc.get_global_rank() == 0:
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
save_checkpoint("test.pt", 0, m1)
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
if is_using_pp():
m2 = build_pipeline(m2)
load_checkpoint("test.pt", m2)
sd2 = m2.state_dict()
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
sd2 = gather_pipeline_parallel_state_dict(sd2)
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
if gpc.get_global_rank() == 0:
for k, v in sd1.items():
assert k in sd2
check_equal(v, sd2[k].to(torch.device("cpu")))
@pytest.mark.dist
@pytest.mark.skip("takes too long")
@skip_if_not_enough_gpus(min_gpus=8)
@rerun_if_address_is_in_use()
def test_checkpoint_2d():
spawn(check_checkpoint_2d, 8)
if __name__ == "__main__":
test_checkpoint_2d()

View File

@@ -1,77 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pprint
import pytest
import torch
import torch.nn as nn
import colossalai.legacy.nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
from colossalai.utils import is_using_pp
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
def build_pipeline(model):
from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
depth = len(model)
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
layers = []
for i in range(depth):
if start <= i < end:
layers.append(model[i])
else:
layers.append(nn.Identity())
return nn.Sequential(*tuple(layers))
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
def check_checkpoint_2p5d(rank, world_size, port):
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),)
disable_existing_loggers()
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
sd1 = m1.state_dict()
if gpc.get_global_rank() == 0:
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
save_checkpoint("test.pt", 0, m1)
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
if is_using_pp():
m2 = build_pipeline(m2)
load_checkpoint("test.pt", m2)
sd2 = m2.state_dict()
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
sd2 = gather_pipeline_parallel_state_dict(sd2)
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
if gpc.get_global_rank() == 0:
for k, v in sd1.items():
assert k in sd2
check_equal(v, sd2[k].to(torch.device("cpu")))
@pytest.mark.dist
@pytest.mark.skip("takes too long")
@skip_if_not_enough_gpus(min_gpus=8)
@rerun_if_address_is_in_use()
def test_checkpoint_2p5d():
spawn(check_checkpoint_2p5d, 8)
if __name__ == "__main__":
test_checkpoint_2p5d()

View File

@@ -1,77 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pprint
import pytest
import torch
import torch.nn as nn
import colossalai.legacy.nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
from colossalai.utils import is_using_pp
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
def build_pipeline(model):
from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
depth = len(model)
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
layers = []
for i in range(depth):
if start <= i < end:
layers.append(model[i])
else:
layers.append(nn.Identity())
return nn.Sequential(*tuple(layers))
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
def check_checkpoint_3d(rank, world_size, port):
config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),)
disable_existing_loggers()
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
sd1 = m1.state_dict()
if gpc.get_global_rank() == 0:
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
save_checkpoint("test.pt", 0, m1)
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
if is_using_pp():
m2 = build_pipeline(m2)
load_checkpoint("test.pt", m2)
sd2 = m2.state_dict()
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
sd2 = gather_pipeline_parallel_state_dict(sd2)
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
if gpc.get_global_rank() == 0:
for k, v in sd1.items():
assert k in sd2
check_equal(v, sd2[k].to(torch.device("cpu")))
@pytest.mark.dist
@pytest.mark.skip("takes too long")
@skip_if_not_enough_gpus(min_gpus=8)
@rerun_if_address_is_in_use()
def test_checkpoint_3d():
spawn(check_checkpoint_3d, 8)
if __name__ == "__main__":
test_checkpoint_3d()

View File

@@ -1,41 +0,0 @@
import torch
import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.zero.legacy.sharded_param import ShardedTensor
def run_tensor_move(rank, world_size, port):
colossalai.launch(config={}, rank=0, world_size=world_size, host='localhost', port=port, backend='nccl')
src_t = torch.ones(2, 3).cuda()
tgt_t = torch.zeros(2, 3)
colo_model_data_tensor_move(src_t, tgt_t)
assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
src_t = torch.ones(2, 3)
tgt_t = torch.zeros(2, 3).cuda().half()
colo_model_data_tensor_move(src_t, tgt_t)
# the src_t has been removed
assert (src_t.numel() == 0)
assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
src_t = ShardedTensor(torch.ones(2, 3))
tgt_t = ShardedTensor(torch.zeros(2, 3).cuda().half())
colo_model_data_tensor_move(src_t, tgt_t)
assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
assert (tgt_t.device.type == 'cuda')
colo_model_data_tensor_move_inline(tgt_t, torch.device('cpu'))
assert (tgt_t.device.type == 'cpu')
@rerun_if_address_is_in_use()
def test_tensor_move():
spawn(run_tensor_move, 1)
if __name__ == '__main__':
test_tensor_move()

View File

@@ -1,28 +0,0 @@
import pytest
import colossalai
from colossalai.testing import spawn
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():
frac1 = colo_device_memory_capacity(get_current_device())
colo_set_process_memory_fraction(0.5)
frac2 = colo_device_memory_capacity(get_current_device())
assert frac2 * 2 == frac1
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_colo_set_process_memory_fraction_and_colo_device_memory_capacity()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [3, 4])
def test_memory_utils(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_memory_utils(world_size=2)

View File

@@ -1,78 +0,0 @@
import pytest
import torch
from torch.nn.parameter import Parameter
from torch.nn.utils import clip_grad_norm_
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.tensor import ColoTensorSpec, ProcessGroup, distspec
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.utils.common import clip_grad_norm
def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8):
return abs(num - other) <= atol + rtol * other
def shard_param(p: ColoParameter) -> None:
pg = p.get_process_group()
p._redistribute(distspec.ShardSpec([0], [pg.tp_world_size()]))
p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach()
def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None:
pg = colo_p.get_process_group()
if p.shape != colo_p.shape:
grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()]
else:
grad = p.grad
assert torch.allclose(grad, colo_p.grad), f'diff: {torch.abs(grad - colo_p.grad)}'
@parameterize('dtype', [torch.float])
@parameterize('device', ['mixed', 'cuda', 'cpu'])
@parameterize('norm_type', [2.0, 3.0, float('inf')])
def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float):
print(f'{world_size}, {dtype}, {device}, {norm_type}')
cuda_device = get_current_device()
devices = [cuda_device] * 4
if device == 'cpu':
devices = [torch.device('cpu')] * 4
elif device == 'mixed':
devices = [cuda_device] * 2 + [torch.device('cpu')] * 2
pg = ProcessGroup(tp_degree=world_size)
params = [Parameter(torch.empty(4, 4, dtype=dtype, device=devices[i])) for i in range(4)]
colo_params = [
ColoParameter(torch.empty(4, 4, dtype=dtype, device=devices[i]), spec=ColoTensorSpec(pg)) for i in range(4)
]
for p, colo_p in zip(params, colo_params):
grad = torch.rand_like(p)
p.grad = grad
colo_p.grad = grad.clone().detach()
shard_param(colo_params[0])
shard_param(colo_params[2])
torch_norm = clip_grad_norm_(params, 1.0, norm_type=norm_type)
colo_norm = clip_grad_norm(colo_params, 1.0, norm_type=norm_type)
assert close(torch_norm, colo_norm), f'diff: {abs(torch_norm-colo_norm)}'
for p, colo_p in zip(params, colo_params):
check_grad_equal(p, colo_p)
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_grad_clip_norm(world_size=world_size)
@pytest.mark.skip("this need to be updated")
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_zero_clip_grad(world_size: int):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_zero_clip_grad(2)

View File

@@ -1,111 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import checkpoint, clip_grad_norm_fp32
from colossalai.zero.legacy.shard_utils.tensor_shard_strategy import TensorShardStrategy
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
def checkpoint_wrapper(module, enable=True):
if enable:
module.forward = partial(checkpoint, module.forward, False)
return module
class Net(nn.Module):
def __init__(self, checkpoint=False) -> None:
super().__init__()
self.fc1 = nn.Linear(5, 5)
self.fc2 = nn.Linear(5, 5)
self.fc3 = nn.Linear(5, 1)
if checkpoint:
self.fc1 = checkpoint_wrapper(self.fc1)
self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def run_step(model, optimizer, x, enable_autocast=False, norm_type=2.0):
model.train()
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(x)
loss = y.sum()
loss = loss.float()
loss.backward()
clip_grad(model, norm_type)
optimizer.step()
def clip_grad(model, norm_type):
if isinstance(model, DDP):
clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=norm_type)
else:
clip_grad_norm_fp32(model.parameters(), max_norm=1.0, norm_type=norm_type)
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
if loose:
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
return torch.allclose(tensor_a, tensor_b)
def check_grads(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_grad = zero_p.grad.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(4)
if rank >= len(chunks):
continue
grad = chunks[rank]
if zero_p.zero_shard_padding > 0:
zero_grad = zero_grad[:-zero_p.zero_shard_padding]
assert grad.dtype == zero_grad.dtype
assert allclose(grad, zero_grad, loose=loose)
def check_params(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_shard_padding = zero_p.zero_shard_padding
zero_p = zero_p.clone().to(p.device)
chunks = torch.flatten(p).chunk(4)
if rank >= len(chunks):
continue
p = chunks[rank]
if zero_shard_padding > 0:
zero_p = zero_p[:-zero_shard_padding]
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose)
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_zero_clip_grad():
world_size = 4
spawn(run_dist, world_size)
if __name__ == '__main__':
test_zero_clip_grad()