From 9e1daa63d2b539a1037246c0723668602916c23f Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 24 Jun 2022 18:05:16 +0800 Subject: [PATCH] [zero] sharded optim supports loading local state dict (#1170) * sharded optim supports loading local state dict * polish code * add unit test --- .../zero/sharded_optim/sharded_optim_v2.py | 14 ++- .../test_sharded_optim_state_dict.py | 93 +++++++++++++++++++ 2 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 tests/test_zero/test_sharded_optim_state_dict.py diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 95ab70708..63545f11e 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -199,7 +199,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self._logger.debug( f"Before step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!", ranks=[0]) - ret = self.optim.step(*args, **kwargs) if self._verbose: @@ -289,6 +288,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer): colo_model_data_tensor_move_inline(p.colo_attr.saved_grad, torch.cuda.current_device()) p.colo_attr.offload_grad = False fp32_shards_used_cuda_margin_mem += shard_mem + state = self.optim.state[p] + for k, v in state.items(): + if isinstance(v, Tensor): + state[k] = v.cuda() def _prepare_grads(self): for group in self.optim.param_groups: @@ -353,3 +356,12 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) self.master_params[p].trans_state(TensorState.HOLD) + + def load_state_dict(self, *args, **kwargs): + super().load_state_dict(*args, **kwargs) + for group in self.optim.param_groups: + for p in group['params']: + state = self.optim.state[p] + for k, v in state.items(): + if isinstance(v, Tensor): + state[k] = v.to(dtype=self.master_params[p].dtype, device=self.master_params[p].device) diff --git a/tests/test_zero/test_sharded_optim_state_dict.py b/tests/test_zero/test_sharded_optim_state_dict.py new file mode 100644 index 000000000..859a6ae41 --- /dev/null +++ b/tests/test_zero/test_sharded_optim_state_dict.py @@ -0,0 +1,93 @@ +import pytest +import colossalai +import torch +from colossalai.context.parallel_mode import ParallelMode +import torch.multiprocessing as mp +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from colossalai.core import global_context as gpc +from functools import partial +from tests.test_tensor._utils import set_seed +from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.testing import parameterize +from colossalai.nn.optimizer import HybridAdam +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import TensorShardStrategy +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_optim import ShardedOptimizerV2 + + +def init_zero(model_builder, placement_policy): + device = get_current_device() if placement_policy == 'cuda' else torch.device('cpu') + shard_strategy = TensorShardStrategy() + with ZeroInitContext(target_device=device, shard_strategy=shard_strategy, shard_param=True): + model = model_builder() + model = ShardedModelV2( + model, + shard_strategy, + tensor_placement_policy=placement_policy, + reuse_fp16_shard=True, + ) + optim = HybridAdam(model.parameters(), lr=1e-3) + optim = ShardedOptimizerV2(model, optim, initial_scale=32) + return model, optim + + +def run_step(model, optim, criterion, data, label): + optim.zero_grad() + logits = model(data) + loss = criterion(logits, label) + optim.backward(loss) + optim.step() + + +def check_state_dict_eq(state_dict, other): + for p, state in state_dict['state'].items(): + other_state = other['state'][p] + for k, v in state.items(): + if isinstance(v, torch.Tensor): + assert torch.allclose(v, other_state[k], atol=1e-3), f'{v} vs {other_state[k]}' + else: + assert v == other_state[k] + + +@parameterize('placement_policy', ['cuda', 'cpu']) +def run_nested_model(placement_policy): + get_components_func = non_distributed_component_funcs.get_callable('simple_net') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + set_seed(42) + model, optim = init_zero(model_builder, placement_policy) + set_seed(42) + model_copy, optim_copy = init_zero(model_builder, placement_policy) + + model.train() + model_copy.train() + set_seed(gpc.get_local_rank(ParallelMode.DATA)) + data_iter = iter(train_dataloader) + + data, label = map(lambda x: x.cuda(), next(data_iter)) + run_step(model, optim, criterion, data, label) + optim_copy.load_state_dict(optim.state_dict()) + check_state_dict_eq(optim.state_dict(), optim_copy.state_dict()) + + data, label = map(lambda x: x.cuda(), next(data_iter)) + run_step(model_copy, optim_copy, criterion, data, label) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_nested_model() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_sharded_optim_state_dist(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_sharded_optim_state_dist(2)