[checkpoint] use gather_tensor in checkpoint and update its unit test (#1339)

This commit is contained in:
HELSON
2022-07-19 14:15:28 +08:00
committed by GitHub
parent f3ce7b8336
commit f92c100ddd
6 changed files with 209 additions and 91 deletions

View File

@@ -262,7 +262,7 @@ class ColoTensor(torch.Tensor):
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
return replicated_t.view(*args)
def size_global(self, args: Optional[int] = None):
def size_global(self, args: Optional[int] = None) -> torch.Size:
"""override the torch buildin size()
the shape passed in must be in a replicate placement.
Returns:

View File

@@ -141,9 +141,18 @@ class ProcessGroup:
def rank(self):
return self._rank
def ranks_in_group(self):
return self._rank_list
def world_size(self):
return self._world_size
def tp_rank_list(self):
return self._tp_rank_list
def dp_rank_list(self):
return self._dp_rank_list
def tp_local_rank(self):
return self._rank % self._tp_degree

View File

@@ -1,8 +1,8 @@
import torch
import torch.distributed as dist
from colossalai.tensor import ColoTensor, DistSpecManager
from colossalai.tensor import ColoTensor
from colossalai.nn.optimizer import ColossalaiOptimizer
from copy import copy
from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor
from typing import Optional
@@ -22,37 +22,52 @@ def save_checkpoint(dire: str,
optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
"""
rank = dist.get_rank()
model_state = model.state_dict()
# save the dist context about the tensors in a new dict, while still maintain the original dict.
for k, v in model_state.items():
if isinstance(v, ColoTensor):
gather_tensor(v) # gather shared tensors to rank0
# don't recover tensors in rank0, since the dict is only a copy of model
if rank == 0:
# sanity check
for k, v in model_state.items():
if isinstance(v, ColoTensor):
assert v.save_ready
assert v.is_replicate()
delattr(v, 'save_ready')
# model saving
save_state = {'epoch': epoch, 'model': model_state}
torch.save(save_state, dire + '/epoch_{}_model.pth'.format(epoch))
# delete old dicts
del model_state
# synchronize all the processes
dist.barrier()
mapping = dict()
new_dict = dict()
# save the dist context about the tensors in a new dict, while still maintain the original dict.
for k, v in model.state_dict().items():
if isinstance(v, ColoTensor):
mapping[k] = (v.dist_spec, v.compute_spec)
new_dict[k] = v.to_replicate().detach()
else:
new_dict[k] = v
if dist.get_rank() == 0:
for k, v in new_dict.items():
if isinstance(v, ColoTensor):
assert v.is_replicate()
model_state = {'epoch': epoch, 'model': new_dict}
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
# delete the new dict
del new_dict
optim_state_copy = copy(optimizer.state_dict())
for k, v in optim_state_copy['state'].items():
optim_state = optimizer.state_dict()
for k, v in optim_state['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
t.to_replicate_()
if dist.get_rank() == 0:
model_state = {'epoch': epoch, 'optim': optim_state_copy}
torch.save(model_state, dire + '/epoch_{}_optim.pth'.format(epoch))
del optim_state_copy
mapping[(k, n)] = t.dist_spec
gather_tensor(t)
if rank == 0:
save_state = {'epoch': epoch, 'optim': optim_state}
torch.save(save_state, dire + '/epoch_{}_optim.pth'.format(epoch))
# recover colo tensors in rank0
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
assert hasattr(t, 'save_ready')
t.set_dist_spec(mapping[(k, n)])
delattr(t, 'save_ready')
del optim_state
del mapping
dist.barrier()
def load_checkpoint(dire,
@@ -72,39 +87,42 @@ def load_checkpoint(dire,
optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
"""
rank = dist.get_rank()
mapping = dict()
for n, p in model.named_parameters():
if isinstance(p, ColoTensor):
mapping[n] = p.dist_spec
gather_tensor(p)
if rank == 0:
load_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch))
model.load_state_dict(load_state['model'])
dist.barrier()
# scatter loaded parameters
for n, p in model.named_parameters():
if isinstance(p, ColoTensor):
scatter_tensor(p, mapping[n])
if rank == 0:
assert hasattr(p, 'save_ready')
delattr(p, 'save_ready')
del mapping
mapping = dict()
for k, v in model.state_dict().items():
if isinstance(v, ColoTensor):
mapping[k] = (v.dist_spec, v.compute_spec)
v.to_replicate_()
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
mapping[(k, n)] = t.dist_spec
gather_tensor(t)
model_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch))
model.load_state_dict(model_state['model'])
if rank == 0:
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
optimizer.load_state_dict(colo_checkpoint['optim'])
dist.barrier()
# reset tensors to original dist spec.
with DistSpecManager.no_grad():
for k, v in model.state_dict().items():
if isinstance(v, ColoTensor):
v.set_tensor_spec(*mapping[k])
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
scatter_tensor(t, mapping[(k, n)])
del mapping
mapping = dict()
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
mapping[(k, n)] = (t.dist_spec, t.compute_spec)
t.to_replicate_()
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
optimizer.load_state_dict(colo_checkpoint['optim'])
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
# skip key not in mapping.
# For Adam, if it dose not execute step() once, there will be not exp_avg and exp_avg_sq in optimizer
if (k, n) not in mapping:
continue
t.set_tensor_spec(*mapping[(k, n)])

View File

@@ -0,0 +1,50 @@
import torch
import torch.distributed as dist
from colossalai.tensor import ColoTensor, ColoTensorSpec
from colossalai.tensor.distspec import _DistSpec
def gather_tensor(colo_tensor: ColoTensor) -> None:
"""Make colo_tensor replicated when the rank is 0
"""
if not colo_tensor.is_replicate():
pg = colo_tensor.get_process_group()
# for the group which contains rank 0
if pg.tp_rank_list()[0] == 0:
old_dist_spec = colo_tensor.dist_spec
colo_tensor.to_replicate_()
if dist.get_rank() != 0:
colo_tensor.set_dist_spec(old_dist_spec)
# synchronize all processes for unexpected problems
dist.barrier()
if dist.get_rank() == 0:
setattr(colo_tensor, 'save_ready', True) # set saving signitrue
def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
"""Reversal operation of `gather_tensor`.
"""
if dist_spec.placement == 'r':
dist.broadcast(colo_tensor.data, 0)
else:
global_size = colo_tensor.size_global()
if dist.get_rank() == 0:
entire_data = colo_tensor.data
else:
entire_data = torch.empty(global_size, device=colo_tensor.device)
dist.broadcast(entire_data, 0)
if dist.get_rank() == 0:
colo_tensor.set_dist_spec(dist_spec)
else:
rep_tensor = ColoTensor(entire_data, ColoTensorSpec(
pg=colo_tensor.get_process_group(),
compute_attr=colo_tensor.compute_spec))
rep_tensor.set_dist_spec(dist_spec)
with torch.no_grad():
colo_tensor.data.copy_(rep_tensor.data)
# synchronize all processes for unexpected problems
dist.barrier()