mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 14:12:02 +00:00
Feature/zero (#279)
* add zero1 (#209) * add zero1 * add test zero1 * update zero stage 1 develop (#212) * Implement naive zero3 (#240) * naive zero3 works well * add zero3 param manager * add TODOs in comments * add gather full param ctx * fix sub module streams * add offload * fix bugs of hook and add unit tests * fix bugs of hook and add unit tests (#252) * add gather full param ctx * fix sub module streams * add offload * fix bugs of hook and add unit tests * polish code and add state dict hook * fix bug * update unit test * refactor reconstructed zero code * clip_grad support zero3 and add unit test * add unit test for Zero3ParameterManager * [WIP] initialize the shard param class * [WIP] Yet another sharded model implementation (#274) * [WIP] initialize the shard param class * [WIP] Yes another implementation of shardModel. Using a better hook method. * torch.concat -> torch.cat * fix test_zero_level_1.py::test_zero_level_1 unitest * remove deepspeed implementation and refactor for the reconstructed zero module * polish zero dp unittests Co-authored-by: ver217 <lhx0217@gmail.com> Co-authored-by: Frank Lee <somerlee.9@gmail.com>
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
from distutils.command.config import config
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.utils import is_no_pp_or_last_stage
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
|
||||
from .zero_redundancy_optimizer_level_2 import ZeroRedundancyOptimizer_Level_2
|
||||
from .zero_redundancy_optimizer_level_3 import ZeroRedundancyOptimizer_Level_3
|
||||
from colossalai.core import global_context as gpc
|
||||
from torch.optim import Optimizer
|
||||
from .sharded_model import ShardedModel
|
||||
from .sharded_optim import ShardedOptimizer
|
||||
|
||||
|
||||
def convert_to_zero(model: nn.Module,
|
||||
@@ -29,82 +28,14 @@ def convert_to_zero(model: nn.Module,
|
||||
:return: (model, optimizer)
|
||||
:rtype: Tuple
|
||||
"""
|
||||
import deepspeed
|
||||
assert level == 2 or level == 3, 'Only ZERO Optimizer Level 2 and 3 are provided'
|
||||
model = NaiveAMPModel(model, output_to_fp32=False)
|
||||
|
||||
if level == 2:
|
||||
optimizer = ZeroRedundancyOptimizer_Level_2(init_optimizer=optimizer, **zero_config)
|
||||
assert 1 <= level <= 3, 'Only ZERO Optimizer Level 1-3 are provided'
|
||||
if level in [1, 2]:
|
||||
if level == 2:
|
||||
assert config['partition_grad'], 'ZeRO Optimizer requires partition_grad to be True'
|
||||
model = NaiveAMPModel(model, output_to_fp32=True)
|
||||
optimizer = ShardedOptimizer(model.parameters(), *zero_config)
|
||||
else:
|
||||
optimizer = ZeroRedundancyOptimizer_Level_3(init_optimizer=optimizer, module=model, **zero_config)
|
||||
model = ShardedModel(module=model, **zero_config)
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def zero3_model_context(dtype=torch.half):
|
||||
"""A context to enable massive model construction for training with
|
||||
ZeRO-3. Models are automatically partitioned (or, sharded) across the
|
||||
system and converted to half precision. Note that the config of ZeRO-3 will be loaded automatically from `gpc.config`.
|
||||
|
||||
Args:
|
||||
dtype (``dtype``, optional): Can be used to change the data type of the parameters.
|
||||
Supported options are ``torch.half`` and ``torch.float``. Defaults to ``torch.half``
|
||||
|
||||
This context accelerates model initialization and enables models that
|
||||
are too large to allocate in their entirety in CPU memory. It has the
|
||||
following effects:
|
||||
|
||||
#. allocates tensors to either GPU or CPU memory or NVMe
|
||||
#. converts floating point tensors to half precision
|
||||
#. immediately partitions tensors among the group of data-parallel devices
|
||||
#. (*optional*) replaces ``torch.nn.functional.linear`` with a more
|
||||
memory-efficient implementation
|
||||
|
||||
These modifications allow for models that exceed the size of local CPU/GPU
|
||||
memory/NVMe, but fit within the total NVMe capacity (*i.e.*, aggregate CPU
|
||||
or GPU memory or NVMe) across all nodes. Consider initializing a model with one
|
||||
trillion parameters, whose weights occupy two terabytes (TB) in half
|
||||
precision. The initial CPU allocation in full precision requires 4TB of
|
||||
memory *per process*, and so a system with 8 GPUs per node would need 32TB of
|
||||
CPU memory due to data-parallel redundancies. Instead, by immediately
|
||||
partitioning tensors we remove the redundancies. The result is that
|
||||
regardless of the number of GPUs, we still only require the original 4TB. This
|
||||
allows for a linear increase in model size with the aggregate system memory.
|
||||
For example, if a node has 1TB of memory and 8 GPUs, we could fit a trillion
|
||||
parameter model with 4 nodes and 32 GPUs.
|
||||
|
||||
Important: If the fp16 weights of the model can't fit onto a single GPU memory
|
||||
this feature must be used.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
#. Allocate a model and partition it among all processes:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
with zero3_model_context():
|
||||
model = MyLargeModel()
|
||||
|
||||
"""
|
||||
assert dtype == torch.half or dtype == torch.float, f'Invalid dtype, except torch.half or torch.float, got {dtype}'
|
||||
import deepspeed
|
||||
ds_config = {
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"zero_optimization": {
|
||||
"offload_param": getattr(gpc.config.zero, 'offload_param_config', None),
|
||||
"offload_optimizer": getattr(gpc.config.zero, 'offload_optimizer_config'),
|
||||
},
|
||||
"aio": getattr(gpc.config.zero, 'aio_config', None)
|
||||
}
|
||||
remote_device = getattr(ds_config['zero_optimization']['offload_param'], 'device', None)
|
||||
pin_memory = getattr(ds_config['zero_optimization']['offload_param'], 'pin_memory', False)
|
||||
return deepspeed.zero.Init(data_parallel_group=gpc.get_group(ParallelMode.DATA),
|
||||
remote_device=remote_device,
|
||||
config_dict_or_path=ds_config,
|
||||
pin_memory=pin_memory,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
__all__ = ['convert_to_zero', 'ZeroRedundancyOptimizer_Level_2',
|
||||
'ZeroRedundancyOptimizer_Level_3', 'zero3_model_context']
|
||||
__all__ = ['convert_to_zero', 'ShardedModel', 'ShardedOptimizer']
|
||||
|
Reference in New Issue
Block a user