mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[hotfix] fix no optimizer in save/load (#1363)
This commit is contained in:
@@ -46,28 +46,29 @@ def save_checkpoint(dire: str,
|
||||
# synchronize all the processes
|
||||
dist.barrier()
|
||||
|
||||
mapping = dict()
|
||||
optim_state = optimizer.state_dict()
|
||||
for k, v in optim_state['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
mapping[(k, n)] = t.dist_spec
|
||||
gather_tensor(t)
|
||||
|
||||
if rank == 0:
|
||||
save_state = {'epoch': epoch, 'optim': optim_state}
|
||||
torch.save(save_state, dire + '/epoch_{}_optim.pth'.format(epoch))
|
||||
# recover colo tensors in rank0
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
if optimizer is not None:
|
||||
mapping = dict()
|
||||
optim_state = optimizer.state_dict()
|
||||
for k, v in optim_state['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
assert hasattr(t, 'save_ready')
|
||||
t.set_dist_spec(mapping[(k, n)])
|
||||
delattr(t, 'save_ready')
|
||||
mapping[(k, n)] = t.dist_spec
|
||||
gather_tensor(t)
|
||||
|
||||
del optim_state
|
||||
del mapping
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
save_state = {'epoch': epoch, 'optim': optim_state}
|
||||
torch.save(save_state, dire + '/epoch_{}_optim.pth'.format(epoch))
|
||||
# recover colo tensors in rank0
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
assert hasattr(t, 'save_ready')
|
||||
t.set_dist_spec(mapping[(k, n)])
|
||||
delattr(t, 'save_ready')
|
||||
|
||||
del optim_state
|
||||
del mapping
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def load_checkpoint(dire,
|
||||
@@ -108,21 +109,22 @@ def load_checkpoint(dire,
|
||||
delattr(p, 'save_ready')
|
||||
del mapping
|
||||
|
||||
mapping = dict()
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
mapping[(k, n)] = t.dist_spec
|
||||
gather_tensor(t)
|
||||
if optimizer is not None:
|
||||
mapping = dict()
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
mapping[(k, n)] = t.dist_spec
|
||||
gather_tensor(t)
|
||||
|
||||
if rank == 0:
|
||||
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
|
||||
optimizer.load_state_dict(colo_checkpoint['optim'])
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
|
||||
optimizer.load_state_dict(colo_checkpoint['optim'])
|
||||
dist.barrier()
|
||||
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
scatter_tensor(t, mapping[(k, n)])
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
scatter_tensor(t, mapping[(k, n)])
|
||||
|
||||
del mapping
|
||||
del mapping
|
||||
|
Reference in New Issue
Block a user