From 370f567e7d58713de8a8ae0248d5320753c85e80 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Mon, 14 Mar 2022 20:48:41 +0800 Subject: [PATCH] [zero] new interface for ShardedOptimv2 (#406) --- .../zero/sharded_optim/sharded_optim_v2.py | 37 ++++++++++++++++--- tests/components_to_test/bert.py | 5 +-- tests/components_to_test/nested_model.py | 5 +-- .../repeated_computed_layer.py | 5 +-- tests/components_to_test/resnet.py | 5 +-- tests/test_engine/test_engine.py | 6 +-- .../test_trainer_with_non_pipe_schedule.py | 4 +- .../test_sharded_optim_v2.py | 12 +++--- .../test_sharded_optim_v2_with_cpu_adam.py | 7 ++-- 9 files changed, 51 insertions(+), 35 deletions(-) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index b9be80fed..47c0d26b7 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict, Optional +from typing import Callable, Dict, Optional, Union import torch import torch.distributed as dist @@ -15,7 +15,7 @@ from torch import Tensor from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter from torch.optim import Optimizer - +from typing import Type, Any from ._utils import has_inf_or_nan @@ -27,8 +27,8 @@ class OptimState(Enum): class ShardedOptimizerV2(ColossalaiOptimizer): def __init__(self, - optimizer: Optimizer, sharded_model: ShardedModelV2, + optimizer_class: Type[Optimizer], shard_strategy: BaseShardStrategy, cpu_offload: bool = False, initial_scale: float = 2**32, @@ -39,9 +39,34 @@ class ShardedOptimizerV2(ColossalaiOptimizer): hysteresis: float = 2, max_scale: int = 2**32, dp_process_group: Optional[ProcessGroup] = None, - mp_process_group: Optional[ProcessGroup] = None) -> None: + mp_process_group: Optional[ProcessGroup] = None, + **defaults: Any) -> None: + """ + :param sharded_model: A sharded model initialized by class ShardedModelV2 + :type sharded_model: sharded_model + + :param optimizer_class: A type of Optimizer + :type optimizer_class: Type[Optimizer] + + :param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters. + :type shard_strategy: BaseShardStrategy + + :param cpu_offload: is offloading the optimizer states to CPU. + :type cpu_offload: bool + + :param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters. + :type shard_strategy: BaseShardStrategy + :**defaults: any trailing arguments, which are forwarded to the local optimizer. + :type defaults: dict() + """ assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' - super().__init__(optimizer) + + self._optim_defaults = defaults + # initialize the M, V as zeros tensors and initialize param fp32 from sharded_model.parameters() + + self.optimizer = optimizer_class(sharded_model.parameters(), **self._optim_defaults) + + super().__init__(self.optimizer) self.shard_strategy = shard_strategy self.model: ShardedModelV2 = sharded_model if cpu_offload and not sharded_model.cpu_offload: @@ -65,7 +90,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # Store fp32 param shards self.master_params: Dict[Parameter, Tensor] = {} - for group in optimizer.param_groups: + for group in self.optimizer.param_groups: for p in group['params']: assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam' is_param_sharded = p.col_attr.data.is_sharded diff --git a/tests/components_to_test/bert.py b/tests/components_to_test/bert.py index cf543d88d..224ae5147 100644 --- a/tests/components_to_test/bert.py +++ b/tests/components_to_test/bert.py @@ -74,8 +74,5 @@ def get_training_components(): sequence_length=sequence_length, is_distrbuted=True) - def get_optim(model): - return torch.optim.Adam(model.parameters(), lr=0.001) - criterion = None - return bert_model_builder, trainloader, testloader, get_optim, criterion + return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/nested_model.py b/tests/components_to_test/nested_model.py index edf4a1a89..26bfb8ecc 100644 --- a/tests/components_to_test/nested_model.py +++ b/tests/components_to_test/nested_model.py @@ -49,8 +49,5 @@ def get_training_components(): trainloader = DummyDataLoader() testloader = DummyDataLoader() - def optim_builder(model): - return torch.optim.Adam(model.parameters(), lr=0.001) - criterion = torch.nn.CrossEntropyLoss() - return model_builder, trainloader, testloader, optim_builder, criterion + return model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/repeated_computed_layer.py b/tests/components_to_test/repeated_computed_layer.py index bc035d4b5..f70910191 100644 --- a/tests/components_to_test/repeated_computed_layer.py +++ b/tests/components_to_test/repeated_computed_layer.py @@ -43,8 +43,5 @@ def get_training_components(): trainloader = DummyDataLoader() testloader = DummyDataLoader() - def optim_builder(model): - return torch.optim.Adam(model.parameters(), lr=0.001) - criterion = torch.nn.CrossEntropyLoss() - return model_builder, trainloader, testloader, optim_builder, criterion + return model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/resnet.py b/tests/components_to_test/resnet.py index 20a4be8e2..193832ebc 100644 --- a/tests/components_to_test/resnet.py +++ b/tests/components_to_test/resnet.py @@ -29,8 +29,5 @@ def get_resnet_training_components(): trainloader = get_cifar10_dataloader(train=True) testloader = get_cifar10_dataloader(train=False) - def optim_builder(model): - return torch.optim.Adam(model.parameters(), lr=0.001) - criterion = torch.nn.CrossEntropyLoss() - return model_builder, trainloader, testloader, optim_builder, criterion + return model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/test_engine/test_engine.py b/tests/test_engine/test_engine.py index 904c3c4ea..1bcba61f3 100644 --- a/tests/test_engine/test_engine.py +++ b/tests/test_engine/test_engine.py @@ -19,11 +19,11 @@ def run_train(): # FIXME: test bert for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func() + model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() model = model_builder(checkpoint=False) engine, train_dataloader, *args = colossalai.initialize(model=model, - optimizer=optimizer_builder(model), + optimizer=optimizer_class(model.parameters(), lr=1e-3), criterion=criterion, train_dataloader=train_dataloader) @@ -84,7 +84,7 @@ def run_engine(rank, world_size, port): @pytest.mark.dist def test_engine(): - world_size = 4 + world_size = 2 run_func = partial(run_engine, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py index 9ae21cf77..d226916b5 100644 --- a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py +++ b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py @@ -25,9 +25,9 @@ def run_trainer_no_pipeline(rank, world_size, port): test_models = ['repeated_computed_layers', 'resnet18', 'nested_model'] for name in test_models: get_components_func = non_distributed_component_funcs.get_callable(name) - model_builder, train_dataloader, test_dataloader, optimizer_builder, criterion = get_components_func() + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model = model_builder() - optimizer = optimizer_builder(model) + optimizer = optimizer_class(model.parameters(), lr=1e-3) engine, train_dataloader, *_ = colossalai.initialize(model=model, optimizer=optimizer, criterion=criterion, diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2.py b/tests/test_zero_data_parallel/test_sharded_optim_v2.py index 5ecfba71a..9371cf66a 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2.py @@ -44,19 +44,21 @@ def run_dist(rank, world_size, port, cpu_offload, shard_strategy): shard_strategy = shard_strategy() for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name) - model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() + model, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model = model(checkpoint=True).cuda() zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy, offload_config=dict(device='cpu') if cpu_offload else None) if dist.get_world_size() > 1: model = DDP(model) - optim = Adam(model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(Adam(zero_model.parameters(), lr=1e-3), - zero_model, + lr = 1e-3 + optim = optimizer_class(model.parameters(), lr=lr) + sharded_optim = ShardedOptimizerV2(zero_model, + optimizer_class, shard_strategy, cpu_offload=cpu_offload, - initial_scale=2**5) + initial_scale=2**5, + lr=lr) for i, (data, label) in enumerate(train_dataloader): if i > 2: break diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py b/tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py index d5daaafcc..942b46723 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py @@ -59,11 +59,12 @@ def run_dist(rank, world_size, port, shard_strategy): if dist.get_world_size() > 1: model = DDP(model) optim = Adam(model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(CPUAdam(zero_model.parameters(), lr=1e-3), - zero_model, + sharded_optim = ShardedOptimizerV2(zero_model, + CPUAdam, shard_strategy, initial_scale=2**5, - cpu_offload=True) + cpu_offload=True, + lr=1e-3) for i, (data, label) in enumerate(train_dataloader): if i > 2: break