mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[zero] Update initialize for ZeRO (#458)
* polish code * shard strategy receive pg in shard() / gather() * update zero engine * polish code
This commit is contained in:
@@ -16,7 +16,7 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
|
||||
offload_config=None,
|
||||
gradient_predivide_factor=1.0,
|
||||
use_memory_tracer=False,
|
||||
shard_strategy=TensorShardStrategy)
|
||||
shard_strategy=TensorShardStrategy())
|
||||
|
||||
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
|
||||
initial_scale=2**5,
|
||||
@@ -25,8 +25,7 @@ _ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=1000,
|
||||
hysteresis=2,
|
||||
max_scale=2**32,
|
||||
lr=1e-3)
|
||||
max_scale=2**32)
|
||||
|
||||
ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
|
||||
zero=dict(
|
||||
|
@@ -7,26 +7,27 @@ import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
from common import CONFIG
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
|
||||
@parameterize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
|
||||
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_model_test(init_device, shard_strategy):
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_model_test(init_device, shard_strategy_class):
|
||||
for get_components_func in non_distributed_component_funcs:
|
||||
model_builder, _, _, _, _ = get_components_func()
|
||||
model_numel_tensor = torch.zeros(1, dtype=torch.int)
|
||||
with ZeroInitContext(convert_fp16=True,
|
||||
target_device=init_device,
|
||||
shard_strategy=shard_strategy(),
|
||||
shard_strategy=shard_strategy_class(),
|
||||
shard_param=True,
|
||||
model_numel_tensor=model_numel_tensor):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
@@ -9,22 +9,22 @@ import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from colossalai.zero.sharded_param import ShardedParam, ShardedTensor
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_zero_data_parallel.common import CONFIG, allclose
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
|
||||
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_shard_tensor_with_strategy(shard_strategy, world_size):
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_shard_tensor_with_strategy(shard_strategy_class, world_size):
|
||||
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
|
||||
assert list(t.origin_shape) == [world_size * 2, 3]
|
||||
assert list(t.shape) == [world_size * 2, 3]
|
||||
|
||||
shard_strategy = shard_strategy(process_group=None)
|
||||
shard_strategy = shard_strategy_class()
|
||||
|
||||
# test shard strategy
|
||||
shard_strategy.shard([t])
|
||||
|
@@ -11,6 +11,8 @@ import torch.multiprocessing as mp
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
from torchvision.models import resnet50
|
||||
|
||||
|
||||
@@ -19,7 +21,7 @@ def run_dist(rank, world_size, port):
|
||||
# as this model has sync batch normalization
|
||||
# need to configure cudnn deterministic so that
|
||||
# randomness of convolution layers will be disabled
|
||||
zero_config = dict(optimizer_config=dict(optimizer_class=torch.optim.Adam, lr=1e-3))
|
||||
zero_config = dict(model_config=dict(shard_strategy=TensorShardStrategy()))
|
||||
colossalai.launch(config=dict(zero=zero_config, cudnn_determinstic=True, cudnn_benchmark=False),
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
@@ -27,7 +29,11 @@ def run_dist(rank, world_size, port):
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
model = resnet50()
|
||||
with ZeroInitContext(convert_fp16=True,
|
||||
target_device=torch.cuda.current_device(),
|
||||
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
||||
shard_param=True):
|
||||
model = resnet50()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
@@ -64,10 +70,6 @@ def run_dist(rank, world_size, port):
|
||||
'expected the output from different ranks to be the same, but got different values'
|
||||
|
||||
|
||||
# FIXME: enable this test in next PR
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.dist
|
||||
def test_sharded_optim_with_sync_bn():
|
||||
"""
|
||||
|
@@ -8,7 +8,6 @@ import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
@@ -17,8 +16,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params,
|
||||
check_sharded_params_padding)
|
||||
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_params_padding)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, parallel_config):
|
||||
@@ -35,18 +33,19 @@ def run_dist(rank, world_size, port, parallel_config):
|
||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
with ZeroInitContext(convert_fp16=hasattr(gpc.config, 'fp16'),
|
||||
target_device=torch.cuda.current_device(),
|
||||
shard_strategy=gpc.config.zero.model_config.shared_strategy(
|
||||
gpc.get_group(ParallelMode.DATA)),
|
||||
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
||||
shard_param=True):
|
||||
colo_model = model_builder(checkpoint=True)
|
||||
|
||||
torch_model = model_builder(checkpoint=True).half()
|
||||
col_model_deepcopy(colo_model, torch_model)
|
||||
torch_model = torch_model.cuda().float()
|
||||
colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3)
|
||||
engine, train_dataloader, _, _ = colossalai.initialize(colo_model,
|
||||
optimizer=optimizer_class,
|
||||
optimizer=colo_optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
torch_model = model_builder(checkpoint=True).half()
|
||||
col_model_deepcopy(engine.model, torch_model)
|
||||
torch_model = torch_model.cuda().float()
|
||||
|
||||
engine.train()
|
||||
torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3)
|
||||
|
||||
@@ -102,7 +101,6 @@ def test_mp_engine(world_size):
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
def test_zero_engine(world_size):
|
||||
|
Reference in New Issue
Block a user