[hotfix] fix no optimizer in save/load (#1363)

This commit is contained in:
HELSON
2022-07-26 10:53:53 +08:00
committed by GitHub
parent cd063ac37f
commit 943a96323e
2 changed files with 38 additions and 36 deletions

View File

@@ -46,28 +46,29 @@ def save_checkpoint(dire: str,
# synchronize all the processes
dist.barrier()
mapping = dict()
optim_state = optimizer.state_dict()
for k, v in optim_state['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
mapping[(k, n)] = t.dist_spec
gather_tensor(t)
if rank == 0:
save_state = {'epoch': epoch, 'optim': optim_state}
torch.save(save_state, dire + '/epoch_{}_optim.pth'.format(epoch))
# recover colo tensors in rank0
for k, v in optimizer.state_dict()['state'].items():
if optimizer is not None:
mapping = dict()
optim_state = optimizer.state_dict()
for k, v in optim_state['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
assert hasattr(t, 'save_ready')
t.set_dist_spec(mapping[(k, n)])
delattr(t, 'save_ready')
mapping[(k, n)] = t.dist_spec
gather_tensor(t)
del optim_state
del mapping
dist.barrier()
if rank == 0:
save_state = {'epoch': epoch, 'optim': optim_state}
torch.save(save_state, dire + '/epoch_{}_optim.pth'.format(epoch))
# recover colo tensors in rank0
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
assert hasattr(t, 'save_ready')
t.set_dist_spec(mapping[(k, n)])
delattr(t, 'save_ready')
del optim_state
del mapping
dist.barrier()
def load_checkpoint(dire,
@@ -108,21 +109,22 @@ def load_checkpoint(dire,
delattr(p, 'save_ready')
del mapping
mapping = dict()
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
mapping[(k, n)] = t.dist_spec
gather_tensor(t)
if optimizer is not None:
mapping = dict()
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
mapping[(k, n)] = t.dist_spec
gather_tensor(t)
if rank == 0:
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
optimizer.load_state_dict(colo_checkpoint['optim'])
dist.barrier()
if rank == 0:
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
optimizer.load_state_dict(colo_checkpoint['optim'])
dist.barrier()
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
scatter_tensor(t, mapping[(k, n)])
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
scatter_tensor(t, mapping[(k, n)])
del mapping
del mapping