mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[zero] adapt zero for unsharded paramters (Optimizer part) (#601)
This commit is contained in:
@@ -11,6 +11,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from torch.distributed import ProcessGroup
|
||||
from contextlib import AbstractContextManager
|
||||
|
||||
|
||||
def _substitute_init_recursively(cls, func):
|
||||
@@ -88,6 +89,7 @@ class ZeroContextConfig(object):
|
||||
"""The configuration used to control zero context initialization.
|
||||
|
||||
Args:
|
||||
target_device (torch.device): The device where param data are after exiting the context.
|
||||
replicated (bool, optional): Whether the param is replicated across data parallel group.
|
||||
Some parameters are not replicated, e.g. parameters in MOE experts.
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
@@ -99,8 +101,13 @@ class ZeroContextConfig(object):
|
||||
See torchvision resnet18. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, replicated: bool = True, shard_param: bool = False, rm_torch_payload_on_the_fly: bool = False):
|
||||
def __init__(self,
|
||||
target_device: torch.device,
|
||||
replicated: bool = True,
|
||||
shard_param: bool = False,
|
||||
rm_torch_payload_on_the_fly: bool = False):
|
||||
super().__init__()
|
||||
self.target_device = target_device
|
||||
self.is_replicated: bool = replicated
|
||||
self.shard_param: bool = shard_param
|
||||
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly
|
||||
@@ -114,7 +121,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
3. Shard the param and grad according to flags.
|
||||
|
||||
Args:
|
||||
target_device (torch.device): The device where param data after exiting the context.
|
||||
target_device (torch.device): The device where param data are after exiting the context.
|
||||
shard_strategy (BaseShardStrategy): Shard strategy instance.
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
|
||||
@@ -136,17 +143,22 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
dp_process_group: Optional[ProcessGroup] = None):
|
||||
|
||||
super().__init__()
|
||||
self.target_device = target_device
|
||||
self.shard_strategy = shard_strategy
|
||||
self.initialized_param_list = []
|
||||
self.model_numel_tensor = model_numel_tensor
|
||||
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||
|
||||
self.config = ZeroContextConfig(replicated=True,
|
||||
self.config = ZeroContextConfig(target_device=target_device,
|
||||
replicated=True,
|
||||
shard_param=shard_param,
|
||||
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly)
|
||||
|
||||
ZeroContextMgr().current_context = self
|
||||
|
||||
@property
|
||||
def target_device(self):
|
||||
return self.config.target_device
|
||||
|
||||
@property
|
||||
def is_replicated(self):
|
||||
return self.config.is_replicated
|
||||
@@ -235,8 +247,9 @@ class ZeroContextMgr(metaclass=SingletonMeta):
|
||||
self.current_context.config = old_config
|
||||
|
||||
|
||||
def no_shard_zero_context(is_replicated: bool = True):
|
||||
return ZeroContextMgr().hijack_context_config(replicated=is_replicated,
|
||||
def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
|
||||
return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()),
|
||||
replicated=is_replicated,
|
||||
shard_param=False,
|
||||
rm_torch_payload_on_the_fly=False)
|
||||
|
||||
|
@@ -12,13 +12,12 @@ from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.zero.shard_utils.tensor_utils import (colo_model_tensor_clone, colo_tensor_mem_usage)
|
||||
from colossalai.zero.shard_utils.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
|
||||
colo_tensor_mem_usage)
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
|
||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline
|
||||
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -69,6 +68,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
|
||||
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
|
||||
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
|
||||
keep_unsharded (bool, optional): if True, optimizer won't shard unsharded parameters.
|
||||
In Zero-2, set keep_unsharded to False.
|
||||
In Zero-3, set keep_unsharded to True.
|
||||
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
|
||||
dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None.
|
||||
mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None.
|
||||
@@ -89,6 +91,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
growth_interval: float = 1000,
|
||||
hysteresis: float = 2,
|
||||
max_scale: int = 2**32,
|
||||
keep_unsharded: bool = False,
|
||||
dp_process_group: Optional[ProcessGroup] = None,
|
||||
mp_process_group: Optional[ProcessGroup] = None) -> None:
|
||||
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
|
||||
@@ -122,24 +125,12 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
|
||||
self._logger = get_dist_logger("ShardedOptimizerV2")
|
||||
|
||||
# Store fp32 param shards
|
||||
self.master_params: Dict[Parameter, StatefulTensor] = {}
|
||||
assert not (keep_unsharded and self._should_move_fp32_shards_h2d), \
|
||||
"Keeping unsharded parameters can't be used with hybrid OS placement right now."
|
||||
self.keep_unshard = keep_unsharded
|
||||
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
|
||||
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
|
||||
if not is_param_sharded:
|
||||
# TODO (ver217): we may not use shard / gather here
|
||||
# Param is no sharded, which means we use ZeRO-2 here
|
||||
# As we only store param shard, we shard it here
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.master_params[p] = StatefulTensor(
|
||||
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device))
|
||||
if not is_param_sharded:
|
||||
# In this branch, there's no need to shard param
|
||||
# So we gather here
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
# Store fp32 param shards
|
||||
self._register_master_weight()
|
||||
|
||||
self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
|
||||
ranks=[0])
|
||||
@@ -283,6 +274,23 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
def sync_grad(self):
|
||||
pass
|
||||
|
||||
def _register_master_weight(self):
|
||||
self.master_params: Dict[Parameter, StatefulTensor] = {}
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
|
||||
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
|
||||
if not is_param_sharded and not self.keep_unshard:
|
||||
# Please use keep_unsharded to control whether shard unsharded paramters
|
||||
# As we only store param shard, we shard it here
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.master_params[p] = StatefulTensor(
|
||||
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device))
|
||||
if not is_param_sharded and not self.keep_unshard:
|
||||
# In this branch, there's no need to shard param
|
||||
# So we gather here
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
|
||||
def _maybe_move_fp32_shards(self):
|
||||
if self._should_move_fp32_shards_h2d:
|
||||
self._should_move_fp32_shards_h2d = False
|
||||
@@ -328,7 +336,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
|
||||
if not is_param_sharded:
|
||||
if not is_param_sharded and not self.keep_unshard:
|
||||
# We use ZeRO-2 here
|
||||
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
|
||||
# But we only have updated fp32 param shard here
|
||||
@@ -342,7 +350,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
p.colo_attr.sharded_data_tensor.reset_payload(
|
||||
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
|
||||
|
||||
if not is_param_sharded:
|
||||
if not is_param_sharded and not self.keep_unshard:
|
||||
# We gather full fp16 param here
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
p.data = p.colo_attr.sharded_data_tensor.payload
|
||||
|
Reference in New Issue
Block a user