mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-04 15:14:19 +00:00
[zero] adapt zero for unsharded parameters (#561)
* support existing sharded and unsharded parameters in zero * add unitest for moe-zero model init * polish moe gradient handler
This commit is contained in:
@@ -88,6 +88,8 @@ class ZeroContextConfig(object):
|
||||
"""The configuration used to control zero context initialization.
|
||||
|
||||
Args:
|
||||
replicated (bool, optional): Whether the param is replicated across data parallel group.
|
||||
Some parameters are not replicated, e.g. parameters in MOE experts.
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
|
||||
This will reduce memory usage when initializing model.
|
||||
@@ -97,8 +99,9 @@ class ZeroContextConfig(object):
|
||||
See torchvision resnet18. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, shard_param: bool = False, rm_torch_payload_on_the_fly: bool = False):
|
||||
def __init__(self, replicated: bool = True, shard_param: bool = False, rm_torch_payload_on_the_fly: bool = False):
|
||||
super().__init__()
|
||||
self.is_replicated: bool = replicated
|
||||
self.shard_param: bool = shard_param
|
||||
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly
|
||||
|
||||
@@ -139,10 +142,15 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
self.model_numel_tensor = model_numel_tensor
|
||||
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||
|
||||
self.config = ZeroContextConfig(shard_param=shard_param,
|
||||
self.config = ZeroContextConfig(replicated=True,
|
||||
shard_param=shard_param,
|
||||
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly)
|
||||
ZeroContextMgr().current_context = self
|
||||
|
||||
@property
|
||||
def is_replicated(self):
|
||||
return self.config.is_replicated
|
||||
|
||||
@property
|
||||
def shard_param(self):
|
||||
return self.config.shard_param
|
||||
@@ -183,6 +191,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
self.model_numel_tensor += param.numel()
|
||||
|
||||
# mark whether the param is replicated
|
||||
param.is_replicated = self.is_replicated
|
||||
|
||||
# convert parameters to half
|
||||
param_half = half_fn(param)
|
||||
param.data = param_half
|
||||
@@ -224,14 +235,20 @@ class ZeroContextMgr(metaclass=SingletonMeta):
|
||||
self.current_context.config = old_config
|
||||
|
||||
|
||||
def no_shard_zero_context():
|
||||
return ZeroContextMgr().hijack_context_config(shard_param=False, rm_torch_payload_on_the_fly=False)
|
||||
def no_shard_zero_context(is_replicated: bool = True):
|
||||
return ZeroContextMgr().hijack_context_config(replicated=is_replicated,
|
||||
shard_param=False,
|
||||
rm_torch_payload_on_the_fly=False)
|
||||
|
||||
|
||||
def no_shard_zero_decrator(init_func):
|
||||
def no_shard_zero_decrator(is_replicated: bool = True):
|
||||
|
||||
def _no_shard(*args, **kwargs):
|
||||
with no_shard_zero_context():
|
||||
init_func(*args, **kwargs)
|
||||
def _wrapper(init_func):
|
||||
|
||||
return _no_shard
|
||||
def _no_shard(*args, **kwargs):
|
||||
with no_shard_zero_context(is_replicated):
|
||||
init_func(*args, **kwargs)
|
||||
|
||||
return _no_shard
|
||||
|
||||
return _wrapper
|
||||
|
||||
@@ -10,6 +10,7 @@ from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.ophooks import register_ophooks_recursively
|
||||
from colossalai.engine.ophooks.zero_hook import ZeroHook
|
||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||
from colossalai.engine.gradient_handler.utils import bucket_allreduce
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
@@ -67,17 +68,27 @@ class ShardedModelV2(nn.Module):
|
||||
self.logger = get_dist_logger()
|
||||
|
||||
# We force users to use ZeroInitContext
|
||||
sharded = []
|
||||
unsharded = []
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.'
|
||||
sharded.append(param.colo_attr.param_is_sharded)
|
||||
unsharded.append(not param.colo_attr.param_is_sharded)
|
||||
assert all(sharded) or all(
|
||||
unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded now.'
|
||||
self.shard_param = all(sharded)
|
||||
self.module = module
|
||||
for submodule in module.modules():
|
||||
sharded_cnt = 0
|
||||
unshard_cnt = 0
|
||||
for param in submodule.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.'
|
||||
if param.colo_attr.param_is_sharded:
|
||||
sharded_cnt += 1
|
||||
else:
|
||||
unshard_cnt += 1
|
||||
assert (not sharded_cnt) or (not unshard_cnt), 'nn.Module can not both have shard param and unshard param'
|
||||
submodule.param_is_sharded = (sharded_cnt > 0)
|
||||
|
||||
self.sharded_params = []
|
||||
self.unshard_params = []
|
||||
for param in module.parameters():
|
||||
if param.colo_attr.param_is_sharded:
|
||||
self.sharded_params.append(param)
|
||||
else:
|
||||
self.unshard_params.append(param)
|
||||
|
||||
self.module = module
|
||||
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
||||
self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
@@ -95,8 +106,8 @@ class ShardedModelV2(nn.Module):
|
||||
|
||||
# Register hooks
|
||||
self._ophook_list = [ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)]
|
||||
register_ophooks_recursively(self.module, self._ophook_list)
|
||||
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
|
||||
register_ophooks_recursively(self.module, self._ophook_list, filter_fn=lambda m: not m.param_is_sharded)
|
||||
self.param_hook_mgr = BaseParamHookMgr(self.sharded_params)
|
||||
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||
|
||||
self.fp32_reduce_scatter = fp32_reduce_scatter
|
||||
@@ -185,7 +196,6 @@ class ShardedModelV2(nn.Module):
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
|
||||
|
||||
self._post_backward_operations()
|
||||
for ophook in self._ophook_list:
|
||||
ophook.post_iter()
|
||||
@@ -224,17 +234,21 @@ class ShardedModelV2(nn.Module):
|
||||
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.reducer.free()
|
||||
# 3. shard tensors not dealed in the zero hook
|
||||
if self.shard_param:
|
||||
tensor_list = []
|
||||
for p in self.module.parameters():
|
||||
if not p.colo_attr.param_is_sharded:
|
||||
tensor_list.append(p.colo_attr.sharded_data_tensor)
|
||||
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
p.colo_attr.remove_torch_payload()
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
|
||||
# 4. move sharded param grad payload to param.grad
|
||||
# all reduce gradients for unsharded parameters
|
||||
reduce_list = [p for p in self.unshard_params if p.is_replicated]
|
||||
bucket_allreduce(reduce_list, self.process_group)
|
||||
|
||||
# 3. shard tensors not dealed in the zero hook
|
||||
tensor_list = []
|
||||
for p in self.sharded_params:
|
||||
if not p.colo_attr.param_is_sharded:
|
||||
tensor_list.append(p.colo_attr.sharded_data_tensor)
|
||||
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
p.colo_attr.remove_torch_payload()
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
|
||||
# 4. set all parameters' grad to None
|
||||
for p in self.module.parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
@@ -245,6 +259,16 @@ class ShardedModelV2(nn.Module):
|
||||
# We also allows to interleave no-sync pass with sync passes, if desired.
|
||||
if not self._require_backward_grad_sync:
|
||||
continue
|
||||
|
||||
# move unsharded param grad to saved_grad
|
||||
if not p.colo_attr.param_is_sharded:
|
||||
if p.colo_attr.offload_grad:
|
||||
colo_model_data_move_to_cpu(p.grad)
|
||||
if p.colo_attr.saved_grad.is_null():
|
||||
p.colo_attr.saved_grad.reset_payload(p.grad.data)
|
||||
else:
|
||||
p.colo_attr.saved_grad.payload.add_(p.grad.data)
|
||||
|
||||
p.grad = None
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -320,16 +344,14 @@ class ShardedModelV2(nn.Module):
|
||||
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||
self.process_group)
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
|
||||
prev_params = {}
|
||||
for p in self.module.parameters():
|
||||
for p in self.sharded_params:
|
||||
prev_params[p] = p.data
|
||||
p.data = p.colo_attr.sharded_data_tensor.payload
|
||||
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||
self.process_group)
|
||||
for p in self.module.parameters():
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
|
||||
for p in self.sharded_params:
|
||||
p.data = prev_params[p]
|
||||
return gathered_state_dict
|
||||
|
||||
|
||||
Reference in New Issue
Block a user