mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 15:11:20 +00:00
[checkpoint] use gather_tensor in checkpoint and update its unit test (#1339)
This commit is contained in:
@@ -262,7 +262,7 @@ class ColoTensor(torch.Tensor):
|
||||
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
|
||||
return replicated_t.view(*args)
|
||||
|
||||
def size_global(self, args: Optional[int] = None):
|
||||
def size_global(self, args: Optional[int] = None) -> torch.Size:
|
||||
"""override the torch buildin size()
|
||||
the shape passed in must be in a replicate placement.
|
||||
Returns:
|
||||
|
@@ -141,9 +141,18 @@ class ProcessGroup:
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
def ranks_in_group(self):
|
||||
return self._rank_list
|
||||
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
def tp_rank_list(self):
|
||||
return self._tp_rank_list
|
||||
|
||||
def dp_rank_list(self):
|
||||
return self._dp_rank_list
|
||||
|
||||
def tp_local_rank(self):
|
||||
return self._rank % self._tp_degree
|
||||
|
||||
|
@@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.tensor import ColoTensor, DistSpecManager
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from copy import copy
|
||||
from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@@ -22,37 +22,52 @@ def save_checkpoint(dire: str,
|
||||
optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None.
|
||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
|
||||
"""
|
||||
rank = dist.get_rank()
|
||||
model_state = model.state_dict()
|
||||
# save the dist context about the tensors in a new dict, while still maintain the original dict.
|
||||
for k, v in model_state.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
gather_tensor(v) # gather shared tensors to rank0
|
||||
# don't recover tensors in rank0, since the dict is only a copy of model
|
||||
|
||||
if rank == 0:
|
||||
# sanity check
|
||||
for k, v in model_state.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
assert v.save_ready
|
||||
assert v.is_replicate()
|
||||
delattr(v, 'save_ready')
|
||||
# model saving
|
||||
save_state = {'epoch': epoch, 'model': model_state}
|
||||
torch.save(save_state, dire + '/epoch_{}_model.pth'.format(epoch))
|
||||
|
||||
# delete old dicts
|
||||
del model_state
|
||||
# synchronize all the processes
|
||||
dist.barrier()
|
||||
|
||||
mapping = dict()
|
||||
new_dict = dict()
|
||||
|
||||
# save the dist context about the tensors in a new dict, while still maintain the original dict.
|
||||
for k, v in model.state_dict().items():
|
||||
if isinstance(v, ColoTensor):
|
||||
mapping[k] = (v.dist_spec, v.compute_spec)
|
||||
new_dict[k] = v.to_replicate().detach()
|
||||
else:
|
||||
new_dict[k] = v
|
||||
if dist.get_rank() == 0:
|
||||
for k, v in new_dict.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
assert v.is_replicate()
|
||||
|
||||
model_state = {'epoch': epoch, 'model': new_dict}
|
||||
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
|
||||
|
||||
# delete the new dict
|
||||
del new_dict
|
||||
|
||||
optim_state_copy = copy(optimizer.state_dict())
|
||||
for k, v in optim_state_copy['state'].items():
|
||||
optim_state = optimizer.state_dict()
|
||||
for k, v in optim_state['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
t.to_replicate_()
|
||||
if dist.get_rank() == 0:
|
||||
model_state = {'epoch': epoch, 'optim': optim_state_copy}
|
||||
torch.save(model_state, dire + '/epoch_{}_optim.pth'.format(epoch))
|
||||
del optim_state_copy
|
||||
mapping[(k, n)] = t.dist_spec
|
||||
gather_tensor(t)
|
||||
|
||||
if rank == 0:
|
||||
save_state = {'epoch': epoch, 'optim': optim_state}
|
||||
torch.save(save_state, dire + '/epoch_{}_optim.pth'.format(epoch))
|
||||
# recover colo tensors in rank0
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
assert hasattr(t, 'save_ready')
|
||||
t.set_dist_spec(mapping[(k, n)])
|
||||
delattr(t, 'save_ready')
|
||||
|
||||
del optim_state
|
||||
del mapping
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def load_checkpoint(dire,
|
||||
@@ -72,39 +87,42 @@ def load_checkpoint(dire,
|
||||
optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None.
|
||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
|
||||
"""
|
||||
rank = dist.get_rank()
|
||||
mapping = dict()
|
||||
for n, p in model.named_parameters():
|
||||
if isinstance(p, ColoTensor):
|
||||
mapping[n] = p.dist_spec
|
||||
gather_tensor(p)
|
||||
|
||||
if rank == 0:
|
||||
load_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch))
|
||||
model.load_state_dict(load_state['model'])
|
||||
dist.barrier()
|
||||
|
||||
# scatter loaded parameters
|
||||
for n, p in model.named_parameters():
|
||||
if isinstance(p, ColoTensor):
|
||||
scatter_tensor(p, mapping[n])
|
||||
if rank == 0:
|
||||
assert hasattr(p, 'save_ready')
|
||||
delattr(p, 'save_ready')
|
||||
del mapping
|
||||
|
||||
mapping = dict()
|
||||
for k, v in model.state_dict().items():
|
||||
if isinstance(v, ColoTensor):
|
||||
mapping[k] = (v.dist_spec, v.compute_spec)
|
||||
v.to_replicate_()
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
mapping[(k, n)] = t.dist_spec
|
||||
gather_tensor(t)
|
||||
|
||||
model_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch))
|
||||
model.load_state_dict(model_state['model'])
|
||||
if rank == 0:
|
||||
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
|
||||
optimizer.load_state_dict(colo_checkpoint['optim'])
|
||||
dist.barrier()
|
||||
|
||||
# reset tensors to original dist spec.
|
||||
with DistSpecManager.no_grad():
|
||||
for k, v in model.state_dict().items():
|
||||
if isinstance(v, ColoTensor):
|
||||
v.set_tensor_spec(*mapping[k])
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
scatter_tensor(t, mapping[(k, n)])
|
||||
|
||||
del mapping
|
||||
mapping = dict()
|
||||
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
mapping[(k, n)] = (t.dist_spec, t.compute_spec)
|
||||
t.to_replicate_()
|
||||
|
||||
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
|
||||
optimizer.load_state_dict(colo_checkpoint['optim'])
|
||||
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
# skip key not in mapping.
|
||||
# For Adam, if it dose not execute step() once, there will be not exp_avg and exp_avg_sq in optimizer
|
||||
if (k, n) not in mapping:
|
||||
continue
|
||||
t.set_tensor_spec(*mapping[(k, n)])
|
||||
|
50
colossalai/utils/checkpoint/utils.py
Normal file
50
colossalai/utils/checkpoint/utils.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor.distspec import _DistSpec
|
||||
|
||||
|
||||
def gather_tensor(colo_tensor: ColoTensor) -> None:
|
||||
"""Make colo_tensor replicated when the rank is 0
|
||||
"""
|
||||
if not colo_tensor.is_replicate():
|
||||
pg = colo_tensor.get_process_group()
|
||||
# for the group which contains rank 0
|
||||
if pg.tp_rank_list()[0] == 0:
|
||||
old_dist_spec = colo_tensor.dist_spec
|
||||
colo_tensor.to_replicate_()
|
||||
if dist.get_rank() != 0:
|
||||
colo_tensor.set_dist_spec(old_dist_spec)
|
||||
|
||||
# synchronize all processes for unexpected problems
|
||||
dist.barrier()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
setattr(colo_tensor, 'save_ready', True) # set saving signitrue
|
||||
|
||||
|
||||
def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
|
||||
"""Reversal operation of `gather_tensor`.
|
||||
"""
|
||||
if dist_spec.placement == 'r':
|
||||
dist.broadcast(colo_tensor.data, 0)
|
||||
else:
|
||||
global_size = colo_tensor.size_global()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
entire_data = colo_tensor.data
|
||||
else:
|
||||
entire_data = torch.empty(global_size, device=colo_tensor.device)
|
||||
dist.broadcast(entire_data, 0)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
colo_tensor.set_dist_spec(dist_spec)
|
||||
else:
|
||||
rep_tensor = ColoTensor(entire_data, ColoTensorSpec(
|
||||
pg=colo_tensor.get_process_group(),
|
||||
compute_attr=colo_tensor.compute_spec))
|
||||
rep_tensor.set_dist_spec(dist_spec)
|
||||
with torch.no_grad():
|
||||
colo_tensor.data.copy_(rep_tensor.data)
|
||||
# synchronize all processes for unexpected problems
|
||||
dist.barrier()
|
Reference in New Issue
Block a user