diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index 8161076de..3ad15a436 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -13,6 +13,6 @@ from . import distspec __all__ = [ 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter', - 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState', 'ProcessGroup', - 'ColoTensorSpec', 'TensorSpec', 'ShardSpec', 'ReplicaSpec' + 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', + 'ShardSpec', 'ReplicaSpec' ] diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/utils/checkpoint/module_checkpoint.py index 8e9654e6f..c59d1ecf3 100644 --- a/colossalai/utils/checkpoint/module_checkpoint.py +++ b/colossalai/utils/checkpoint/module_checkpoint.py @@ -46,28 +46,29 @@ def save_checkpoint(dire: str, # synchronize all the processes dist.barrier() - mapping = dict() - optim_state = optimizer.state_dict() - for k, v in optim_state['state'].items(): - for n, t in v.items(): - if isinstance(t, ColoTensor): - mapping[(k, n)] = t.dist_spec - gather_tensor(t) - - if rank == 0: - save_state = {'epoch': epoch, 'optim': optim_state} - torch.save(save_state, dire + '/epoch_{}_optim.pth'.format(epoch)) - # recover colo tensors in rank0 - for k, v in optimizer.state_dict()['state'].items(): + if optimizer is not None: + mapping = dict() + optim_state = optimizer.state_dict() + for k, v in optim_state['state'].items(): for n, t in v.items(): if isinstance(t, ColoTensor): - assert hasattr(t, 'save_ready') - t.set_dist_spec(mapping[(k, n)]) - delattr(t, 'save_ready') + mapping[(k, n)] = t.dist_spec + gather_tensor(t) - del optim_state - del mapping - dist.barrier() + if rank == 0: + save_state = {'epoch': epoch, 'optim': optim_state} + torch.save(save_state, dire + '/epoch_{}_optim.pth'.format(epoch)) + # recover colo tensors in rank0 + for k, v in optimizer.state_dict()['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + assert hasattr(t, 'save_ready') + t.set_dist_spec(mapping[(k, n)]) + delattr(t, 'save_ready') + + del optim_state + del mapping + dist.barrier() def load_checkpoint(dire, @@ -108,21 +109,22 @@ def load_checkpoint(dire, delattr(p, 'save_ready') del mapping - mapping = dict() - for k, v in optimizer.state_dict()['state'].items(): - for n, t in v.items(): - if isinstance(t, ColoTensor): - mapping[(k, n)] = t.dist_spec - gather_tensor(t) + if optimizer is not None: + mapping = dict() + for k, v in optimizer.state_dict()['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + mapping[(k, n)] = t.dist_spec + gather_tensor(t) - if rank == 0: - colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch)) - optimizer.load_state_dict(colo_checkpoint['optim']) - dist.barrier() + if rank == 0: + colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch)) + optimizer.load_state_dict(colo_checkpoint['optim']) + dist.barrier() - for k, v in optimizer.state_dict()['state'].items(): - for n, t in v.items(): - if isinstance(t, ColoTensor): - scatter_tensor(t, mapping[(k, n)]) + for k, v in optimizer.state_dict()['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + scatter_tensor(t, mapping[(k, n)]) - del mapping + del mapping