From 20da6e48c8ebd7e17f68308b0bc45e977ff80471 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Fri, 8 Jul 2022 16:33:13 +0800 Subject: [PATCH] [checkpoint] save sharded optimizer states (#1237) --- colossalai/tensor/process_group.py | 21 +++++++++---------- .../utils/checkpoint/module_checkpoint.py | 7 ++++++- tests/test_utils/test_colo_checkpoint.py | 19 ++++++++++------- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/colossalai/tensor/process_group.py b/colossalai/tensor/process_group.py index 1482f02db..9a413ce33 100644 --- a/colossalai/tensor/process_group.py +++ b/colossalai/tensor/process_group.py @@ -93,20 +93,17 @@ class ProcessGroup: if idx // self._tp_degree == self._rank_idx // self._tp_degree: self._tp_rank_list.append(rank_id) - self._tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl') - self._dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl') - self._has_cpu_groups = False - self._cpu_dp_process_group = None - self._cpu_tp_process_group = None + PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl') + PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl') def set_cpu_groups(self): if self.has_cpu_groups: return self.logger.info( f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}') - self._cpu_tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo') - self._cpu_dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo') + PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo') + PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo') @property def has_cpu_groups(self): @@ -152,13 +149,15 @@ class ProcessGroup: return len(self._tp_rank_list) def dp_process_group(self): - return self._dp_process_group + # return self._dp_process_group + return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl') def tp_process_group(self): - return self._tp_process_group + # return self._tp_process_group + return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl') def cpu_dp_process_group(self): - return self._cpu_dp_process_group + return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo') def cpu_tp_process_group(self): - return self._cpu_tp_process_group + return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo') diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/utils/checkpoint/module_checkpoint.py index c622edc99..564ccf4b8 100644 --- a/colossalai/utils/checkpoint/module_checkpoint.py +++ b/colossalai/utils/checkpoint/module_checkpoint.py @@ -32,10 +32,15 @@ def save_checkpoint(dire: str, optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. """ - model_state = {'epoch': epoch, 'model': colo_state_dict(model, state_dict_func=nn.Module.state_dict)} + model_state = {'epoch': epoch, 'model': model.state_dict()} if dist.get_rank() == 0: torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch)) + + # TODO() If use tensor parallelism, optim_states contain SHARD ColoTensors. + # 1. convert SHARD ColoTensor to REPLICATE + # only rank 0 saves the REPLICATE tensors. optim_state = {'epoch': epoch, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict()} + torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank())) diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index 3aaec746a..d83db1fbc 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -126,6 +126,9 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg): model_reload = ColoDDP(model_reload, pg) model_ref = ColoDDP(model_ref, pg) + init_spec_func(model, pg) + init_spec_func(model_ref, pg) + criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) optimizer_reload = torch.optim.Adam(model_reload.parameters(), @@ -135,23 +138,21 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg): weight_decay=0) optimizer_ref = torch.optim.Adam(model_ref.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + lr_scheduler = None if test_scheduler == 'colossalai_cosine_warmup': lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=num_epoch, warmup_steps=warmup_epoch) lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload, total_steps=num_epoch, warmup_steps=warmup_epoch) - elif test_scheduler == 'torch_cosine': lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=num_epoch) lr_scheduler_reload = CosineAnnealingLR(optimizer=optimizer_reload, T_max=num_epoch) - elif test_scheduler == 'torch_lambda': lr_lambda = lambda epoch: 0.95 lr_scheduler = MultiplicativeLR(optimizer=optimizer, lr_lambda=lr_lambda) lr_scheduler_reload = MultiplicativeLR(optimizer=optimizer_reload, lr_lambda=lr_lambda) - - init_spec_func(model, pg) - init_spec_func(model_ref, pg) + else: + raise TypeError(f"{test_scheduler} is invalid") for epoch in range(0, num_epoch): if epoch <= test_epoch: @@ -212,7 +213,11 @@ def run_dist(rank, world_size, port, use_ddp, test_epoch, test_scheduler): config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') pg = ProcessGroup(tp_degree=world_size) - run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, test_epoch, test_scheduler, pg) + run_checkpoint(init_1d_row_for_linear_weight_spec, + use_ddp, + test_epoch=test_epoch, + test_scheduler=test_scheduler, + pg=pg) @pytest.mark.skip @@ -236,4 +241,4 @@ def test_checkpoint(world_size, use_ddp, test_epoch, test_scheduler): if __name__ == '__main__': - test_checkpoint(4, True, 1, 1) + test_checkpoint(4, True, 1, "colossalai_cosine_warmup")