mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[utils] correct cpu memory used and capacity in the context of multi-process (#726)
This commit is contained in:
141
tests/test_zero/common.py
Normal file
141
tests/test_zero/common.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import checkpoint
|
||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
|
||||
LOGGER = get_dist_logger('zero_test')
|
||||
|
||||
MP_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), parallel=dict(pipeline=dict(size=1), tensor=dict(size=2, mode=None)))
|
||||
|
||||
_ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
|
||||
fp32_reduce_scatter=False,
|
||||
offload_config=None,
|
||||
gradient_predivide_factor=1.0,
|
||||
use_memory_tracer=False,
|
||||
shard_strategy=TensorShardStrategy(),
|
||||
reuse_fp16_shard=False)
|
||||
|
||||
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
|
||||
initial_scale=2**5,
|
||||
min_scale=1,
|
||||
growth_factor=2,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=1000,
|
||||
hysteresis=2,
|
||||
max_scale=2**32)
|
||||
|
||||
ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
|
||||
zero=dict(
|
||||
model_config=_ZERO_MODEL_CONFIG,
|
||||
optimizer_config=_ZERO_OPTIMIZER_CONFIG,
|
||||
),
|
||||
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
|
||||
|
||||
CONFIG = dict(fp16=dict(mode=None,),
|
||||
zero=dict(level=3,
|
||||
verbose=False,
|
||||
offload_optimizer_config=dict(device='cpu', pin_memory=True, buffer_count=5, fast_init=False),
|
||||
offload_param_config=dict(device='cpu',
|
||||
pin_memory=True,
|
||||
buffer_count=5,
|
||||
buffer_size=1e8,
|
||||
max_in_cpu=1e9)),
|
||||
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
|
||||
|
||||
|
||||
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
|
||||
model.train()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
if criterion:
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
loss = loss.float()
|
||||
if isinstance(model, ShardedModelV2):
|
||||
model.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
|
||||
def checkpoint_wrapper(module, enable=True):
|
||||
if enable:
|
||||
module.forward = partial(checkpoint, module.forward)
|
||||
return module
|
||||
|
||||
|
||||
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
|
||||
if loose:
|
||||
return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3)
|
||||
return torch.allclose(tensor_a, tensor_b)
|
||||
|
||||
|
||||
def check_grads(model, zero_model, loose=False):
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_grad = zero_p.grad.clone().to(p.device)
|
||||
grad = p.grad.float()
|
||||
assert grad.dtype == zero_grad.dtype
|
||||
assert allclose(grad, zero_grad, loose=loose)
|
||||
|
||||
|
||||
def check_params(model, zero_model, loose=False):
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_p = zero_p.clone().to(p.device)
|
||||
# assert p.dtype == zero_p.dtype
|
||||
assert allclose(p.float(), zero_p.float(), loose=loose), f"diff {p.float() - zero_p.float()}"
|
||||
|
||||
|
||||
def check_grads_padding(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
|
||||
# zero_grad = zero_p.grad.clone().to(p.device)
|
||||
if zero_p.colo_attr.is_replicated:
|
||||
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
|
||||
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
grad = chunks[rank].float()
|
||||
if zero_grad.size(0) > grad.size(0):
|
||||
zero_grad = zero_grad[:grad.size(0)]
|
||||
else:
|
||||
zero_grad = zero_p.colo_attr.saved_grad.payload
|
||||
grad = p.grad.to(zero_grad.dtype)
|
||||
|
||||
assert grad.dtype == zero_grad.dtype
|
||||
assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}'
|
||||
|
||||
|
||||
def check_params_padding(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_p = zero_p.clone().to(p.device)
|
||||
chunks = torch.flatten(p).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
p = chunks[rank]
|
||||
if zero_p.size(0) > p.size(0):
|
||||
zero_p = zero_p[:p.size(0)]
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose)
|
||||
|
||||
|
||||
def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False):
|
||||
rank = dist.get_rank()
|
||||
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
|
||||
if zero_p.colo_attr.param_is_sharded:
|
||||
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
|
||||
chunks = torch.flatten(p).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
p = chunks[rank].float()
|
||||
if zero_p.size(0) > p.size(0):
|
||||
zero_p = zero_p[:p.size(0)]
|
||||
else:
|
||||
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device)
|
||||
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'
|
76
tests/test_zero/test_found_inf.py
Normal file
76
tests/test_zero/test_found_inf.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_on_exception
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import BucketTensorShardStrategy
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_zero.test_sharded_optim_v2 import _run_step
|
||||
|
||||
from common import CONFIG
|
||||
|
||||
|
||||
@parameterize("cpu_offload", [True, False])
|
||||
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
|
||||
@parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
|
||||
def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio):
|
||||
test_models = ['repeated_computed_layers']
|
||||
shard_strategy = shard_strategy_class()
|
||||
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(
|
||||
zero_model,
|
||||
shard_strategy,
|
||||
offload_config=dict(device='cpu') if cpu_offload else None,
|
||||
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
|
||||
reuse_fp16_shard=True,
|
||||
)
|
||||
|
||||
sharded_optim = HybridAdam(zero_model.parameters(), lr=1e-3)
|
||||
sharded_optim = ShardedOptimizerV2(zero_model,
|
||||
sharded_optim,
|
||||
cpu_offload=cpu_offload,
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio)
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
if i > 1:
|
||||
break
|
||||
assert zero_model.overflow_counter == 0
|
||||
data, label = data.cuda(), label.cuda()
|
||||
_run_step(zero_model, sharded_optim, data, label, criterion, False)
|
||||
for param in zero_model.parameters():
|
||||
assert not has_inf_or_nan(param.colo_attr.sharded_data_tensor.payload)
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_test_found_inf()
|
||||
|
||||
|
||||
# use_cpuadam = True can be used with cpu_offload = False
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_found_inf(world_size):
|
||||
run_func = partial(_run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_found_inf(world_size=2)
|
74
tests/test_zero/test_init_context.py
Normal file
74
tests/test_zero/test_init_context.py
Normal file
@@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.testing import parameterize, rerun_on_exception
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
colo_model_mem_usage
|
||||
from colossalai.utils.memory import colo_device_memory_used
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
from common import CONFIG
|
||||
|
||||
|
||||
@parameterize("init_device_type", ['cpu', 'cuda'])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_model_test(init_device_type, shard_strategy_class):
|
||||
logger = get_dist_logger("test_zero_init")
|
||||
|
||||
for get_components_func in non_distributed_component_funcs:
|
||||
model_builder, _, _, _, _ = get_components_func()
|
||||
if init_device_type == 'cuda':
|
||||
init_device = get_current_device()
|
||||
elif init_device_type == 'cpu':
|
||||
init_device = torch.device("cpu")
|
||||
else:
|
||||
continue
|
||||
|
||||
model_numel_tensor = torch.zeros(1, dtype=torch.int)
|
||||
with ZeroInitContext(target_device=init_device,
|
||||
shard_strategy=shard_strategy_class(),
|
||||
shard_param=True,
|
||||
model_numel_tensor=model_numel_tensor):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
for param in model.parameters():
|
||||
assert hasattr(param, 'colo_attr')
|
||||
assert param.colo_attr.sharded_data_tensor.dtype == torch.half
|
||||
assert param.colo_attr.sharded_data_tensor.is_sharded
|
||||
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
|
||||
cuda_mem_use, _ = colo_model_mem_usage(model)
|
||||
model_data_cuda_mem_MB = cuda_mem_use / 1e6
|
||||
logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
|
||||
sys_cuda_mem_MB = colo_device_memory_used(get_current_device()) / 1e6
|
||||
logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0])
|
||||
logger.info(f"Model Number Parameter {model_numel_tensor.numpy()[0]/1e6} M", ranks=[0])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_model_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 4])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_zero_init_context(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_init_context(4)
|
69
tests/test_zero/test_shard_model_v2.py
Normal file
69
tests/test_zero/test_shard_model_v2.py
Normal file
@@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import parameterize, rerun_on_exception
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from common import CONFIG, check_grads_padding, run_fwd_bwd
|
||||
|
||||
|
||||
@parameterize("enable_autocast", [True])
|
||||
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
|
||||
def run_model_test(enable_autocast, shard_strategy_class):
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module']
|
||||
shard_strategy = shard_strategy_class()
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, _, criterion = get_components_func()
|
||||
|
||||
with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True)
|
||||
|
||||
model = model_builder(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
model = model.cuda()
|
||||
|
||||
model = DDP(model)
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
if i > 5:
|
||||
break
|
||||
|
||||
data, label = cast_tensor_to_fp16(data).cuda(), label.cuda()
|
||||
run_fwd_bwd(model, data, label, criterion, enable_autocast)
|
||||
run_fwd_bwd(zero_model, data, label, criterion, enable_autocast)
|
||||
|
||||
check_grads_padding(model, zero_model, loose=True)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_model_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_shard_model_v2(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_shard_model_v2(world_size=2)
|
96
tests/test_zero/test_shard_param.py
Normal file
96
tests/test_zero/test_shard_param.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from tests.test_zero.common import CONFIG, allclose
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||
|
||||
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_shard_tensor_with_strategy(shard_strategy_class, world_size):
|
||||
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
|
||||
assert list(t.origin_shape) == [world_size * 2, 3]
|
||||
assert list(t.shape) == [world_size * 2, 3]
|
||||
|
||||
shard_strategy = shard_strategy_class()
|
||||
|
||||
# test shard strategy
|
||||
shard_strategy.shard([t])
|
||||
assert list(t.shape) == [6], f"{list(t.shape)} vs 6"
|
||||
shard_strategy.gather([t])
|
||||
assert list(t.shape) == [world_size * 2, 3], f"{list(t.shape)} vs {[world_size * 2, 3]}"
|
||||
|
||||
|
||||
def _run_shard_tensor(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_shard_tensor_with_strategy(world_size=world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_shard_tensor(world_size):
|
||||
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
def _run_shard_param_v2(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
param = torch.nn.Parameter(torch.randn(2, 3))
|
||||
param_ref = deepcopy(param)
|
||||
sparam = ShardedParamV2(param=param)
|
||||
|
||||
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
|
||||
|
||||
# Test get memory usage
|
||||
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
|
||||
|
||||
sparam.remove_torch_payload()
|
||||
assert (param.data.numel() == 0)
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
# 4 is size of dummy tensor of param.data
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
|
||||
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
|
||||
sparam.remove_torch_payload()
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
assert cuda_mem_use == 0
|
||||
|
||||
# append a grad to torch param
|
||||
param.data = sparam.sharded_data_tensor.payload
|
||||
param.grad = torch.randn(2, 3)
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 2 * 3 * 4, f"cpu_mem_use {cpu_mem_use}"
|
||||
assert cuda_mem_use == 0
|
||||
|
||||
# reuse torch grad for sparam
|
||||
sparam.saved_grad = StatefulTensor(param.grad)
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
assert cuda_mem_use == 0
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_shard_param_v2(world_size):
|
||||
run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_shard_tensor(2)
|
||||
test_shard_param_v2(2)
|
117
tests/test_zero/test_sharded_optim_v2.py
Normal file
117
tests/test_zero/test_sharded_optim_v2.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.nn.optimizer import CPUAdam
|
||||
from colossalai.testing import parameterize, rerun_on_exception
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from common import CONFIG, check_sharded_model_params
|
||||
|
||||
|
||||
def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
if criterion:
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
|
||||
loss = loss.float()
|
||||
if isinstance(model, ShardedModelV2):
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
@parameterize("cpu_offload", [True, False])
|
||||
@parameterize("use_cpuadam", [True, False])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
@parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
|
||||
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio):
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module']
|
||||
shard_strategy = shard_strategy_class()
|
||||
|
||||
if use_cpuadam and cpu_offload is False:
|
||||
return
|
||||
if gpu_margin_mem_ratio > 0.0 and not (cpu_offload and use_cpuadam):
|
||||
return
|
||||
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(
|
||||
zero_model,
|
||||
shard_strategy,
|
||||
offload_config=dict(device='cpu') if cpu_offload else None,
|
||||
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
|
||||
reuse_fp16_shard=use_cpuadam,
|
||||
)
|
||||
|
||||
model = model_builder(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
model = model.cuda().float()
|
||||
|
||||
if use_cpuadam:
|
||||
optimizer_class = CPUAdam
|
||||
optim = optimizer_class(model.parameters(), lr=1e-3)
|
||||
sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3)
|
||||
sharded_optim = ShardedOptimizerV2(zero_model,
|
||||
sharded_optim,
|
||||
cpu_offload=cpu_offload,
|
||||
initial_scale=2**5,
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio)
|
||||
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False)
|
||||
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
|
||||
if dist.get_world_size() > 1:
|
||||
apex_model = DDP(apex_model)
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
if i > 5:
|
||||
break
|
||||
data, label = data.cuda(), label.cuda()
|
||||
_run_step(apex_model, apex_optimizer, data, label, criterion, False)
|
||||
_run_step(zero_model, sharded_optim, data, label, criterion, False)
|
||||
check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam)
|
||||
for param in model.parameters():
|
||||
assert not has_inf_or_nan(param)
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_test_sharded_optim_v2()
|
||||
|
||||
|
||||
# use_cpuadam = True can be used with cpu_offload = False
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_sharded_optim_v2(world_size):
|
||||
run_func = partial(_run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sharded_optim_v2(world_size=2)
|
92
tests/test_zero/test_sharded_optim_with_sync_bn.py
Normal file
92
tests/test_zero/test_sharded_optim_with_sync_bn.py
Normal file
@@ -0,0 +1,92 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
from torchvision.models import resnet50
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
# this test only runs on resnet18
|
||||
# as this model has sync batch normalization
|
||||
# need to configure cudnn deterministic so that
|
||||
# randomness of convolution layers will be disabled
|
||||
zero_config = dict(model_config=dict(shard_strategy=TensorShardStrategy()))
|
||||
colossalai.launch(config=dict(zero=zero_config, cudnn_determinstic=True, cudnn_benchmark=False),
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
with ZeroInitContext(target_device=torch.cuda.current_device(),
|
||||
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
||||
shard_param=True):
|
||||
model = resnet50()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
engine, *args = colossalai.initialize(model, optimizer, criterion)
|
||||
|
||||
# train for dummy iterations
|
||||
engine.train()
|
||||
for _ in range(2):
|
||||
data = torch.rand(4, 3, 128, 128).cuda().half()
|
||||
label = torch.randint(0, 10, size=(4,)).cuda()
|
||||
engine.zero_grad()
|
||||
out = engine(data)
|
||||
loss = engine.criterion(out, label)
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
|
||||
# test
|
||||
# need to make sure the batch norm stats are synchronized
|
||||
# so that given the same input, the model will produce the same
|
||||
# output on different ranks
|
||||
engine.eval()
|
||||
data = torch.rand(4, 3, 128, 128).cuda().half()
|
||||
dist.broadcast(data, src=0, group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
# predict
|
||||
out = engine(data)
|
||||
|
||||
# test if results are equal
|
||||
tensor_list = [torch.empty_like(out) for _ in range(world_size - 1)]
|
||||
tensor_list.insert(rank, out)
|
||||
dist.all_gather(tensor_list=tensor_list, tensor=out, group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
assert torch.all(tensor_list[0] == tensor_list[1]), \
|
||||
'expected the output from different ranks to be the same, but got different values'
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_sharded_optim_with_sync_bn():
|
||||
"""
|
||||
This test is to make sure that buffers are synchronized between ranks
|
||||
when using ZeRO. An example of module buffer is the running stats of
|
||||
BatchNormalization layer, i.e. mean and var.
|
||||
|
||||
If the buffers are not synchronized, the model will produce different
|
||||
output even though the input and parameters are the same. This is not
|
||||
wanted if we are doing predictions.
|
||||
|
||||
"""
|
||||
world_size = 2
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sharded_optim_with_sync_bn()
|
59
tests/test_zero/test_state_dict.py
Normal file
59
tests/test_zero/test_state_dict.py
Normal file
@@ -0,0 +1,59 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import parameterize, rerun_on_exception
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
from common import CONFIG
|
||||
|
||||
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_zero_state_dict(shard_strategy_class):
|
||||
test_models = ['repeated_computed_layers', 'resnet18']
|
||||
shard_strategy = shard_strategy_class()
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
|
||||
|
||||
with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(zero_model, shard_strategy)
|
||||
|
||||
model = model_builder(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
model = model.cuda()
|
||||
|
||||
zero_state_dict = zero_model.state_dict()
|
||||
for key, val in model.state_dict().items():
|
||||
assert torch.equal(val, zero_state_dict[key])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_zero_state_dict()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_zero_state_dict(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_state_dict(2)
|
112
tests/test_zero/test_stateful_tensor_mgr.py
Normal file
112
tests/test_zero/test_stateful_tensor_mgr.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
||||
from colossalai.zero.utils import StatefulTensorMgr
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import List
|
||||
from functools import partial
|
||||
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# each parameter is 512 MB
|
||||
self.p0 = Parameter(torch.empty(1024, 1024, 128))
|
||||
self.p1 = Parameter(torch.empty(1024, 1024, 128))
|
||||
self.p2 = Parameter(torch.empty(1024, 1024, 128))
|
||||
|
||||
|
||||
def run_stm():
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
fraction = (1.4 * 1024**3) / cuda_capacity
|
||||
# limit max memory to 1.4GB
|
||||
# which means only 2 parameters can be on CUDA
|
||||
colo_set_process_memory_fraction(fraction)
|
||||
model = Net()
|
||||
for p in model.parameters():
|
||||
p.colo_attr = ShardedParamV2(p, rm_torch_payload=True)
|
||||
GLOBAL_MODEL_DATA_TRACER.register_model(model)
|
||||
mem_collector = MemStatsCollector()
|
||||
stateful_tensor_mgr = StatefulTensorMgr(mem_collector)
|
||||
for p in model.parameters():
|
||||
stateful_tensor_mgr.register_stateful_param(p.colo_attr)
|
||||
|
||||
mem_collector.start_collection()
|
||||
# Compute order: 0 1 2 0 1
|
||||
# warmup
|
||||
# use naive eviction strategy
|
||||
apply_adjust(model, model.p0, [model.p0], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
apply_adjust(model, model.p2, [model.p1, model.p2], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
mem_collector.finish_collection()
|
||||
stateful_tensor_mgr.reset()
|
||||
|
||||
# warmup done
|
||||
# use OPT-like eviction strategy
|
||||
apply_adjust(model, model.p0, [model.p0, model.p1], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
apply_adjust(model, model.p2, [model.p0, model.p2], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
|
||||
|
||||
def apply_adjust(model: torch.nn.Module, compute_param: Parameter, cuda_param_after_adjust: List[Parameter],
|
||||
stateful_tensor_mgr: StatefulTensorMgr):
|
||||
compute_param.colo_attr._sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
for p in model.parameters():
|
||||
if p is not compute_param and p.colo_attr._sharded_data_tensor.state != TensorState.HOLD:
|
||||
p.colo_attr._sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||
stateful_tensor_mgr.adjust_layout()
|
||||
print_stats(model)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
cuda_param_after_adjust = [hash(p) for p in cuda_param_after_adjust]
|
||||
for n, p in model.named_parameters():
|
||||
if hash(p) in cuda_param_after_adjust:
|
||||
assert p.colo_attr._sharded_data_tensor.device == device, f'{n} {p.colo_attr._sharded_data_tensor.device} vs {device}'
|
||||
else:
|
||||
assert p.colo_attr._sharded_data_tensor.device == torch.device('cpu')
|
||||
|
||||
|
||||
def print_stats(model: torch.nn.Module):
|
||||
msgs = []
|
||||
for n, p in model.named_parameters():
|
||||
msgs.append(f'{n}: {p.colo_attr._sharded_data_tensor.state}({p.colo_attr._sharded_data_tensor.device})')
|
||||
print(f'[ {", ".join(msgs)} ]')
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_stm()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_stateful_tensor_manager(world_size=1):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_stateful_tensor_manager()
|
93
tests/test_zero/test_tensor_utils.py
Normal file
93
tests/test_zero/test_tensor_utils.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import pytest
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage, colo_model_data_tensor_move,
|
||||
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
|
||||
colo_model_tensor_clone)
|
||||
from colossalai.utils import free_port
|
||||
|
||||
import torch
|
||||
|
||||
from functools import partial
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
|
||||
def _run_colo_tensor_mem_usage():
|
||||
for i in range(1):
|
||||
if i == 1:
|
||||
t1 = StatefulTensor(torch.randn(2, 2))
|
||||
t2 = StatefulTensor(torch.randn(4, 4))
|
||||
c1, g1 = colo_tensor_mem_usage(t1)
|
||||
c2, g2 = colo_tensor_mem_usage(t2)
|
||||
assert c1 * 4 == c2
|
||||
assert g1 * 4 == g2
|
||||
else:
|
||||
t1 = torch.randn(2, 2)
|
||||
t2 = torch.randn(4, 4)
|
||||
c1, g1 = colo_tensor_mem_usage(t1)
|
||||
c2, g2 = colo_tensor_mem_usage(t2)
|
||||
assert c1 * 4 == c2
|
||||
assert g1 * 4 == g2
|
||||
|
||||
|
||||
def _run_colo_model_data_tensor_move_inline():
|
||||
for t in [StatefulTensor(torch.randn(2, 3)), torch.randn(2, 3)]:
|
||||
colo_model_data_tensor_move_inline(t, get_current_device())
|
||||
assert t.device == get_current_device()
|
||||
|
||||
|
||||
def _run_colo_model_data_tensor_move():
|
||||
for t in [(StatefulTensor(torch.ones(2, 3)), StatefulTensor(torch.zeros(2, 3).to(get_current_device()))),
|
||||
(torch.ones(2, 3), torch.zeros(2, 3).to(get_current_device()))]:
|
||||
cpu_t, cuda_t = t
|
||||
colo_model_data_tensor_move(cpu_t, cuda_t)
|
||||
assert cuda_t.device == get_current_device()
|
||||
|
||||
|
||||
def _run_colo_model_data_move_to_cpu():
|
||||
for t in [StatefulTensor(torch.randn(2, 2)), torch.randn(4, 4)]:
|
||||
colo_model_data_move_to_cpu(t)
|
||||
assert t.device == torch.device("cpu")
|
||||
|
||||
|
||||
def _run_colo_model_tensor_clone():
|
||||
for t in [
|
||||
StatefulTensor(torch.randn(2, 2).cuda(torch.cuda.current_device())),
|
||||
torch.randn(4, 4).cuda(torch.cuda.current_device())
|
||||
]:
|
||||
if issubclass(type(t), StatefulTensor):
|
||||
assert t.payload.device == get_current_device()
|
||||
else:
|
||||
assert t.device == get_current_device()
|
||||
p = colo_model_tensor_clone(t, get_current_device())
|
||||
assert p.device == get_current_device()
|
||||
for i in range(2):
|
||||
for j in range(2):
|
||||
if issubclass(type(t), StatefulTensor):
|
||||
assert t.payload.device == p.device
|
||||
assert t.payload[i][j] == p[i][j]
|
||||
else:
|
||||
assert t.device == p.device
|
||||
assert t[i][j] == p[i][j]
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
_run_colo_tensor_mem_usage()
|
||||
_run_colo_model_data_tensor_move_inline()
|
||||
_run_colo_model_data_tensor_move()
|
||||
_run_colo_model_data_move_to_cpu()
|
||||
_run_colo_model_tensor_clone()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4, 5])
|
||||
def test_zero_tensor_utils(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_tensor_utils(world_size=2)
|
114
tests/test_zero/test_zero_engine.py
Normal file
114
tests/test_zero/test_zero_engine.py
Normal file
@@ -0,0 +1,114 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, parallel_config):
|
||||
colossalai.launch(config=parallel_config,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
with ZeroInitContext(target_device=torch.cuda.current_device(),
|
||||
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
||||
shard_param=True):
|
||||
colo_model = model_builder(checkpoint=True)
|
||||
|
||||
colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3)
|
||||
engine, train_dataloader, _, _ = colossalai.initialize(colo_model,
|
||||
optimizer=colo_optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
torch_model = model_builder(checkpoint=True).half()
|
||||
col_model_deepcopy(engine.model, torch_model)
|
||||
torch_model = torch_model.cuda().float()
|
||||
|
||||
engine.train()
|
||||
torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3)
|
||||
|
||||
if dist.get_world_size() > 1:
|
||||
torch_model = DDP(torch_model)
|
||||
|
||||
i = 0
|
||||
for data, label in train_dataloader:
|
||||
if i > 4:
|
||||
break
|
||||
|
||||
data, label = data.cuda(), label.cuda()
|
||||
|
||||
engine.zero_grad()
|
||||
torch_optimizer.zero_grad()
|
||||
|
||||
if criterion:
|
||||
output = engine(data)
|
||||
loss = engine.criterion(output, label)
|
||||
|
||||
torch_output = torch_model(data)
|
||||
torch_loss = engine.criterion(torch_output, label)
|
||||
else:
|
||||
loss = engine(data, label)
|
||||
torch_loss = torch_model(data, label)
|
||||
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
|
||||
torch_loss.backward()
|
||||
|
||||
for param in torch_model.parameters():
|
||||
if param.grad is not None:
|
||||
assert not has_inf_or_nan(param.grad)
|
||||
|
||||
torch_optimizer.step()
|
||||
i += 1
|
||||
|
||||
if parallel_config == MP_PARALLEL_CONFIG:
|
||||
check_params(torch_model, colo_model, loose=True)
|
||||
elif parallel_config == ZERO_PARALLEL_CONFIG:
|
||||
check_sharded_model_params(torch_model, colo_model, loose=True)
|
||||
|
||||
|
||||
# FIXME: enable this test in next PR
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_mp_engine(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_zero_engine(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_engine(world_size=4)
|
Reference in New Issue
Block a user