diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 5d7087841..4401dd9d0 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -220,7 +220,9 @@ def initialize(model: Union[nn.Module, List[nn.Module]], # first sync model across dp ranks model.to(get_current_device()) - sync_model_param_in_dp(model) + use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3 + if not use_zero3: + sync_model_param_in_dp(model) # check amp and zero fp16_cfg = gpc.config.get('fp16', None) diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 8fe3dcab9..6b7619a1f 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -1,7 +1,10 @@ +import torch import torch.nn as nn from torch.optim import Optimizer from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.utils import is_no_pp_or_last_stage +from colossalai.core import global_context as gpc +from colossalai.context.parallel_mode import ParallelMode from .zero_redundancy_optimizer_level_2 import ZeroRedundancyOptimizer_Level_2 from .zero_redundancy_optimizer_level_3 import ZeroRedundancyOptimizer_Level_3 @@ -12,11 +15,11 @@ def convert_to_zero(model: nn.Module, level: int, zero_config): assert level == 2 or level == 3, 'Only ZERO Optimizer Level 2 and 3 are provided' - - if is_no_pp_or_last_stage(): - model = NaiveAMPModel(model, output_to_fp32=True) - else: - model = NaiveAMPModel(model, output_to_fp32=False) + if level == 2: + if is_no_pp_or_last_stage(): + model = NaiveAMPModel(model, output_to_fp32=True) + else: + model = NaiveAMPModel(model, output_to_fp32=False) if level == 2: optimizer = ZeroRedundancyOptimizer_Level_2(init_optimizer=optimizer, **zero_config) @@ -25,4 +28,71 @@ def convert_to_zero(model: nn.Module, return model, optimizer -__all__ = ['convert_to_zero', 'ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3'] +def zero3_model_context(dtype=torch.half): + """A context to enable massive model construction for training with + ZeRO-3. Models are automatically partitioned (or, sharded) across the + system and converted to half precision. Note that the config of ZeRO-3 will be loaded automatically from `gpc.config`. + + Args: + dtype (``dtype``, optional): Can be used to change the data type of the parameters. + Supported options are ``torch.half`` and ``torch.float``. Defaults to ``torch.half`` + + This context accelerates model initialization and enables models that + are too large to allocate in their entirety in CPU memory. It has the + following effects: + + #. allocates tensors to either GPU or CPU memory or NVMe + #. converts floating point tensors to half precision + #. immediately partitions tensors among the group of data-parallel devices + #. (*optional*) replaces ``torch.nn.functional.linear`` with a more + memory-efficient implementation + + These modifications allow for models that exceed the size of local CPU/GPU + memory/NVMe, but fit within the total NVMe capacity (*i.e.*, aggregate CPU + or GPU memory or NVMe) across all nodes. Consider initializing a model with one + trillion parameters, whose weights occupy two terabytes (TB) in half + precision. The initial CPU allocation in full precision requires 4TB of + memory *per process*, and so a system with 8 GPUs per node would need 32TB of + CPU memory due to data-parallel redundancies. Instead, by immediately + partitioning tensors we remove the redundancies. The result is that + regardless of the number of GPUs, we still only require the original 4TB. This + allows for a linear increase in model size with the aggregate system memory. + For example, if a node has 1TB of memory and 8 GPUs, we could fit a trillion + parameter model with 4 nodes and 32 GPUs. + + Important: If the fp16 weights of the model can't fit onto a single GPU memory + this feature must be used. + + Examples + -------- + + #. Allocate a model and partition it among all processes: + + .. code-block:: python + + with zero3_model_context(): + model = MyLargeModel() + + """ + assert dtype == torch.half or dtype == torch.float, f'Invalid dtype, except torch.half or torch.float, got {dtype}' + import deepspeed + ds_config = { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "zero_optimization": { + "offload_param": getattr(gpc.config.zero, 'offload_param_config', None), + "offload_optimizer": getattr(gpc.config.zero, 'offload_optimizer_config'), + }, + "aio": getattr(gpc.config.zero, 'aio_config', None) + } + remote_device = getattr(ds_config['zero_optimization']['offload_param'], 'device', None) + pin_memory = getattr(ds_config['zero_optimization']['offload_param'], 'pin_memory', False) + return deepspeed.zero.Init(data_parallel_group=gpc.get_group(ParallelMode.DATA), + remote_device=remote_device, + config_dict_or_path=ds_config, + pin_memory=pin_memory, + dtype=dtype) + + +__all__ = ['convert_to_zero', 'ZeroRedundancyOptimizer_Level_2', + 'ZeroRedundancyOptimizer_Level_3', 'zero3_model_context'] diff --git a/colossalai/zero/zero_redundancy_optimizer_level_3.py b/colossalai/zero/zero_redundancy_optimizer_level_3.py index cf281d6b4..32c5afba8 100644 --- a/colossalai/zero/zero_redundancy_optimizer_level_3.py +++ b/colossalai/zero/zero_redundancy_optimizer_level_3.py @@ -637,7 +637,8 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer): postscale_gradients=True, gradient_predivide_factor=1.0, gradient_accumulation_steps=1, - aio_config=None): + aio_config=None, + dtype=torch.half): # mpu = None # mpu is removed from the parameter list # tensor parallel will be automatically detected later @@ -682,13 +683,25 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer): util_ops = UtilsBuilder().load() self.flatten = util_ops.flatten self.unflatten = util_ops.unflatten - self.dtype = self.optimizer.param_groups[0]['params'][0].dtype + self.dtype = dtype if not all(is_zero_param(p) for p in module.parameters()): + ds_config = { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "zero_optimization": { + "offload_param": offload_param_config, + "offload_optimizer": offload_optimizer_config, + }, + "aio": aio_config + } + remote_device = offload_param_config['device'] group = None if gpc.is_initialized(ParallelMode.DATA): group = gpc.get_group(ParallelMode.DATA) - Init(module=module, data_parallel_group=group, dtype=self.dtype) + Init(module=module, data_parallel_group=group, dtype=self.dtype, + remote_device=remote_device, config_dict_or_path=ds_config, + pin_memory=offload_optimizer_config[OFFLOAD_OPTIMIZER_PIN_MEMORY]) for m in module.modules(): _init_external_params(m) diff --git a/docs/zero.md b/docs/zero.md index 1e2ee30d8..201223803 100644 --- a/docs/zero.md +++ b/docs/zero.md @@ -83,4 +83,13 @@ Note that `fp16` is automatically enabled when using ZeRO. This relies on `AMP_T ### Training +Note that if your model is too large to fit within the memory when using ZeRO-3, you should use `colossalai.zero.zero3_model_context` to construct your model: + +```python +from colossalai.zero import zero3_model_context + +with zero3_model_context(): + model = Model() +``` + Once you have completed your configuration, just use `colossalai.initialize()` to initialize your training. diff --git a/docs/zero_zh.md b/docs/zero_zh.md index 85a0f4562..df170dcb0 100644 --- a/docs/zero_zh.md +++ b/docs/zero_zh.md @@ -23,7 +23,7 @@ ZeRO优化器可以切分三种模型状态(优化器状态、梯度、参数 ) zero = dict( - type='ZeroRedundancyOptimizer_Level_3', + level=3, dynamic_loss_scale=True, clip_grad=1.0 ) @@ -78,4 +78,13 @@ ZeRO优化器可以切分三种模型状态(优化器状态、梯度、参数 ### 使用ZeRO优化器进行训练 +注意,当使用ZeRO-3时,如果您的模型过大以至于无法放入内存, 您应该使用`colossalai.zero.zero3_model_context`来构建您的模型: + +```python +from colossalai.zero import zero3_model_context + +with zero3_model_context(): + model = Model() +``` + 如果您完成了上述配置,可以运行`colossalai.initialize()`来开始您的训练。