mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[tensor] distributed checkpointing for parameters (#1240)
This commit is contained in:
@@ -1,19 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
import collections
|
||||
import inspect
|
||||
from colossalai.utils.model.colo_init_context import colo_state_dict
|
||||
|
||||
|
||||
def filter_dict(dict_to_filter, thing_with_kwargs):
|
||||
sig = inspect.signature(thing_with_kwargs)
|
||||
filter_keys = [param.name for param in sig.parameters.values() if param.kind == param.POSITIONAL_OR_KEYWORD]
|
||||
filter_dict = {}
|
||||
for filter_key in filter_keys:
|
||||
if filter_key in dict_to_filter:
|
||||
filter_dict[filter_key] = dict_to_filter[filter_key]
|
||||
return filter_dict
|
||||
from colossalai.tensor import ColoTensor, DistSpecManager
|
||||
|
||||
|
||||
def save_checkpoint(dire: str,
|
||||
@@ -32,21 +19,30 @@ def save_checkpoint(dire: str,
|
||||
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
|
||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
|
||||
"""
|
||||
model_state = {'epoch': epoch, 'model': model.state_dict()}
|
||||
|
||||
mapping = dict()
|
||||
new_dict = dict()
|
||||
|
||||
# save the dist context about the tensors in a new dict, while still maintain the original dict.
|
||||
for k, v in model.state_dict().items():
|
||||
if isinstance(v, ColoTensor):
|
||||
mapping[k] = (v.dist_spec, v.compute_spec)
|
||||
new_dict[k] = v.to_replicate().detach()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
for k, v in new_dict.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
assert v.is_replicate()
|
||||
|
||||
model_state = {'epoch': epoch, 'model': new_dict}
|
||||
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
|
||||
|
||||
# TODO() If use tensor parallelism, optim_states contain SHARD ColoTensors.
|
||||
# 1. convert SHARD ColoTensor to REPLICATE
|
||||
# only rank 0 saves the REPLICATE tensors.
|
||||
optim_state = {'epoch': epoch, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict()}
|
||||
|
||||
torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank()))
|
||||
# delete the new dict
|
||||
del new_dict
|
||||
|
||||
|
||||
def load_checkpoint(dire,
|
||||
epoch: int,
|
||||
rank: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
@@ -62,19 +58,18 @@ def load_checkpoint(dire,
|
||||
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None.
|
||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
|
||||
"""
|
||||
|
||||
mapping = dict()
|
||||
for k, v in model.named_parameters():
|
||||
if isinstance(v, ColoTensor):
|
||||
mapping[k] = (v.dist_spec, v.compute_spec)
|
||||
v.to_replicate_()
|
||||
|
||||
model_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch))
|
||||
model_state['model'] = collections.OrderedDict([(k.split('.', 1)[1], v) for k, v in model_state['model'].items()])
|
||||
model.load_state_dict(model_state['model'])
|
||||
optim_state = torch.load(dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, rank))
|
||||
optimizer.load_state_dict(optim_state['optimizer'])
|
||||
lr_scheduler_dict = optim_state['lr_scheduler']
|
||||
if 'after_scheduler_type' in lr_scheduler_dict:
|
||||
after_scheduler_type = lr_scheduler_dict.pop('after_scheduler_type')
|
||||
after_scheduler_dict = lr_scheduler_dict.pop('after_scheduler_dict')
|
||||
reload_scheduler = getattr(torch.optim.lr_scheduler, after_scheduler_type)
|
||||
filtered_dict = filter_dict(after_scheduler_dict, reload_scheduler)
|
||||
lr_scheduler_dict['after_scheduler'] = reload_scheduler(
|
||||
optimizer,
|
||||
**filtered_dict,
|
||||
)
|
||||
lr_scheduler.load_state_dict(lr_scheduler_dict)
|
||||
|
||||
# reset tensors to original dist spec.
|
||||
with DistSpecManager.no_grad():
|
||||
for k, v in model.named_parameters():
|
||||
if isinstance(v, ColoTensor):
|
||||
v.set_tensor_spec(*mapping[k])
|
||||
|
@@ -1,13 +1,10 @@
|
||||
from .utils import InsertPostInitMethodToModuleSubClasses
|
||||
import torch
|
||||
from colossalai.tensor import ColoTensor, ColoParameter, distspec, ProcessGroup, ReplicaSpec
|
||||
|
||||
from colossalai.tensor import ColoTensor, ColoParameter
|
||||
from colossalai.nn.parallel.layers import register_colo_module, \
|
||||
ColoLinear, ColoEmbedding
|
||||
from copy import copy
|
||||
from torch import nn
|
||||
from typing import Iterator, Tuple, Union
|
||||
from functools import partialmethod
|
||||
# find named_params includes replica
|
||||
|
||||
|
||||
@@ -34,47 +31,6 @@ def ColoModulize(module):
|
||||
module._colo_visited = True
|
||||
|
||||
|
||||
def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_dict_func=None):
|
||||
# build param to spec mapping
|
||||
mapping1 = dict()
|
||||
mapping2 = dict()
|
||||
mapping3 = dict()
|
||||
# gather all params
|
||||
has_dist_parameter = False
|
||||
with torch.no_grad():
|
||||
for param in self.parameters():
|
||||
if isinstance(param, ColoParameter):
|
||||
has_dist_parameter = True
|
||||
mapping1[id(param)] = copy(param.dist_spec)
|
||||
mapping2[id(param)] = copy(param.compute_spec)
|
||||
# TODO(jiaruifang) fixme, we should elegently handle the default PG in init context
|
||||
if param.get_process_group() is None:
|
||||
param.process_group = ProcessGroup()
|
||||
param.set_dist_spec(distspec.replicate())
|
||||
mapping3[id(param)] = param.get_process_group()
|
||||
param.process_group = None
|
||||
|
||||
# TODO: fix when keep_vars = True
|
||||
# when keep_vars = False, the state_dict_func will call detach to create
|
||||
# new tensors, but when keep_vars = True, the recovery of spec will be reflected
|
||||
# in the `ret`, such that the final state dict will still contain process group,
|
||||
# raising exception as it is not serializable
|
||||
assert not (keep_vars and has_dist_parameter), 'keep_vars cannot be True when there are distributed ColoParameters.'
|
||||
|
||||
ret = state_dict_func(self, destination, prefix, keep_vars)
|
||||
|
||||
# recover
|
||||
with torch.no_grad():
|
||||
for param in self.parameters():
|
||||
param_id = id(param)
|
||||
if param_id in mapping1:
|
||||
dist_spec = mapping1[id(param)]
|
||||
compute_spec = mapping2[id(param)]
|
||||
param.process_group = mapping3[id(param)]
|
||||
param.set_tensor_spec(dist_spec, compute_spec)
|
||||
return ret
|
||||
|
||||
|
||||
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')):
|
||||
@@ -94,8 +50,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
register_colo_module(torch.nn.Embedding, ColoEmbedding())
|
||||
|
||||
def _pre_context_exec(self):
|
||||
self.state_dict_func = nn.Module.state_dict
|
||||
nn.Module.state_dict = partialmethod(colo_state_dict, state_dict_func=self.state_dict_func)
|
||||
pass
|
||||
|
||||
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user