[zero] adapt zero for unsharded paramters (Optimizer part) (#601)

This commit is contained in:
HELSON
2022-04-01 20:10:47 +08:00
committed by GitHub
parent 229382c844
commit 055fbf5be6
8 changed files with 208 additions and 44 deletions

View File

@@ -11,6 +11,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup
from contextlib import AbstractContextManager
def _substitute_init_recursively(cls, func):
@@ -88,6 +89,7 @@ class ZeroContextConfig(object):
"""The configuration used to control zero context initialization.
Args:
target_device (torch.device): The device where param data are after exiting the context.
replicated (bool, optional): Whether the param is replicated across data parallel group.
Some parameters are not replicated, e.g. parameters in MOE experts.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
@@ -99,8 +101,13 @@ class ZeroContextConfig(object):
See torchvision resnet18. Defaults to False.
"""
def __init__(self, replicated: bool = True, shard_param: bool = False, rm_torch_payload_on_the_fly: bool = False):
def __init__(self,
target_device: torch.device,
replicated: bool = True,
shard_param: bool = False,
rm_torch_payload_on_the_fly: bool = False):
super().__init__()
self.target_device = target_device
self.is_replicated: bool = replicated
self.shard_param: bool = shard_param
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly
@@ -114,7 +121,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
3. Shard the param and grad according to flags.
Args:
target_device (torch.device): The device where param data after exiting the context.
target_device (torch.device): The device where param data are after exiting the context.
shard_strategy (BaseShardStrategy): Shard strategy instance.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
@@ -136,17 +143,22 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
dp_process_group: Optional[ProcessGroup] = None):
super().__init__()
self.target_device = target_device
self.shard_strategy = shard_strategy
self.initialized_param_list = []
self.model_numel_tensor = model_numel_tensor
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
self.config = ZeroContextConfig(replicated=True,
self.config = ZeroContextConfig(target_device=target_device,
replicated=True,
shard_param=shard_param,
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly)
ZeroContextMgr().current_context = self
@property
def target_device(self):
return self.config.target_device
@property
def is_replicated(self):
return self.config.is_replicated
@@ -235,8 +247,9 @@ class ZeroContextMgr(metaclass=SingletonMeta):
self.current_context.config = old_config
def no_shard_zero_context(is_replicated: bool = True):
return ZeroContextMgr().hijack_context_config(replicated=is_replicated,
def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()),
replicated=is_replicated,
shard_param=False,
rm_torch_payload_on_the_fly=False)

View File

@@ -12,13 +12,12 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.shard_utils.tensor_utils import (colo_model_tensor_clone, colo_tensor_mem_usage)
from colossalai.zero.shard_utils.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
@@ -69,6 +68,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
keep_unsharded (bool, optional): if True, optimizer won't shard unsharded parameters.
In Zero-2, set keep_unsharded to False.
In Zero-3, set keep_unsharded to True.
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None.
mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None.
@@ -89,6 +91,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
growth_interval: float = 1000,
hysteresis: float = 2,
max_scale: int = 2**32,
keep_unsharded: bool = False,
dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None:
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
@@ -122,24 +125,12 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
self._logger = get_dist_logger("ShardedOptimizerV2")
# Store fp32 param shards
self.master_params: Dict[Parameter, StatefulTensor] = {}
assert not (keep_unsharded and self._should_move_fp32_shards_h2d), \
"Keeping unsharded parameters can't be used with hybrid OS placement right now."
self.keep_unshard = keep_unsharded
for group in self.optim.param_groups:
for p in group['params']:
assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
if not is_param_sharded:
# TODO (ver217): we may not use shard / gather here
# Param is no sharded, which means we use ZeRO-2 here
# As we only store param shard, we shard it here
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p] = StatefulTensor(
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device))
if not is_param_sharded:
# In this branch, there's no need to shard param
# So we gather here
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
# Store fp32 param shards
self._register_master_weight()
self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
ranks=[0])
@@ -283,6 +274,23 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def sync_grad(self):
pass
def _register_master_weight(self):
self.master_params: Dict[Parameter, StatefulTensor] = {}
for group in self.optim.param_groups:
for p in group['params']:
assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
if not is_param_sharded and not self.keep_unshard:
# Please use keep_unsharded to control whether shard unsharded paramters
# As we only store param shard, we shard it here
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p] = StatefulTensor(
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device))
if not is_param_sharded and not self.keep_unshard:
# In this branch, there's no need to shard param
# So we gather here
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
def _maybe_move_fp32_shards(self):
if self._should_move_fp32_shards_h2d:
self._should_move_fp32_shards_h2d = False
@@ -328,7 +336,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for group in self.optim.param_groups:
for p in group['params']:
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
if not is_param_sharded:
if not is_param_sharded and not self.keep_unshard:
# We use ZeRO-2 here
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
# But we only have updated fp32 param shard here
@@ -342,7 +350,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
if not is_param_sharded:
if not is_param_sharded and not self.keep_unshard:
# We gather full fp16 param here
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
p.data = p.colo_attr.sharded_data_tensor.payload