[refactor] remove old zero code (#517)

This commit is contained in:
Jiarui Fang 2022-03-25 14:54:39 +08:00 committed by GitHub
parent 6a3f9fda83
commit 4d322b79da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 33 additions and 2978 deletions

View File

@ -12,7 +12,6 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ShardedModel, ShardedOptimizer
from colossalai.zero.sharded_model import ShardedModelV2
from ._base_schedule import BaseSchedule
@ -92,8 +91,6 @@ class PipelineSchedule(BaseSchedule):
def pre_processing(self, engine):
# TODO: remove this after testing new zero with pipeline parallelism
if isinstance(engine.optimizer, ShardedOptimizer) or isinstance(engine.model, ShardedModel):
raise TypeError("Pipeline schedule is currently not compatible with ZeRO")
model = engine.model
if isinstance(model, (NaiveAMPModel, ShardedModelV2)):
self.dtype = torch.half

View File

@ -2,14 +2,9 @@ from typing import Tuple
import torch
import torch.nn as nn
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.logging import get_dist_logger
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
from torch.optim import Optimizer
from .sharded_model import ShardedModel
from .sharded_optim import ShardedOptimizer
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
@ -40,35 +35,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
return zero_model, zero_optimizer
def convert_to_zero(model: nn.Module, optimizer: Optimizer, level: int, zero_config: dict):
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
:param model: Your model object
:type model: :class:`torch.nn.Module`
:param optimizer: Your optimizer object
:type optimizer: :class:`torch.optim.Optimizer`
:param level: Optimizer level, can be 2 or 3
:type level: int
:param zero_config: Configuration for zero
:type zero_config: dict
:return: (model, optimizer)
:rtype: Tuple
"""
assert 1 <= level <= 3, 'Only ZERO Optimizer Level 1-3 are provided'
if level in [1, 2]:
if level == 2:
if 'partition_grad' in zero_config:
assert zero_config['partition_grad'], \
'Sharded Optimizer requires partition_grad to be True'
else:
zero_config['partiton_grad'] = True
model = NaiveAMPModel(model, output_to_fp32=True)
optimizer = ShardedOptimizer(optimizer, **zero_config)
else:
model = ShardedModel(module=model, **zero_config)
return model, optimizer
__all__ = ['convert_to_zero', 'ShardedModel', 'ShardedOptimizer']
__all__ = ['convert_to_zerov2', 'ShardedModelV2', 'ShardedOptimizerV2']

View File

@ -8,7 +8,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup
from colossalai.logging import get_dist_logger, disable_existing_loggers

View File

@ -0,0 +1,20 @@
import torch
import torch.nn.functional as F
from typing import Tuple
def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks = list(torch.flatten(tensor).chunk(world_size))
while len(chunks) < world_size:
chunks.append(chunks[0].new_empty(0))
# Determine number of padding elements.
num_to_pad = chunks[0].numel() - chunks[rank].numel()
assert num_to_pad >= 0, num_to_pad
shard = chunks[rank].clone()
if num_to_pad > 0:
shard = F.pad(shard, [0, num_to_pad])
return shard, num_to_pad

View File

@ -4,7 +4,7 @@ import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.zero.shard_utils.commons import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor

View File

@ -1,4 +1,3 @@
from .sharded_model import ShardedModel
from .sharded_model_v2 import ShardedModelV2
__all__ = ['ShardedModel', 'ShardedModelV2']
__all__ = ['ShardedModelV2']

View File

@ -1,5 +1,4 @@
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, List, Tuple
import torch
import torch.nn.functional as F
@ -12,23 +11,6 @@ def get_gradient_predivide_factor(world_size: int) -> float:
return float(factor)
def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks = list(torch.flatten(tensor).chunk(world_size))
while len(chunks) < world_size:
chunks.append(chunks[0].new_empty(0))
# Determine number of padding elements.
num_to_pad = chunks[0].numel() - chunks[rank].numel()
assert num_to_pad >= 0, num_to_pad
shard = chunks[rank].clone()
if num_to_pad > 0:
shard = F.pad(shard, [0, num_to_pad])
return shard, num_to_pad
def free_storage(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor."""
if data.storage().size() > 0:
@ -86,31 +68,3 @@ def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
if len(chunks) < num_chunks:
chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))])
return chunks
def assert_in_engine(cond: Any, s: Any) -> None:
"""Used in backward context to make sure error is printed."""
if not cond:
print(s)
raise AssertionError
def replace_state_dict_prefix(state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
old_prefix: str, new_prefix: str) -> None:
"""
Replace all keys that match a given old_prefix with a new_prefix (in-place).
Usage::
state_dict = {"layer.xyz": torch.tensor(1)}
replace_state_dict_prefix(state_dict, "layer.", "module.layer.")
assert state_dict == {"module.layer.xyz": torch.tensor(1)}
"""
if old_prefix == new_prefix:
raise ValueError("old_prefix and new_prefix must be distinct")
for key in list(state_dict.keys()):
if not key.startswith(old_prefix):
continue
new_key = new_prefix + key[len(old_prefix):]
state_dict[new_key] = state_dict[key]
del state_dict[key]

View File

@ -1,385 +0,0 @@
import os
from typing import Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from ._zero3_utils import alloc_storage, free_storage, get_shard
# TODO: Remove the toggle-enable_nccl_base_collectives in the future
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
enable_nccl_base_collectives = False
else:
enable_nccl_base_collectives = True
# TODO: add flatten params
class Zero3ParameterManager:
def __init__(self,
module: nn.Module,
process_group: Optional[ProcessGroup],
mixed_precision: bool = False,
flatten_parameters: bool = True,
compute_dtype: Optional[torch.dtype] = None,
compute_device: Optional[torch.device] = None,
offload_config: Optional[dict] = None
) -> None:
"""Manage parameter shards. We manage several attributes on each Parameter instance:
``zero_is_sharded``: ``True`` if the Parameter is sharded or ``False``
if the Parameter is intentionally not sharded (in which case we
will all-reduce grads for this param).
``zero_orig_size``: the size of the original Parameter (before sharding)
``zero_shard_padding``: the padding size. All paddings are right padding.
``zero_fp32_shard``: a single shard of the parameters in full precision
(typically FP32, but this is dependent on the dtype of the model
as it's passed in by the user). This can be on CPU or GPU
depending on the value of *``offload_config``*.
``zero_fp16_shard``: This will be a single shard of the parameters in FP16, used for all-gather.
This can be in FP16 or FP32 depending on the value of *``compute_dtype``* and
if params are offloaded to CPU.
``zero_full_param_padded``: the full weight (padded to be evenly
divisible by ``world_size``), used for computation in the
forward and backward pass. This will be resized in place and
only materialized (via all-gather) as needed.
``zero_cpu_grad``: the gradient saved on CPU. It's set only when using CPU offload.
:param module: original module
:type module: nn.Module
:param process_group: typically data parallel process group, defaults to None
:type process_group: Optional[ProcessGroup], optional
:param mixed_precision: whether to use mixed precision mode, defaults to False
:type mixed_precision: bool, optional
:param flatten_parameters: whether to flatten parameters, useless now, defaults to True
:type flatten_parameters: bool, optional
:param compute_dtype: the dtype of parameters when computing, defaults to None
:type compute_dtype: Optional[torch.dtype], optional
:param compute_device: the device of parameters when computing, defaults to None
:type compute_device: Optional[torch.device], optional
:param offload_config: offload config, defaults to None
:type offload_config: Optional[dict], optional
"""
self.process_group = process_group
self.shard_idx = process_group.rank()
self.num_shards = process_group.size()
self.mixed_precision = mixed_precision
self.compute_dtype = compute_dtype
self.compute_device = compute_device
self.offload_config = offload_config
self._cpu_offload = offload_config.get('device', None) == 'cpu' if offload_config else False
self.params: List[Parameter] = []
for param in module.parameters():
if not hasattr(param, 'zero_is_sharded'):
self.params.append(param)
self._has_params = len(self.params) > 0
self._has_sharded_params = False
# Flag to indicate if the full params are gathered.
self.has_full_params: bool = False
self._shard_params()
# Maybe no need, reserve to prevent bugs
# self.delete_fp32_shards()
self._streams: Dict[str, torch.cuda.Stream] = {}
def _shard_params(self) -> None:
for p in self.params:
assert not hasattr(p, "zero_is_sharded")
assert p.is_floating_point()
if self.mixed_precision:
assert p.dtype == torch.float32
# If world_size is 1, then we all-reduce grads instead of sharding.
p.zero_is_sharded = self.num_shards > 1
p.zero_orig_size = p.data.size()
if not p.zero_is_sharded:
p.zero_shard_padding = 0
continue
# Replace p.data with the relevant shard.
orig_data = p.data
p.data, p.zero_shard_padding = get_shard(p.data, self.shard_idx, self.num_shards)
free_storage(orig_data)
@torch.no_grad()
def reset_param_attr(self, p: Parameter, training: bool) -> None:
"""This should be called by ``ZeroRedundancyLevel3Model._lazy_init()``
"""
assert hasattr(p, 'zero_is_sharded') and hasattr(p, 'zero_orig_size')
if hasattr(p, 'zero_fp32_shard'):
return
# A single shard of the parameters in full precision.
p.zero_fp32_shard = p.data
if self.mixed_precision:
assert p.zero_fp32_shard.dtype == torch.float32
if self._cpu_offload:
assert p.zero_fp32_shard.device == torch.device('cpu')
# If we plan to keep the FP32 parameters on CPU, then pinning
# memory allows us to later use non-blocking transfers when moving
# the FP32 param shard to compute_device.
p.zero_fp32_shard = p.zero_fp32_shard.pin_memory()
p.data = p.zero_fp32_shard
if self.mixed_precision or self._cpu_offload:
# In mixed precision mode, we maintain a reduced precision
# (typically FP16) parameter shard on compute_device for performing
# the computation in the forward/backward pass. We resize the
# storage to size 0 at init (here) and re-materialize (by copying
# from _fp32_shard) as needed. If offloading params to CPU, the
# dtype of the fp16 shard will depend on the *`compute_dtype`*.
p.zero_fp16_shard = torch.zeros_like(
p.zero_fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
free_storage(p.zero_fp16_shard)
if self.mixed_precision:
assert p.zero_fp32_shard.dtype == torch.float32
if not self.mixed_precision and not self._cpu_offload:
# use _fp32_shard if you are not in using mixed precision or
# offloading params and grads to CPU.
p.zero_fp16_shard = None
# We also maintain a full-sized parameter of type self.compute_dtype
# (FP16 for mixed_precision or FP32 otherwise). We resize the
# storage to size 0 at init (here) and only materialize as needed. The
# storage may contain padding elements so that it is evenly divisible by
# world_size, although these padding elements will be removed before the
# relevant computation.
if p.zero_is_sharded:
p.zero_full_param_padded = torch.zeros(
p.data.numel() * self.num_shards, device=self.compute_device, dtype=self.compute_dtype
)
free_storage(p.zero_full_param_padded)
if self._cpu_offload and training:
p.zero_cpu_grad = torch.zeros_like(p.data, device='cpu').pin_memory()
def setup_streams(self, streams):
self._streams = streams
@torch.no_grad()
def rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
"""
Gather all shards of params.
Note, this is idempotent if full params are already gathered. Callers
assume the idempotency. So please keep it that way.
Args:
force_full_precision (bool, Optional): by default params will be gathered
in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
``True``, in which case they will be gathered in full precision
(e.g., FP32), possibly in fresh storage. The parameter that's being
rebuilt will end up in full precision as well.
Returns:
A list of tuples, where the first element is the full-sized param
and the second element is a bool indicating if it's safe for the
caller to free the full-sized param. This will be ``None`` if
``force_full_precision=False`` and the full params are already gathered.
"""
# Store tensor and free flag
output_tensors: List[Tuple[torch.Tensor, bool]] = []
def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
"""
Helper function to update p.data pointer.
Args:
custom_output_tensor (torch.Tensor, Optional): if not None, this
tensor contains the data we just gathered.
"""
if custom_output_tensor is not None:
assert p.zero_is_sharded
p.data = custom_output_tensor
output_tensors.append((p.data, True))
elif not p.zero_is_sharded:
if (self.mixed_precision or self._cpu_offload) and not force_full_precision:
assert p.zero_fp16_shard is not None
p.data = p.zero_fp16_shard
output_tensors.append((p.data, True))
else:
# Here p.data == p._fp32_shard, so it's not safe to free.
output_tensors.append((p.data, False))
else:
p.data = p.zero_full_param_padded
output_tensors.append((p.data, True))
# Trim any padding and reshape to match original size.
p.data = p.data[: p.zero_orig_size.numel()].view(p.zero_orig_size)
if self._has_sharded_params:
# self.has_full_params flag can be out of sync if a shared param is
# sharded by another ZeroRedundancyLevel3Model instance. An example is that in eval case
# with reshard_after_forward=False but the sharing instance has
# reshard_after_forward=True. Then, on the second forward, the
# other instance can shard the shared param and but this instance
# can mistakenly think the full param is already gathered from the
# has_full_params flag.
#
# Therefore, we update the flag accordingly here.
self.has_full_params = not any(p.zero_full_param_padded.storage().size() == 0 for p in self.params)
# Early exit if we already have full params and don't need full precision.
if self.has_full_params and not force_full_precision:
for p in self.params:
update_p_data()
return output_tensors
self.has_full_params = True
with torch.cuda.stream(self._streams["all_gather"]):
if (self.mixed_precision or self._cpu_offload) and not force_full_precision:
self.use_fp16_shards()
if self._cpu_offload and force_full_precision:
# If the compute_dtype and storage dtype are the same,
# use pinned memory. Otherwise move p.data to the compute
# device.
if self.params[0].dtype == self.compute_dtype:
self.use_fp16_shards()
else:
for p in self.params:
p.data = p.data.to(self.compute_device)
for p in self.params:
if not p.zero_is_sharded: # e.g., when world_size == 1
update_p_data()
else:
# Skip if already built. Only shared param can be rebuilt multiple times.
# A corner case is p.zero_orig_size = (1,), which means the shape equality is
# not a perfect check. But we assume we don't share a param with shape (1,).
# if p.data.shape == p.zero_orig_size and hasattr(p, "zero_is_shared") and p.zero_is_shared:
# continue
# If self._cpu_offload and force_full_precision, we need to cast
# the FP32 CPU param to CUDA for the all-gather.
p_data = p.data.to(p.zero_full_param_padded.device, non_blocking=True)
p_size = p.zero_full_param_padded.size()
assert p_size.numel() % self.num_shards == 0
if self.mixed_precision and force_full_precision:
# Allocate fresh tensor in full precision since we are in
# mixed precision and full precision rebuild is asked.
output_tensor = p_data.new_zeros(p_size)
else:
if p.zero_full_param_padded.storage().size() != p_size.numel():
# Allocate based on full size from all shards.
alloc_storage(p.zero_full_param_padded, size=p_size)
output_tensor = p.zero_full_param_padded
# Fill output_tensor with (p.data for each shard in self.world_size)
if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives:
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
dist._all_gather_base(output_tensor, p_data, group=self.process_group)
else:
chunks = list(output_tensor.chunk(self.num_shards))
dist.all_gather(chunks, p_data, group=self.process_group)
# Set p.data = output_tensor (with padding trimmed)
update_p_data(output_tensor)
if (self.mixed_precision or self._cpu_offload) and not force_full_precision:
self.free_fp16_shards([p])
if self._cpu_offload and (self.params[0].dtype == self.compute_dtype):
self.free_fp16_shards([p])
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
return output_tensors
@torch.no_grad()
def use_full_params(self) -> None:
"""
Switch p.data pointers to use the full params.
Note: this assumes full params are already gathered.
Note: this might be called after full_params is already in used. So please
make sure it is idempotent in that case.
"""
assert self.has_full_params
for p in self.params:
if not p.zero_is_sharded:
if self.mixed_precision or self._cpu_offload:
assert p.zero_fp16_shard is not None
assert p.zero_fp16_shard.storage().size() != 0
p.data = p.zero_fp16_shard
else:
assert p.zero_full_param_padded.storage().size() != 0, f"{p.zero_orig_size} {id(self)}"
p.data = p.zero_full_param_padded[: p.zero_orig_size.numel()].view(p.zero_orig_size)
@torch.no_grad()
def use_fp16_shards(self, params: Optional[List[Parameter]] = None) -> None:
"""Cast FP32 param shard to FP16 for a list of params."""
if params is None:
params = self.params
with torch.cuda.stream(self._streams["fp32_to_fp16"]):
for p in params:
assert p.zero_fp16_shard is not None
alloc_storage(p.zero_fp16_shard, size=p.zero_fp32_shard.size())
p.zero_fp16_shard.copy_(
# If _cpu_offload is True, this will be non-blocking
# because _fp32_shard is pinned, otherwise it's a no-op.
p.zero_fp32_shard.to(p.zero_fp16_shard.device, non_blocking=True)
)
p.data = p.zero_fp16_shard
torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])
@torch.no_grad()
def use_fp32_shards(self, params: Optional[List[Parameter]] = None) -> None:
"""Use FP32 shard for a list of params."""
if params is None:
params = self.params
for p in params:
p.data = p.zero_fp32_shard
@torch.no_grad()
def free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
"""Free up storage for full parameters."""
if params is None:
params = self.params
self.has_full_params = False
current_stream = torch.cuda.current_stream()
for p in params:
if not p.zero_is_sharded: # e.g., world_size == 1
if self.mixed_precision or self._cpu_offload:
self.free_fp16_shards([p])
continue
# Don't let PyTorch reuse this memory until all work in the current
# stream is complete.
p.zero_full_param_padded.record_stream(current_stream)
# There may be external references to the Tensor Storage that we
# can't modify, such as references that are created by
# ctx.save_for_backward in the forward pass. Thus when we
# unshard parameters, we should reuse the original Tensor
# Storage object and unshard it in-place. For now, just resize
# the Storage to 0 to save memory.
free_storage(p.zero_full_param_padded)
@torch.no_grad()
def free_fp16_shards(self, params: Optional[List[Parameter]] = None) -> None:
"""Free storage for FP16 shards for a list of params."""
if params is None:
params = self.params
current_stream = torch.cuda.current_stream()
for p in params:
if p.zero_fp16_shard is not None:
# zero_fp16_shard is allocated in "fp32_to_fp16" stream, so we can't
# free it until the work in the current stream completes.
p.zero_fp16_shard.record_stream(current_stream)
free_storage(p.zero_fp16_shard)
def delete_fp32_shards(self) -> None:
for p in self.params:
if hasattr(p, 'zero_fp32_shard'):
del p.zero_fp32_shard # reset _init_param_attr

View File

@ -1,85 +0,0 @@
from typing import Optional
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
class ShardedGradient:
def __init__(self,
param: Parameter,
sharded_module: nn.Module,
offload_config: Optional[dict] = None
) -> None:
assert hasattr(
param, 'ca_attr') and param.ca_attr.is_sharded, 'ShardedGradient can only be initialized with sharded parameter'
self.param = param
self.sharded_module = sharded_module
self.offload_config = offload_config
self._cpu_offload = offload_config.get('device', None) == 'cpu' if offload_config else False
# _gpu_grad is either sharded or not
# all saved grads are fp32
self._gpu_grad: Optional[torch.Tensor] = None
self._cpu_grad: Optional[torch.Tensor] = None
if self._cpu_offload:
# this buffer will be held and reused every iteration
self._cpu_grad = torch.zeros(param.ca_attr.payload('cpu'), dtype=torch.float).pin_memory()
@torch.no_grad()
def setup(self) -> None:
"""This function will be called pre-backward. Save the local accumulated gradient to _gpu_grad.
When no_sync() is enable (_require_backward_grad_sync=False), the grad is accumulated locally in param.grad
:raises AssertionError: Raise if grad shape is wrong
"""
if self.sharded_module._require_backward_grad_sync and self.param.grad is not None:
if self.param.grad.device != self.param.data.device:
# TODO: offload?
raise RuntimeError(
'grad and param are on different device, grad {self.param.grad.device} vs. param {self.param.data.device}')
else:
self._gpu_grad = self.param.grad.data
self.param.grad = None
def reduce_scatter_callback(self, reduced_grad: torch.Tensor) -> None:
"""This function will be called in post-backward hook, so we cannot modify param.grad directly
:param reduced_grad: the reduced grad
:type reduced_grad: torch.Tensor
"""
# Make sure we store fp32 grad
if torch.is_floating_point(reduced_grad) and reduced_grad.dtype != torch.float:
reduced_grad.data = reduced_grad.data.to(torch.float)
if self._gpu_grad is None:
self._gpu_grad = reduced_grad.data
else:
self._gpu_grad += reduced_grad.data
# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
# backwards pass completes, we will set `.grad` to the CPU copy.
if self._cpu_offload:
self._cpu_grad.copy_(self._gpu_grad.data, non_blocking=True)
# Don't let this memory get reused until after the transfer.
self._gpu_grad.data.record_stream(torch.cuda.current_stream())
@torch.no_grad()
def write_back(self) -> None:
"""This function will be called in final backward hook
"""
if self._cpu_grad is not None:
assert self.param.device == torch.device(
'cpu'), f'Incorrect param device, expected CPU, got {self.param.device}'
self.param.grad.data = self._cpu_grad
elif self._gpu_grad is not None:
assert self.param.device == self._gpu_grad.device, f'Incorrect _gpu_grad device, param on {self.param.device} but _gpu_grad on {self._gpu_grad.device}'
self.param.grad.data = self._gpu_grad
else:
raise RuntimeError('No grad to write back')
# If using CPU offload, _cpu_grad will store the CPU tensor of _gpu_grad
# They should be released here
self._gpu_grad = None

File diff suppressed because it is too large Load Diff

View File

@ -18,7 +18,7 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor)

View File

@ -1,4 +1,3 @@
from .sharded_optim import ShardedOptimizer
from .sharded_optim_v2 import ShardedOptimizerV2
__all__ = ['ShardedOptimizer', 'ShardedOptimizerV2']
__all__ = ['ShardedOptimizerV2']

View File

@ -1,6 +0,0 @@
from .gradient_store import GradientStore
from .parameter_store import ParameterStore
from .bucket_store import BucketStore
from .tensor_bucket import TensorBucket
__all__ = ['GradientStore', 'ParameterStore', 'BucketStore', 'TensorBucket']

View File

@ -1,17 +0,0 @@
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
class BaseStore:
def __init__(self, dp_parallel_mode=ParallelMode.DATA):
self._world_size = gpc.get_world_size(dp_parallel_mode)
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
@property
def world_size(self):
return self._world_size
@property
def local_rank(self):
return self._local_rank

View File

@ -1,43 +0,0 @@
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from .base_store import BaseStore
class BucketStore(BaseStore):
def __init__(self, dp_parallel_mode):
super().__init__(dp_parallel_mode)
self._grads = dict()
self._params = dict()
self._num_elements_in_bucket = dict()
self.reset()
def num_elements_in_bucket(self, reduce_rank: int = None):
return self._num_elements_in_bucket[reduce_rank]
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
self._num_elements_in_bucket[reduce_rank] += num_elements
def add_grad(self, tensor, reduce_rank: int = None):
self._grads[reduce_rank].append(tensor)
def add_param(self, tensor, reduce_rank: int = None):
self._params[reduce_rank].append(tensor)
def reset(self):
keys = [None] + list(range(self._world_size))
self._grads = {rank: [] for rank in keys}
self._params = {rank: [] for rank in keys}
self._num_elements_in_bucket = {rank: 0 for rank in keys}
def reset_by_rank(self, reduce_rank=None):
self._grads[reduce_rank] = []
self._params[reduce_rank] = []
self._num_elements_in_bucket[reduce_rank] = 0
def get_grad(self, reduce_rank: int = None):
return self._grads[reduce_rank]
def get_param(self, reduce_rank: int = None):
return self._params[reduce_rank]

View File

@ -1,66 +0,0 @@
from typing import List
from torch import Tensor
from .base_store import BaseStore
class GradientStore(BaseStore):
def __init__(self, *args):
super().__init__(*args)
# bookkeeping data structures
self._averaged_gradients = dict()
# for backward reduction hooks
self._grad_acc_objs = []
def add_accumulate_grad_object(self, obj):
"""
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
be attached successfully.
:param obj: An object of :class:`AccumulateGrad` class
:type obj: :class:`AccumulateGrad`
"""
self._grad_acc_objs.append(obj)
def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
"""
Return average gradients of a parameter group
:param group_id: The index of parameter group
:type group_id: int
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
:rtype: List[torch.Tensor]
"""
return self._averaged_gradients[group_id]
def add_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
"""
Append an average gradient to the list of averaged gradients of a parameter group
:param group_id: The index of a parameter group
:param tensor: A :class:`torch.Tensor` object
:type group_id: int
:type tensor: torch.Tensor
"""
if group_id in self._averaged_gradients:
self._averaged_gradients[group_id].append(tensor)
else:
self._averaged_gradients[group_id] = [tensor]
def reset_average_gradients_by_group(self, group_id: int) -> None:
"""
Reset the bookkeeping data structure for averaged gradients to an empty list
:param group_id: The index of a parameter group
:type group_id: int
"""
self._averaged_gradients[group_id] = []

View File

@ -1,96 +0,0 @@
from .base_store import BaseStore
from torch import Tensor
from typing import List
class ParameterStore(BaseStore):
def __init__(self, dp_paralle_mode):
super().__init__(dp_paralle_mode)
# param partitioning data structures
self._fp16_param_to_rank = dict()
self._rank_groupid_to_fp16_param_list = dict()
self._rank_group_id_to_flat_fp16_param = dict()
# param reduction data structures
self._is_param_reduced = dict()
self._reduced_param = []
def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
"""
Set the mapping between parameter to rank, each parameter should be owned by a rank.
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
:param rank: The rank of which the process is responsible for updating the parameter
:type rank: int
"""
self._fp16_param_to_rank[tensor] = rank
def get_param_rank(self, tensor: Tensor) -> int:
"""
Gives the rank which the parameter belongs to
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
"""
return self._fp16_param_to_rank[tensor]
def belongs_to_current_rank(self, tensor) -> bool:
"""
Check whether a parameter is supposed to be updated by the process of the current rank
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
:return: True if the parameter should be updated by the current rank. Otherwise false.
:rtype: bool
"""
tensor_rank = self._fp16_param_to_rank[tensor]
return tensor_rank == self._local_rank
def add_fp16_param_list_by_rank_group(self, rank, group_id,
tensor_list) -> None:
if rank not in self._rank_groupid_to_fp16_param_list:
self._rank_groupid_to_fp16_param_list[rank] = dict()
if group_id not in self._rank_groupid_to_fp16_param_list[rank]:
self._rank_groupid_to_fp16_param_list[rank][group_id] = []
self._rank_groupid_to_fp16_param_list[rank][group_id].extend(
tensor_list)
def get_fp16_params_by_rank_group(self, rank, group_id) -> List[Tensor]:
return self._rank_groupid_to_fp16_param_list[rank][group_id]
def add_flat_fp16_param_by_rank_group(self, rank, group_id, tensor) -> None:
if rank not in self._rank_group_id_to_flat_fp16_param:
self._rank_group_id_to_flat_fp16_param[rank] = dict()
self._rank_group_id_to_flat_fp16_param[rank][group_id] = tensor
def get_flat_fp16_param_by_rank_group(self, rank, group_id) -> Tensor:
return self._rank_group_id_to_flat_fp16_param[rank][group_id]
def is_param_reduced(self, tensor):
return self._is_param_reduced[tensor]
def set_param_reduction_state(self, tensor, state):
self._is_param_reduced[tensor] = state
def get_param_reduction_states(self):
return self._is_param_reduced
def reset_previous_reduced_params(self):
self._reduced_param = []
def add_previous_reduced_param(self, tensor):
self._reduced_param.append(tensor)
def clear_grads_of_previous_reduced_params(self):
if len(self._reduced_param) > 0:
for param in self._reduced_param:
param.grad = None
self.reset_previous_reduced_params()

View File

@ -1,54 +0,0 @@
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
class TensorBucket:
def __init__(self, size):
self._max_size = size
self._current_size = 0
self._bucket = []
@property
def max_size(self):
return self._max_size
@property
def current_size(self):
return self._current_size
def is_full_or_oversized(self):
return self._current_size >= self._max_size
def is_empty(self):
return len(self._bucket) == 0
def add_to_bucket(self, tensor, allow_oversize=False):
tensor_size = tensor.numel()
if not allow_oversize and self.will_exceed_max_size(tensor_size):
msg = f"The param bucket max size {self._max_size} is exceeded" \
+ f"by tensor (size {tensor_size})"
raise RuntimeError(msg)
self._bucket.append(tensor)
self._current_size += tensor_size
def will_exceed_max_size(self, tensor_size):
expected_size = self._current_size + tensor_size
return expected_size > self._max_size
def get_bucket(self):
return self._bucket
def empty(self):
self._bucket = []
self._size = 0
def flatten(self):
return _flatten_dense_tensors(self._bucket)
def unflatten_and_copy(self, flat_tensor):
unflattened_tensor_list = _unflatten_dense_tensors(
flat_tensor, self._bucket)
for old, new in zip(self._bucket, unflattened_tensor_list):
old.copy_(new)

View File

@ -1,563 +0,0 @@
from colossalai.utils.cuda import get_current_device
import torch
import torch.distributed as dist
from colossalai.logging import get_dist_logger
from torch.optim import Optimizer
from .bookkeeping import ParameterStore, GradientStore, BucketStore, TensorBucket
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.nn.optimizer import ColossalaiOptimizer
from ._utils import (move_tensor, flatten, get_grad_accumulate_object, split_half_float_double, reduce_tensor,
release_param_grad, calculate_global_norm_from_list, compute_norm, sync_param, has_inf_or_nan)
from functools import partial
class ShardedOptimizer(ColossalaiOptimizer):
def __init__(self,
optimizer: Optimizer,
initial_scale=2**32,
min_scale=1,
growth_factor=2,
backoff_factor=0.5,
growth_interval=1000,
hysteresis=2,
max_scale: int = 2**32,
clip_grad_norm=2.0,
verbose=False,
reduce_bucket_size=500000000,
communication_dtype=torch.float16,
overlap_communication=False,
partition_grad=False,
dp_parallel_mode=ParallelMode.DATA,
mp_parallel_mode=ParallelMode.MODEL,
cpu_offload=False,
cpu_fp16_param=False,
cpu_fp16_grad=False):
# TODO: add support for
# 1. fp16 master weights
# 2. contiguous gradients
# 3. cpu offload
# 4. support when some parameters requires_grad = False
self._optimizer = optimizer
self._dtype = self._optimizer.param_groups[0]['params'][0].dtype
self._logger = get_dist_logger()
self._verbose = verbose
# stage 2
self._partition_grads = partition_grad
# cpu_offload
self._cpu_offload = cpu_offload
self._cpu_fp16_param = cpu_fp16_param
self._cpu_fp16_grad = cpu_fp16_grad
# get process groups
self._dp_parallel_mode = dp_parallel_mode
self._mp_parallel_mode = mp_parallel_mode
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
self._world_size = gpc.get_world_size(dp_parallel_mode)
self._dp_group = gpc.get_group(dp_parallel_mode)
if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
self._mp_group = gpc.get_group(mp_parallel_mode)
else:
self._mp_group = None
# fp16 and fp32 params for mixed precision training
self._fp16_param_groups = dict()
self._fp32_flat_param_groups_of_current_rank = dict()
# communication params
self._overlap_communication = overlap_communication
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype
# gradient scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
verbose=verbose)
self._found_overflow = torch.FloatTensor([0]).to(get_current_device())
# gradient clipping
self._clip_grad_norm = clip_grad_norm
# check argument conflict
self._sanity_checks()
# ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training
self._param_store = ParameterStore(self._dp_parallel_mode)
self._grad_store = GradientStore(self._dp_parallel_mode)
self._bucket_store = BucketStore(self._dp_parallel_mode)
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
for group_id, param_group in enumerate(self._optimizer.param_groups):
params = param_group['params']
# add the fp16 params to fp16_param_groups for bookkeeping
self._fp16_param_groups[group_id] = params
# assign parameters to ranks
# the params in the list are sorted
params_per_rank = self._partition_param_list(params)
# store the mapping between param to rank
# each param should belong to only one rank
for rank, params in enumerate(params_per_rank):
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params)
for param in params:
self._param_store.set_param_to_rank(param, rank)
# move to cpu to make room to create the flat tensor
move_tensor(params, device='cpu')
# flatten the reordered tensors
for rank in range(self._world_size):
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
flat_tensor = flatten(tensor_list)
flat_tensor = flat_tensor.cuda()
self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor)
# sync parameters
for rank in range(self._world_size):
flat_tensor = self._param_store.get_flat_fp16_param_by_rank_group(rank, group_id)
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list)
# create a copy of fp32 weights of the parameters for which this rank is responsible
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(self._local_rank, group_id)
# when using cpu offload, our cpu adam support fp16 paramters
if self._cpu_fp16_param:
fp32_flat_current_rank = fp16_flat_current_rank.detach()
else:
fp32_flat_current_rank = fp16_flat_current_rank.detach().float()
device = 'cpu' if self._cpu_offload else get_current_device()
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
fp32_flat_current_rank.requires_grad = True
self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank
# need to replace the params in the `params` field in the optimizer
# so that when the optimizer calls step(), it only updates the tensors
# managed by this data parallel rank
param_group['params'] = [fp32_flat_current_rank]
# set reduction state
for param in self._fp16_param_groups[group_id]:
self._param_store.set_param_reduction_state(param, False)
# intialize communication stream for
# communication-compuation overlapping
if self._overlap_communication:
self._comm_stream = torch.cuda.Stream()
# reduction hook is only used if overlapping communication
# or stage 2 is used
# if it is stage 1 without overlapping, no hook will be attached
if self._overlap_communication or self._partition_grads:
self._attach_reduction_hook()
self._initialize_optimizer_states()
@property
def loss_scale(self):
return self.grad_scaler.scale
@property
def num_param_groups(self):
return len(self._fp16_param_groups)
def _partition_param_list(self, param_list):
params_per_rank = [[] for _ in range(self._world_size)]
numel_per_rank = [0 for _ in range(self._world_size)]
# partititon the parameters in a greedy fashion
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
for param in sorted_params:
# allocate this parameter to the rank with
# the smallest numel for load balancing purpose
rank_to_go = numel_per_rank.index(min(numel_per_rank))
params_per_rank[rank_to_go].append(param)
numel_per_rank[rank_to_go] += param.numel()
if self._verbose:
self._logger.info(f'Number of elements on ranks: {numel_per_rank}',
ranks=[0],
parallel_mode=self._dp_parallel_mode)
return params_per_rank
def _initialize_optimizer_states(self):
# create a dummy zero tensor which has the same shape as that of the param
# set this dummpy zero tensor as grad
for group_id in range(len(self._fp32_flat_param_groups_of_current_rank)):
fp32_partition_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp32_partition_grad = torch.zeros_like(fp32_partition_param)
fp32_partition_param.grad = fp32_partition_grad
# update the parameter with zero gradients for initialization of optimizer stateus
self._optimizer.step()
# remove the grad of the paramter to save memory
for group_id, fp32_flat_tensor in self._fp32_flat_param_groups_of_current_rank.items():
fp32_flat_tensor.grad = None
def _sanity_checks(self):
assert torch.cuda.is_available(), 'CUDA is required'
assert self._dtype == torch.float16, \
f'Parameters are expected to be of type torch.float16, but got {self._dtype}'
###########################################################
# Backward Reduction Hook
###########################################################
def _attach_reduction_hook(self):
# we iterate over the fp16 params
# on each param, we register a hook to its AccumulateGrad object
for group_id in range(self.num_param_groups):
param_group = self._fp16_param_groups[group_id]
for param in param_group:
if param.requires_grad:
# determines the reduction destionation rank
# this is only valid for stage 2
# dst_rank = None means using all-reduce
# else using reduce
if self._partition_grads:
reduce_rank = self._param_store.get_param_rank(param)
else:
reduce_rank = None
def _define_and_attach(param, reduce_rank):
# get the AccumulateGrad object of the param itself
accum_grad_obj = get_grad_accumulate_object(param)
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
reduction_func = partial(self._reduce_and_remove_grads_by_bucket,
param=param,
reduce_rank=reduce_rank)
# define hook
# NOT IMPORTANT BUT GOOD TO KNOW:
# args here is not grad, but allow_unreacable and accumulate_grad
def reduce_grad_hook(*args):
reduction_func()
accum_grad_obj.register_hook(reduce_grad_hook)
_define_and_attach(param, reduce_rank)
def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None):
param_size = param.numel()
# check if the bucket is full
# if full, will reduce the grads already in the bucket
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._reduce_grads_in_bucket(reduce_rank)
# the param must not be reduced to ensure correctness
is_param_reduced = self._param_store.is_param_reduced(param)
if is_param_reduced:
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
+ 'duplicate reduction will lead to arithmetic incorrectness'
raise RuntimeError(msg)
# the param must have grad for reduction
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
self._bucket_store.add_grad(param.grad, reduce_rank)
self._bucket_store.add_param(param, reduce_rank)
def _reduce_grads_in_bucket(self, reduce_rank=None):
# reduce grads
self._reduce_grads_by_rank(reduce_rank=reduce_rank,
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
# use communication stream if overlapping
# communication with computation
if self._overlap_communication:
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
for param in params_in_bucket:
# the is_param_reduced flag should be False showing that
# this param is not reduced before calling self._reduce_grads_by_rank
is_param_reduced = self._param_store.is_param_reduced(param)
if is_param_reduced:
msg = f'Parameter of size ({param.size()}) has been reduced, ' + \
'duplicate reduction will lead to arithmetic incorrectness'
raise RuntimeError(msg)
# update the flag
self._param_store.set_param_reduction_state(param, True)
# if partition grads = True
# we do not keep the gradient after reduction
if self._partition_grads and not self._param_store.belongs_to_current_rank(param):
if self._overlap_communication:
# we need to keep this gradient for now as reduction may
# be completed yet since it is using a different cuda stream
self._param_store.add_previous_reduced_param(param)
else:
param.grad = None
self._bucket_store.reset_by_rank(reduce_rank)
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
grad_buckets_by_dtype = split_half_float_double(grads)
for tensor_list in grad_buckets_by_dtype:
self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank)
##############################
# Reduction Utility Function #
##############################
def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank):
param_bucket = TensorBucket(size=bucket_size)
for tensor in tensor_list:
param_bucket.add_to_bucket(tensor, allow_oversize=True)
if param_bucket.is_full_or_oversized():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
param_bucket.empty()
if not param_bucket.is_empty():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
flat = bucket.flatten()
reduced_flat = reduce_tensor(tensor=flat,
dtype=self._communication_dtype,
dst_rank=reduce_rank,
parallel_mode=self._dp_parallel_mode)
# update the reduced tensor
if reduce_rank is None or reduce_rank == self._local_rank:
bucket.unflatten_and_copy(reduced_flat)
################################
# torch.optim.Optimizer methods
################################
def backward(self, loss, retain_graph=True):
loss = self.loss_scale * loss
loss.backward(retain_graph=retain_graph)
def zero_grad(self, set_to_none=True):
"""
Set parameter gradients to zero. If set_to_none = True, gradient
will be set to None to save memory.
:param set_to_none: Whether set the gradient to None. Default value is True.
:type set_to_none: bool
"""
for group_id, param_group in self._fp16_param_groups.items():
for param in param_group:
if set_to_none:
param.grad = None
else:
if param.grad is not None:
param.grad.detach()
param.grad.zero_()
####################
# Update Parameter #
####################
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'
# check for overflow
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)
# update loss scale if overflow occurs
if found_inf:
self._grad_store._averaged_gradients = dict()
self.zero_grad()
return
# copy the grad of fp16 param to fp32 param
single_grad_partition_groups = []
norm_groups = []
for group_id in range(self.num_param_groups):
# compute norm
norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id],
params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id,
rank=self._local_rank),
dp_group=self._dp_group,
mp_group=self._mp_group)
norm_groups.append(norm_group)
# create flat gradient for the flat fp32 params
fp16_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id)
flat_fp16_avg_grads = flatten(fp16_avg_grads)
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
assert param_shape == flat_fp32_avg_grads.shape, \
f'fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}'
single_grad_partition_groups.append(flat_fp32_avg_grads)
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
self._grad_store._averaged_gradients[group_id] = []
self._grad_store._averaged_gradients[group_id] = []
# unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm)
# update the parameters
self._optimizer.step()
# release the fp32 grad
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
# update fp16 partition updated by the current rank
for group_id in range(len(self._fp16_param_groups)):
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=self._local_rank, group_id=group_id)
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id].to(fp16_param.device)
fp16_param.data.copy_(fp32_param)
# broadcast the updated model weights
handles = []
for group_id in range(self.num_param_groups):
for rank in range(self._world_size):
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_group, async_op=True)
handles.append(handle)
for handle in handles:
handle.wait()
##################
# FP16 Utilities #
##################
def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(0.0)
# check for overflow
for group_id in range(len(self._fp16_param_groups)):
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
if avg_grad is not None and has_inf_or_nan(avg_grad):
self._found_overflow.fill_(1.0)
break
# all-reduce across dp group
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_group)
# all-reduce over model parallel group
if self._mp_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_group)
if self._found_overflow.item() > 0:
return True
else:
return False
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
# compute combined scale factor for this group
combined_scale = self.loss_scale
if self._clip_grad_norm > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm
if clip > 1:
combined_scale = clip * self.loss_scale
for grad in grad_groups_flat:
grad.data.mul_(1. / combined_scale)
############################
# Gradient Synchronization #
############################
def sync_grad(self):
if not self._partition_grads:
self._reduce_grad_stage1()
else:
# TODO: support async comm in reduce
self._reduce_grad_stage2()
# update param already reduced flag
reduction_states = self._param_store.get_param_reduction_states()
for tensor, state in reduction_states.items():
reduction_states[tensor] = False
# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
# accumulate gradient
avg_gradients = self._grad_store._averaged_gradients
for group_id in range(self.num_param_groups):
param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id)
if group_id not in avg_gradients:
avg_gradients[group_id] = []
param_idx = 0
for param in param_group:
if param.grad is not None:
if len(avg_gradients[group_id]) == param_idx:
avg_gradients[group_id].append(param.grad)
else:
avg_gradients[group_id][param_idx].add_(param.grad)
param_idx += 1
# the gradients needed are stored in the avg_gradients buffer
# thus, can clear this
self.zero_grad()
def _reduce_grad_stage1(self):
# if not overlapping communication (no reduction hook is attached)
# we need to manually reduce these gradients
if not self._overlap_communication:
for group_id in range(len(self._fp16_param_groups)):
param_group = self._fp16_param_groups[group_id]
for param in param_group:
if param.grad is not None:
self._reduce_and_remove_grads_by_bucket(param)
# we need to reduce the gradients
# left in the communication bucket
self._reduce_grads_in_bucket()
def _reduce_grad_stage2(self):
# when partition_grads is True, reduction hooks
# are attached in the __init__ function, so we
# only need to reduce the gradients
# left in the communication bucket
for reduce_rank in range(self._world_size):
self._reduce_grads_in_bucket(reduce_rank)

View File

@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter

View File

@ -7,9 +7,7 @@ import colossalai
import torch
from functools import partial
import torch.multiprocessing as mp
import pytest
def run_tensor_move(rank):

View File

@ -2,11 +2,9 @@
# -*- encoding: utf-8 -*-
import copy
import operator as op
from functools import partial, reduce
from typing import List
import colossalai
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
import pytest
import torch
import torch.distributed as dist
@ -14,10 +12,11 @@ import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.logging import disable_existing_loggers
from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
from colossalai.zero.sharded_model import ShardedModel
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from colossalai.testing import parameterize
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
from functools import partial
def checkpoint_wrapper(module, enable=True):
@ -97,41 +96,9 @@ def check_params(model, zero_model, loose=False):
assert allclose(p, zero_p, loose=loose)
@parameterize('checkpoint', [False, True])
@parameterize('fp16', [False, True])
@parameterize('offload', [False, True])
@parameterize('norm_type', [1.0, 2.0, float('inf')])
def check_config(checkpoint=False, fp16=False, offload=False, norm_type=2.0):
model = Net(checkpoint=checkpoint).cuda()
zero_model = copy.deepcopy(model)
ddp_model = DDP(model)
offload_config = {}
if offload:
offload_config['device'] = 'cpu'
zero_model = zero_model.cpu()
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(ddp_model, optimizer, x, enable_autocast=fp16, norm_type=norm_type)
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16, norm_type=norm_type)
check_grads(ddp_model, zero_model)
check_params(ddp_model, zero_model)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(ddp_model, optimizer, x, enable_autocast=False, norm_type=norm_type)
run_step(zero_model, zero_optimizer, x, enable_autocast=False, norm_type=norm_type)
check_grads(ddp_model, zero_model, loose=True)
check_params(ddp_model, zero_model, loose=True)
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_config()
@pytest.mark.dist

View File

@ -12,7 +12,7 @@ from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP

View File

@ -1,168 +0,0 @@
import torch
import colossalai
import copy
import pytest
import torch.multiprocessing as mp
from colossalai.zero import ShardedOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.utils import free_port
from functools import partial
from common import allclose
from tests.components_to_test.registry import non_distributed_component_funcs
def check_completely_equal(a, b):
"""
This function checks if two tensors are completely equal
"""
assert torch.all(a == b), f'a = {a}, b = {b}'
def check_sharded_param_consistency():
"""
In this test, we want to test whether zero stage 1 and 2
deliver the same numerical results despite different communication
pattern
we use these prefixes to differentiate the zero stage
oss: partition optimizer states
pg: partition gradients and optimizer states
"""
test_models = ['repeated_computed_layers', 'resnet18', 'nested_model']
for name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(name)
model_builder, train_dataloader, *_ = get_components_func()
# create model
oss_model = model_builder(checkpoint=True).cuda().half()
pg_model = copy.deepcopy(oss_model)
# create optimizer
oss_optimizer = torch.optim.Adam(oss_model.parameters(), lr=0.001)
pg_optimizer = torch.optim.Adam(pg_model.parameters(), lr=0.001)
oss_optimizer = ShardedOptimizer(oss_optimizer, overlap_communication=True, initial_scale=1, clip_grad_norm=0.0)
pg_optimizer = ShardedOptimizer(pg_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=1,
clip_grad_norm=0.0)
# create
data, label = next(iter(train_dataloader))
input_data = data.cuda().half()
# forward
oss_output = oss_model(input_data)
pg_output = pg_model(input_data)
check_completely_equal(oss_output, pg_output)
# backward
oss_optimizer.backward(oss_output.mean().float())
pg_optimizer.backward(pg_output.mean().float())
# check grad
# as this param is small, the backward reduction
# will not be fired
for oss_param, pg_param in zip(oss_model.parameters(), pg_model.parameters()):
check_completely_equal(oss_param.grad, pg_param.grad)
# step
oss_optimizer.sync_grad()
pg_optimizer.sync_grad()
# step
oss_optimizer.step()
pg_optimizer.step()
# check updated param
for oss_param, pg_param in zip(oss_model.parameters(), pg_model.parameters()):
check_completely_equal(oss_param, pg_param)
def check_sharded_optim_against_torch_ddp():
"""
In this test, two pairs of model and optimizers are created.
1. zero: use sharded optimizer and fp16 parameters
2. torch: use torch DDP and fp32 parameters
We feed these two sets of models with the same input and check if the
differences in model output and updated parameters are within tolerance.
"""
test_models = ['repeated_computed_layers', 'resnet18', 'nested_model']
for name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(name)
model_builder, train_dataloader, *_ = get_components_func()
# create model
zero_model = model_builder(checkpoint=True).cuda()
torch_model = copy.deepcopy(zero_model)
zero_model = zero_model.half()
torch_model = DDP(torch_model.cuda())
# create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001)
# we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results
zero_optimizer = ShardedOptimizer(zero_optimizer,
overlap_communication=True,
initial_scale=1,
clip_grad_norm=0.0)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001)
# create
input_data, _ = next(iter(train_dataloader))
input_data = input_data.cuda()
# zero-dp forward
zero_output = zero_model(input_data.half())
# torch-ddp forward
torch_output = torch_model(input_data)
allclose(zero_output, torch_output.half())
# zero-dp backward
zero_optimizer.backward(zero_output.mean().float())
# torch-ddp backward
torch_output.mean().backward()
# check grad
for oss_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
allclose(oss_param.grad, torch_param.grad.half())
# zero-dp step
zero_optimizer.sync_grad()
zero_optimizer.step()
# torch ddp step
torch_optimizer.step()
# check updated param
for oss_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
allclose(oss_param, torch_param.half())
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_sharded_optim_against_torch_ddp()
check_sharded_param_consistency()
@pytest.mark.dist
def test_sharded_optim():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_sharded_optim()

View File

@ -1,39 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import colossalai
from colossalai.zero.sharded_model.param_manager import Zero3ParameterManager
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from colossalai.utils import free_port
from common import CONFIG
def run_shard_shape_check(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = torch.nn.Linear(2, 4 * world_size)
gpc.init_parallel_groups()
Zero3ParameterManager(module=model,
process_group=gpc.get_group(ParallelMode.DATA),
offload_config=CONFIG.get('offload_param_config'))
assert (model.weight.numel() == 4 * 2)
assert (model.bias.numel() == 4)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2, 4])
def test_run_shard_shape(world_size):
run_func = partial(run_shard_shape_check, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_run_shard_shape(2)

View File

@ -1,19 +0,0 @@
import sys
from pathlib import Path
repo_path = Path(__file__).absolute().parents[2]
sys.path.append(str(repo_path))
try:
import model_zoo.vit.vision_transformer_from_config
except ImportError:
raise ImportError("model_zoo is not found, please check your path")
BATCH_SIZE = 8
IMG_SIZE = 32
PATCH_SIZE = 4
DIM = 512
NUM_ATTENTION_HEADS = 8
SUMMA_DIM = 2
NUM_CLASSES = 10
DEPTH = 6

View File

@ -1,99 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import torch.autograd
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss
from colossalai.utils import free_port, get_dataloader
from model_zoo.vit import vit_lite_depth7_patch4_32
from torchvision import transforms
from torchvision.datasets import CIFAR10
from components import *
CONFIG = dict(parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2d'),
),
fp16=dict(mode=None, ),
zero=dict(level=2))
def train_epoch(engine, train_dataloader):
engine.train()
accumulated_loss = 0
num_steps = len(train_dataloader)
data_iter = iter(train_dataloader)
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
accumulated_loss += loss.detach().cpu().numpy()
avg_loss = accumulated_loss / num_steps
return avg_loss
def run_2d_parallel_vision_transformer_level_2(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# build model
model = vit_lite_depth7_patch4_32()
# build dataloader# build dataloaders
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose([
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]))
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
batch_size=BATCH_SIZE,
pin_memory=True,
drop_last=True)
# build optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = CrossEntropyLoss()
engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader)
logger = get_dist_logger()
logger.info('start training')
engine.train()
for img, label in train_dataloader:
engine.zero_grad()
img = img.cuda()
label = label.cuda()
out = engine(img)
loss = engine.criterion(out, label)
engine.backward(loss)
engine.step()
break
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
@pytest.mark.skip(reason="This test should be refactored for the reconstructed zero")
def test_2d_vit_zero_level_2():
world_size = 8
run_func = partial(run_2d_parallel_vision_transformer_level_2, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_2d_vit_zero_level_2()

View File

@ -1,99 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import torch.autograd
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss
from colossalai.utils import free_port, get_dataloader
from model_zoo.vit import vit_lite_depth7_patch4_32
from torchvision import transforms
from torchvision.datasets import CIFAR10
from components import *
CONFIG = dict(parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2d'),
),
fp16=dict(mode=None, ),
zero=dict(level=3))
def train_epoch(engine, train_dataloader):
engine.train()
accumulated_loss = 0
num_steps = len(train_dataloader)
data_iter = iter(train_dataloader)
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
accumulated_loss += loss.detach().cpu().numpy()
avg_loss = accumulated_loss / num_steps
return avg_loss
def run_2d_parallel_vision_transformer_level_3(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# build model
model = vit_lite_depth7_patch4_32()
# build dataloader# build dataloaders
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose([
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]))
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
batch_size=BATCH_SIZE,
pin_memory=True,
drop_last=True)
# build optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = CrossEntropyLoss()
engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader)
logger = get_dist_logger()
logger.info('start training')
engine.train()
for img, label in train_dataloader:
engine.zero_grad()
img = img.cuda()
label = label.cuda()
out = engine(img)
loss = engine.criterion(out, label)
engine.backward(loss)
engine.step()
break
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
@pytest.mark.skip(reason="This test should be refactored for the reconstructed zero")
def test_3d_vit_zero_level_3():
world_size = 8
run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_3d_vit_zero_level_3()