[utils] correct cpu memory used and capacity in the context of multi-process (#726)

This commit is contained in:
Jiarui Fang
2022-04-12 14:57:54 +08:00
committed by GitHub
parent 7db3ccc79b
commit 53cb584808
17 changed files with 52 additions and 20 deletions

141
tests/test_zero/common.py Normal file
View File

@@ -0,0 +1,141 @@
from functools import partial
import torch
import torch.distributed as dist
from colossalai.logging import get_dist_logger
from colossalai.utils import checkpoint
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
LOGGER = get_dist_logger('zero_test')
MP_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), parallel=dict(pipeline=dict(size=1), tensor=dict(size=2, mode=None)))
_ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
fp32_reduce_scatter=False,
offload_config=None,
gradient_predivide_factor=1.0,
use_memory_tracer=False,
shard_strategy=TensorShardStrategy(),
reuse_fp16_shard=False)
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
initial_scale=2**5,
min_scale=1,
growth_factor=2,
backoff_factor=0.5,
growth_interval=1000,
hysteresis=2,
max_scale=2**32)
ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
zero=dict(
model_config=_ZERO_MODEL_CONFIG,
optimizer_config=_ZERO_OPTIMIZER_CONFIG,
),
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
CONFIG = dict(fp16=dict(mode=None,),
zero=dict(level=3,
verbose=False,
offload_optimizer_config=dict(device='cpu', pin_memory=True, buffer_count=5, fast_init=False),
offload_param_config=dict(device='cpu',
pin_memory=True,
buffer_count=5,
buffer_size=1e8,
max_in_cpu=1e9)),
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, ShardedModelV2):
model.backward(loss)
else:
loss.backward()
def checkpoint_wrapper(module, enable=True):
if enable:
module.forward = partial(checkpoint, module.forward)
return module
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
if loose:
return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3)
return torch.allclose(tensor_a, tensor_b)
def check_grads(model, zero_model, loose=False):
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_grad = zero_p.grad.clone().to(p.device)
grad = p.grad.float()
assert grad.dtype == zero_grad.dtype
assert allclose(grad, zero_grad, loose=loose)
def check_params(model, zero_model, loose=False):
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_p = zero_p.clone().to(p.device)
# assert p.dtype == zero_p.dtype
assert allclose(p.float(), zero_p.float(), loose=loose), f"diff {p.float() - zero_p.float()}"
def check_grads_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
# zero_grad = zero_p.grad.clone().to(p.device)
if zero_p.colo_attr.is_replicated:
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
grad = chunks[rank].float()
if zero_grad.size(0) > grad.size(0):
zero_grad = zero_grad[:grad.size(0)]
else:
zero_grad = zero_p.colo_attr.saved_grad.payload
grad = p.grad.to(zero_grad.dtype)
assert grad.dtype == zero_grad.dtype
assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}'
def check_params_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_p = zero_p.clone().to(p.device)
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
p = chunks[rank]
if zero_p.size(0) > p.size(0):
zero_p = zero_p[:p.size(0)]
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose)
def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False):
rank = dist.get_rank()
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
if zero_p.colo_attr.param_is_sharded:
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
p = chunks[rank].float()
if zero_p.size(0) > p.size(0):
zero_p = zero_p[:p.size(0)]
else:
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device)
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'

View File

@@ -0,0 +1,76 @@
from functools import partial
import colossalai
from colossalai.utils.cuda import get_current_device
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import BucketTensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_zero.test_sharded_optim_v2 import _run_step
from common import CONFIG
@parameterize("cpu_offload", [True, False])
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
@parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio):
test_models = ['repeated_computed_layers']
shard_strategy = shard_strategy_class()
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(),
shard_strategy=shard_strategy,
shard_param=True):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(
zero_model,
shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None,
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
reuse_fp16_shard=True,
)
sharded_optim = HybridAdam(zero_model.parameters(), lr=1e-3)
sharded_optim = ShardedOptimizerV2(zero_model,
sharded_optim,
cpu_offload=cpu_offload,
gpu_margin_mem_ratio=gpu_margin_mem_ratio)
for i, (data, label) in enumerate(train_dataloader):
if i > 1:
break
assert zero_model.overflow_counter == 0
data, label = data.cuda(), label.cuda()
_run_step(zero_model, sharded_optim, data, label, criterion, False)
for param in zero_model.parameters():
assert not has_inf_or_nan(param.colo_attr.sharded_data_tensor.payload)
def _run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_test_found_inf()
# use_cpuadam = True can be used with cpu_offload = False
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_found_inf(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_found_inf(world_size=2)

View File

@@ -0,0 +1,74 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer.model_data_memtracer import \
colo_model_mem_usage
from colossalai.utils.memory import colo_device_memory_used
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG
@parameterize("init_device_type", ['cpu', 'cuda'])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_model_test(init_device_type, shard_strategy_class):
logger = get_dist_logger("test_zero_init")
for get_components_func in non_distributed_component_funcs:
model_builder, _, _, _, _ = get_components_func()
if init_device_type == 'cuda':
init_device = get_current_device()
elif init_device_type == 'cpu':
init_device = torch.device("cpu")
else:
continue
model_numel_tensor = torch.zeros(1, dtype=torch.int)
with ZeroInitContext(target_device=init_device,
shard_strategy=shard_strategy_class(),
shard_param=True,
model_numel_tensor=model_numel_tensor):
model = model_builder(checkpoint=True)
for param in model.parameters():
assert hasattr(param, 'colo_attr')
assert param.colo_attr.sharded_data_tensor.dtype == torch.half
assert param.colo_attr.sharded_data_tensor.is_sharded
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
cuda_mem_use, _ = colo_model_mem_usage(model)
model_data_cuda_mem_MB = cuda_mem_use / 1e6
logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
sys_cuda_mem_MB = colo_device_memory_used(get_current_device()) / 1e6
logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0])
logger.info(f"Model Number Parameter {model_numel_tensor.numpy()[0]/1e6} M", ranks=[0])
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_model_test()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_zero_init_context(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_init_context(4)

View File

@@ -0,0 +1,69 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads_padding, run_fwd_bwd
@parameterize("enable_autocast", [True])
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
def run_model_test(enable_autocast, shard_strategy_class):
test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module']
shard_strategy = shard_strategy_class()
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func()
with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()),
shard_strategy=shard_strategy,
shard_param=True):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True)
model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
model = model.cuda()
model = DDP(model)
for i, (data, label) in enumerate(train_dataloader):
if i > 5:
break
data, label = cast_tensor_to_fp16(data).cuda(), label.cuda()
run_fwd_bwd(model, data, label, criterion, enable_autocast)
run_fwd_bwd(zero_model, data, label, criterion, enable_autocast)
check_grads_padding(model, zero_model, loose=True)
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_model_test()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_shard_model_v2(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_shard_model_v2(world_size=2)

View File

@@ -0,0 +1,96 @@
from copy import deepcopy
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize
from colossalai.utils import free_port
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_param import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.testing import rerun_on_exception
from tests.test_zero.common import CONFIG, allclose
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_shard_tensor_with_strategy(shard_strategy_class, world_size):
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
assert list(t.origin_shape) == [world_size * 2, 3]
assert list(t.shape) == [world_size * 2, 3]
shard_strategy = shard_strategy_class()
# test shard strategy
shard_strategy.shard([t])
assert list(t.shape) == [6], f"{list(t.shape)} vs 6"
shard_strategy.gather([t])
assert list(t.shape) == [world_size * 2, 3], f"{list(t.shape)} vs {[world_size * 2, 3]}"
def _run_shard_tensor(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_shard_tensor_with_strategy(world_size=world_size)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_shard_tensor(world_size):
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
def _run_shard_param_v2(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
param = torch.nn.Parameter(torch.randn(2, 3))
param_ref = deepcopy(param)
sparam = ShardedParamV2(param=param)
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
# Test get memory usage
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
sparam.remove_torch_payload()
assert (param.data.numel() == 0)
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
# 4 is size of dummy tensor of param.data
assert cpu_mem_use == 2 * 3 * 4 * 2
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
sparam.remove_torch_payload()
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2
assert cuda_mem_use == 0
# append a grad to torch param
param.data = sparam.sharded_data_tensor.payload
param.grad = torch.randn(2, 3)
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2 + 2 * 3 * 4, f"cpu_mem_use {cpu_mem_use}"
assert cuda_mem_use == 0
# reuse torch grad for sparam
sparam.saved_grad = StatefulTensor(param.grad)
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2
assert cuda_mem_use == 0
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_shard_param_v2(world_size):
run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
# test_shard_tensor(2)
test_shard_param_v2(2)

View File

@@ -0,0 +1,117 @@
from functools import partial
import colossalai
from colossalai.utils.cuda import get_current_device
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_sharded_model_params
def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
model.train()
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, ShardedModelV2):
optimizer.backward(loss)
else:
loss.backward()
optimizer.step()
@parameterize("cpu_offload", [True, False])
@parameterize("use_cpuadam", [True, False])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
@parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio):
test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module']
shard_strategy = shard_strategy_class()
if use_cpuadam and cpu_offload is False:
return
if gpu_margin_mem_ratio > 0.0 and not (cpu_offload and use_cpuadam):
return
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(),
shard_strategy=shard_strategy,
shard_param=True):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(
zero_model,
shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None,
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
reuse_fp16_shard=use_cpuadam,
)
model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
model = model.cuda().float()
if use_cpuadam:
optimizer_class = CPUAdam
optim = optimizer_class(model.parameters(), lr=1e-3)
sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3)
sharded_optim = ShardedOptimizerV2(zero_model,
sharded_optim,
cpu_offload=cpu_offload,
initial_scale=2**5,
gpu_margin_mem_ratio=gpu_margin_mem_ratio)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False)
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
if dist.get_world_size() > 1:
apex_model = DDP(apex_model)
for i, (data, label) in enumerate(train_dataloader):
if i > 5:
break
data, label = data.cuda(), label.cuda()
_run_step(apex_model, apex_optimizer, data, label, criterion, False)
_run_step(zero_model, sharded_optim, data, label, criterion, False)
check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam)
for param in model.parameters():
assert not has_inf_or_nan(param)
def _run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_test_sharded_optim_v2()
# use_cpuadam = True can be used with cpu_offload = False
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_sharded_optim_v2(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_sharded_optim_v2(world_size=2)

View File

@@ -0,0 +1,92 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from functools import partial
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.testing import rerun_on_exception
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy
from torchvision.models import resnet50
def run_dist(rank, world_size, port):
# this test only runs on resnet18
# as this model has sync batch normalization
# need to configure cudnn deterministic so that
# randomness of convolution layers will be disabled
zero_config = dict(model_config=dict(shard_strategy=TensorShardStrategy()))
colossalai.launch(config=dict(zero=zero_config, cudnn_determinstic=True, cudnn_benchmark=False),
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
with ZeroInitContext(target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True):
model = resnet50()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
engine, *args = colossalai.initialize(model, optimizer, criterion)
# train for dummy iterations
engine.train()
for _ in range(2):
data = torch.rand(4, 3, 128, 128).cuda().half()
label = torch.randint(0, 10, size=(4,)).cuda()
engine.zero_grad()
out = engine(data)
loss = engine.criterion(out, label)
engine.backward(loss)
engine.step()
# test
# need to make sure the batch norm stats are synchronized
# so that given the same input, the model will produce the same
# output on different ranks
engine.eval()
data = torch.rand(4, 3, 128, 128).cuda().half()
dist.broadcast(data, src=0, group=gpc.get_group(ParallelMode.DATA))
# predict
out = engine(data)
# test if results are equal
tensor_list = [torch.empty_like(out) for _ in range(world_size - 1)]
tensor_list.insert(rank, out)
dist.all_gather(tensor_list=tensor_list, tensor=out, group=gpc.get_group(ParallelMode.DATA))
assert torch.all(tensor_list[0] == tensor_list[1]), \
'expected the output from different ranks to be the same, but got different values'
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_sharded_optim_with_sync_bn():
"""
This test is to make sure that buffers are synchronized between ranks
when using ZeRO. An example of module buffer is the running stats of
BatchNormalization layer, i.e. mean and var.
If the buffers are not synchronized, the model will produce different
output even though the input and parameters are the same. This is not
wanted if we are doing predictions.
"""
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_sharded_optim_with_sync_bn()

View File

@@ -0,0 +1,59 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from copy import deepcopy
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_zero_state_dict(shard_strategy_class):
test_models = ['repeated_computed_layers', 'resnet18']
shard_strategy = shard_strategy_class()
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()),
shard_strategy=shard_strategy,
shard_param=True):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model, shard_strategy)
model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
model = model.cuda()
zero_state_dict = zero_model.state_dict()
for key, val in model.state_dict().items():
assert torch.equal(val, zero_state_dict[key])
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_zero_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_zero_state_dict(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_state_dict(2)

View File

@@ -0,0 +1,112 @@
import torch
import colossalai
import pytest
import torch.multiprocessing as mp
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.zero.utils import StatefulTensorMgr
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from torch.nn.parameter import Parameter
from typing import List
from functools import partial
class Net(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
# each parameter is 512 MB
self.p0 = Parameter(torch.empty(1024, 1024, 128))
self.p1 = Parameter(torch.empty(1024, 1024, 128))
self.p2 = Parameter(torch.empty(1024, 1024, 128))
def run_stm():
cuda_capacity = colo_device_memory_capacity(get_current_device())
fraction = (1.4 * 1024**3) / cuda_capacity
# limit max memory to 1.4GB
# which means only 2 parameters can be on CUDA
colo_set_process_memory_fraction(fraction)
model = Net()
for p in model.parameters():
p.colo_attr = ShardedParamV2(p, rm_torch_payload=True)
GLOBAL_MODEL_DATA_TRACER.register_model(model)
mem_collector = MemStatsCollector()
stateful_tensor_mgr = StatefulTensorMgr(mem_collector)
for p in model.parameters():
stateful_tensor_mgr.register_stateful_param(p.colo_attr)
mem_collector.start_collection()
# Compute order: 0 1 2 0 1
# warmup
# use naive eviction strategy
apply_adjust(model, model.p0, [model.p0], stateful_tensor_mgr)
mem_collector.sample_memstats()
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
mem_collector.sample_memstats()
apply_adjust(model, model.p2, [model.p1, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
mem_collector.finish_collection()
stateful_tensor_mgr.reset()
# warmup done
# use OPT-like eviction strategy
apply_adjust(model, model.p0, [model.p0, model.p1], stateful_tensor_mgr)
mem_collector.sample_memstats()
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
mem_collector.sample_memstats()
apply_adjust(model, model.p2, [model.p0, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
def apply_adjust(model: torch.nn.Module, compute_param: Parameter, cuda_param_after_adjust: List[Parameter],
stateful_tensor_mgr: StatefulTensorMgr):
compute_param.colo_attr._sharded_data_tensor.trans_state(TensorState.COMPUTE)
for p in model.parameters():
if p is not compute_param and p.colo_attr._sharded_data_tensor.state != TensorState.HOLD:
p.colo_attr._sharded_data_tensor.trans_state(TensorState.HOLD)
stateful_tensor_mgr.adjust_layout()
print_stats(model)
device = torch.device(torch.cuda.current_device())
cuda_param_after_adjust = [hash(p) for p in cuda_param_after_adjust]
for n, p in model.named_parameters():
if hash(p) in cuda_param_after_adjust:
assert p.colo_attr._sharded_data_tensor.device == device, f'{n} {p.colo_attr._sharded_data_tensor.device} vs {device}'
else:
assert p.colo_attr._sharded_data_tensor.device == torch.device('cpu')
def print_stats(model: torch.nn.Module):
msgs = []
for n, p in model.named_parameters():
msgs.append(f'{n}: {p.colo_attr._sharded_data_tensor.state}({p.colo_attr._sharded_data_tensor.device})')
print(f'[ {", ".join(msgs)} ]')
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_stm()
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_stateful_tensor_manager(world_size=1):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_stateful_tensor_manager()

View File

@@ -0,0 +1,93 @@
import pytest
import colossalai
from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage, colo_model_data_tensor_move,
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
colo_model_tensor_clone)
from colossalai.utils import free_port
import torch
from functools import partial
import torch.multiprocessing as mp
def _run_colo_tensor_mem_usage():
for i in range(1):
if i == 1:
t1 = StatefulTensor(torch.randn(2, 2))
t2 = StatefulTensor(torch.randn(4, 4))
c1, g1 = colo_tensor_mem_usage(t1)
c2, g2 = colo_tensor_mem_usage(t2)
assert c1 * 4 == c2
assert g1 * 4 == g2
else:
t1 = torch.randn(2, 2)
t2 = torch.randn(4, 4)
c1, g1 = colo_tensor_mem_usage(t1)
c2, g2 = colo_tensor_mem_usage(t2)
assert c1 * 4 == c2
assert g1 * 4 == g2
def _run_colo_model_data_tensor_move_inline():
for t in [StatefulTensor(torch.randn(2, 3)), torch.randn(2, 3)]:
colo_model_data_tensor_move_inline(t, get_current_device())
assert t.device == get_current_device()
def _run_colo_model_data_tensor_move():
for t in [(StatefulTensor(torch.ones(2, 3)), StatefulTensor(torch.zeros(2, 3).to(get_current_device()))),
(torch.ones(2, 3), torch.zeros(2, 3).to(get_current_device()))]:
cpu_t, cuda_t = t
colo_model_data_tensor_move(cpu_t, cuda_t)
assert cuda_t.device == get_current_device()
def _run_colo_model_data_move_to_cpu():
for t in [StatefulTensor(torch.randn(2, 2)), torch.randn(4, 4)]:
colo_model_data_move_to_cpu(t)
assert t.device == torch.device("cpu")
def _run_colo_model_tensor_clone():
for t in [
StatefulTensor(torch.randn(2, 2).cuda(torch.cuda.current_device())),
torch.randn(4, 4).cuda(torch.cuda.current_device())
]:
if issubclass(type(t), StatefulTensor):
assert t.payload.device == get_current_device()
else:
assert t.device == get_current_device()
p = colo_model_tensor_clone(t, get_current_device())
assert p.device == get_current_device()
for i in range(2):
for j in range(2):
if issubclass(type(t), StatefulTensor):
assert t.payload.device == p.device
assert t.payload[i][j] == p[i][j]
else:
assert t.device == p.device
assert t[i][j] == p[i][j]
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_colo_tensor_mem_usage()
_run_colo_model_data_tensor_move_inline()
_run_colo_model_data_tensor_move()
_run_colo_model_data_move_to_cpu()
_run_colo_model_tensor_clone()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4, 5])
def test_zero_tensor_utils(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_tensor_utils(world_size=2)

View File

@@ -0,0 +1,114 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from functools import partial
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
from colossalai.testing import rerun_on_exception
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params)
def run_dist(rank, world_size, port, parallel_config):
colossalai.launch(config=parallel_config,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True):
colo_model = model_builder(checkpoint=True)
colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3)
engine, train_dataloader, _, _ = colossalai.initialize(colo_model,
optimizer=colo_optimizer,
criterion=criterion,
train_dataloader=train_dataloader)
torch_model = model_builder(checkpoint=True).half()
col_model_deepcopy(engine.model, torch_model)
torch_model = torch_model.cuda().float()
engine.train()
torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3)
if dist.get_world_size() > 1:
torch_model = DDP(torch_model)
i = 0
for data, label in train_dataloader:
if i > 4:
break
data, label = data.cuda(), label.cuda()
engine.zero_grad()
torch_optimizer.zero_grad()
if criterion:
output = engine(data)
loss = engine.criterion(output, label)
torch_output = torch_model(data)
torch_loss = engine.criterion(torch_output, label)
else:
loss = engine(data, label)
torch_loss = torch_model(data, label)
engine.backward(loss)
engine.step()
torch_loss.backward()
for param in torch_model.parameters():
if param.grad is not None:
assert not has_inf_or_nan(param.grad)
torch_optimizer.step()
i += 1
if parallel_config == MP_PARALLEL_CONFIG:
check_params(torch_model, colo_model, loose=True)
elif parallel_config == ZERO_PARALLEL_CONFIG:
check_sharded_model_params(torch_model, colo_model, loose=True)
# FIXME: enable this test in next PR
@pytest.mark.skip
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_mp_engine(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG)
mp.spawn(run_func, nprocs=world_size)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_zero_engine(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_engine(world_size=4)