mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-09 11:58:06 +00:00
[hotfix] fix no optimizer in save/load (#1363)
This commit is contained in:
parent
cd063ac37f
commit
943a96323e
@ -13,6 +13,6 @@ from . import distspec
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
|
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
|
||||||
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState', 'ProcessGroup',
|
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec',
|
||||||
'ColoTensorSpec', 'TensorSpec', 'ShardSpec', 'ReplicaSpec'
|
'ShardSpec', 'ReplicaSpec'
|
||||||
]
|
]
|
||||||
|
@ -46,28 +46,29 @@ def save_checkpoint(dire: str,
|
|||||||
# synchronize all the processes
|
# synchronize all the processes
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
mapping = dict()
|
if optimizer is not None:
|
||||||
optim_state = optimizer.state_dict()
|
mapping = dict()
|
||||||
for k, v in optim_state['state'].items():
|
optim_state = optimizer.state_dict()
|
||||||
for n, t in v.items():
|
for k, v in optim_state['state'].items():
|
||||||
if isinstance(t, ColoTensor):
|
|
||||||
mapping[(k, n)] = t.dist_spec
|
|
||||||
gather_tensor(t)
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
save_state = {'epoch': epoch, 'optim': optim_state}
|
|
||||||
torch.save(save_state, dire + '/epoch_{}_optim.pth'.format(epoch))
|
|
||||||
# recover colo tensors in rank0
|
|
||||||
for k, v in optimizer.state_dict()['state'].items():
|
|
||||||
for n, t in v.items():
|
for n, t in v.items():
|
||||||
if isinstance(t, ColoTensor):
|
if isinstance(t, ColoTensor):
|
||||||
assert hasattr(t, 'save_ready')
|
mapping[(k, n)] = t.dist_spec
|
||||||
t.set_dist_spec(mapping[(k, n)])
|
gather_tensor(t)
|
||||||
delattr(t, 'save_ready')
|
|
||||||
|
|
||||||
del optim_state
|
if rank == 0:
|
||||||
del mapping
|
save_state = {'epoch': epoch, 'optim': optim_state}
|
||||||
dist.barrier()
|
torch.save(save_state, dire + '/epoch_{}_optim.pth'.format(epoch))
|
||||||
|
# recover colo tensors in rank0
|
||||||
|
for k, v in optimizer.state_dict()['state'].items():
|
||||||
|
for n, t in v.items():
|
||||||
|
if isinstance(t, ColoTensor):
|
||||||
|
assert hasattr(t, 'save_ready')
|
||||||
|
t.set_dist_spec(mapping[(k, n)])
|
||||||
|
delattr(t, 'save_ready')
|
||||||
|
|
||||||
|
del optim_state
|
||||||
|
del mapping
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(dire,
|
def load_checkpoint(dire,
|
||||||
@ -108,21 +109,22 @@ def load_checkpoint(dire,
|
|||||||
delattr(p, 'save_ready')
|
delattr(p, 'save_ready')
|
||||||
del mapping
|
del mapping
|
||||||
|
|
||||||
mapping = dict()
|
if optimizer is not None:
|
||||||
for k, v in optimizer.state_dict()['state'].items():
|
mapping = dict()
|
||||||
for n, t in v.items():
|
for k, v in optimizer.state_dict()['state'].items():
|
||||||
if isinstance(t, ColoTensor):
|
for n, t in v.items():
|
||||||
mapping[(k, n)] = t.dist_spec
|
if isinstance(t, ColoTensor):
|
||||||
gather_tensor(t)
|
mapping[(k, n)] = t.dist_spec
|
||||||
|
gather_tensor(t)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
|
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
|
||||||
optimizer.load_state_dict(colo_checkpoint['optim'])
|
optimizer.load_state_dict(colo_checkpoint['optim'])
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
for k, v in optimizer.state_dict()['state'].items():
|
for k, v in optimizer.state_dict()['state'].items():
|
||||||
for n, t in v.items():
|
for n, t in v.items():
|
||||||
if isinstance(t, ColoTensor):
|
if isinstance(t, ColoTensor):
|
||||||
scatter_tensor(t, mapping[(k, n)])
|
scatter_tensor(t, mapping[(k, n)])
|
||||||
|
|
||||||
del mapping
|
del mapping
|
||||||
|
Loading…
Reference in New Issue
Block a user