From 4d322b79da72690d1cf46f88bf9d59f986e62040 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Fri, 25 Mar 2022 14:54:39 +0800 Subject: [PATCH] [refactor] remove old zero code (#517) --- .../engine/schedule/_pipeline_schedule.py | 3 - colossalai/zero/__init__.py | 38 +- colossalai/zero/init_ctx/init_context.py | 2 +- colossalai/zero/shard_utils/commons.py | 20 + .../zero/shard_utils/tensor_shard_strategy.py | 2 +- colossalai/zero/sharded_model/__init__.py | 3 +- .../{_zero3_utils.py => _utils.py} | 48 +- .../zero/sharded_model/param_manager.py | 385 ------ colossalai/zero/sharded_model/sharded_grad.py | 85 -- .../zero/sharded_model/sharded_model.py | 1104 ----------------- .../zero/sharded_model/sharded_model_v2.py | 4 +- colossalai/zero/sharded_optim/__init__.py | 3 +- .../sharded_optim/bookkeeping/__init__.py | 6 - .../sharded_optim/bookkeeping/base_store.py | 17 - .../sharded_optim/bookkeeping/bucket_store.py | 43 - .../bookkeeping/gradient_store.py | 66 - .../bookkeeping/parameter_store.py | 96 -- .../bookkeeping/tensor_bucket.py | 54 - .../zero/sharded_optim/sharded_optim.py | 563 --------- .../zero/sharded_optim/sharded_optim_v2.py | 2 +- tests/test_utils/test_commons.py | 2 - .../test_zero_gradient_clippling.py | 39 +- .../test_shard_model_v2.py | 2 +- .../test_sharded_optim.py | 168 --- .../test_zero_param_mgr.py | 39 - tests/test_zero_tensor_parallel/components.py | 19 - .../test_vit_2d_level_2.py | 99 -- .../test_vit_2d_level_3.py | 99 -- 28 files changed, 33 insertions(+), 2978 deletions(-) create mode 100644 colossalai/zero/shard_utils/commons.py rename colossalai/zero/sharded_model/{_zero3_utils.py => _utils.py} (58%) delete mode 100644 colossalai/zero/sharded_model/param_manager.py delete mode 100644 colossalai/zero/sharded_model/sharded_grad.py delete mode 100644 colossalai/zero/sharded_model/sharded_model.py delete mode 100644 colossalai/zero/sharded_optim/bookkeeping/__init__.py delete mode 100644 colossalai/zero/sharded_optim/bookkeeping/base_store.py delete mode 100644 colossalai/zero/sharded_optim/bookkeeping/bucket_store.py delete mode 100644 colossalai/zero/sharded_optim/bookkeeping/gradient_store.py delete mode 100644 colossalai/zero/sharded_optim/bookkeeping/parameter_store.py delete mode 100644 colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py delete mode 100644 colossalai/zero/sharded_optim/sharded_optim.py delete mode 100644 tests/test_zero_data_parallel/test_sharded_optim.py delete mode 100644 tests/test_zero_data_parallel/test_zero_param_mgr.py delete mode 100644 tests/test_zero_tensor_parallel/components.py delete mode 100644 tests/test_zero_tensor_parallel/test_vit_2d_level_2.py delete mode 100644 tests/test_zero_tensor_parallel/test_vit_2d_level_3.py diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index a65ec3275..82d99d6ff 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -12,7 +12,6 @@ from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.utils import switch_virtual_pipeline_parallel_rank from colossalai.utils.cuda import get_current_device -from colossalai.zero import ShardedModel, ShardedOptimizer from colossalai.zero.sharded_model import ShardedModelV2 from ._base_schedule import BaseSchedule @@ -92,8 +91,6 @@ class PipelineSchedule(BaseSchedule): def pre_processing(self, engine): # TODO: remove this after testing new zero with pipeline parallelism - if isinstance(engine.optimizer, ShardedOptimizer) or isinstance(engine.model, ShardedModel): - raise TypeError("Pipeline schedule is currently not compatible with ZeRO") model = engine.model if isinstance(model, (NaiveAMPModel, ShardedModelV2)): self.dtype = torch.half diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index b94bb370c..714474ea5 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -2,14 +2,9 @@ from typing import Tuple import torch import torch.nn as nn -from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.logging import get_dist_logger from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 -from torch.optim import Optimizer - -from .sharded_model import ShardedModel -from .sharded_optim import ShardedOptimizer def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config, @@ -40,35 +35,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model return zero_model, zero_optimizer -def convert_to_zero(model: nn.Module, optimizer: Optimizer, level: int, zero_config: dict): - """ - A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading - - :param model: Your model object - :type model: :class:`torch.nn.Module` - :param optimizer: Your optimizer object - :type optimizer: :class:`torch.optim.Optimizer` - :param level: Optimizer level, can be 2 or 3 - :type level: int - :param zero_config: Configuration for zero - :type zero_config: dict - - :return: (model, optimizer) - :rtype: Tuple - """ - assert 1 <= level <= 3, 'Only ZERO Optimizer Level 1-3 are provided' - if level in [1, 2]: - if level == 2: - if 'partition_grad' in zero_config: - assert zero_config['partition_grad'], \ - 'Sharded Optimizer requires partition_grad to be True' - else: - zero_config['partiton_grad'] = True - model = NaiveAMPModel(model, output_to_fp32=True) - optimizer = ShardedOptimizer(optimizer, **zero_config) - else: - model = ShardedModel(module=model, **zero_config) - return model, optimizer - - -__all__ = ['convert_to_zero', 'ShardedModel', 'ShardedOptimizer'] +__all__ = ['convert_to_zerov2', 'ShardedModelV2', 'ShardedOptimizerV2'] diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index bd765f1a6..9ff4a81c5 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -8,7 +8,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \ GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used from colossalai.zero.shard_utils import BaseShardStrategy -from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16 +from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_param import ShardedParamV2 from torch.distributed import ProcessGroup from colossalai.logging import get_dist_logger, disable_existing_loggers diff --git a/colossalai/zero/shard_utils/commons.py b/colossalai/zero/shard_utils/commons.py new file mode 100644 index 000000000..f24559644 --- /dev/null +++ b/colossalai/zero/shard_utils/commons.py @@ -0,0 +1,20 @@ +import torch +import torch.nn.functional as F +from typing import Tuple + + +def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]: + """Return the local shard of a full tensor.""" + # Shard using torch.chunk to match all-gather/reduce-scatter. + chunks = list(torch.flatten(tensor).chunk(world_size)) + while len(chunks) < world_size: + chunks.append(chunks[0].new_empty(0)) + + # Determine number of padding elements. + num_to_pad = chunks[0].numel() - chunks[rank].numel() + assert num_to_pad >= 0, num_to_pad + + shard = chunks[rank].clone() + if num_to_pad > 0: + shard = F.pad(shard, [0, num_to_pad]) + return shard, num_to_pad diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/shard_utils/tensor_shard_strategy.py index 7f2d2684e..31210a190 100644 --- a/colossalai/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/tensor_shard_strategy.py @@ -4,7 +4,7 @@ import torch import torch.distributed as dist from colossalai.utils import get_current_device from colossalai.zero.shard_utils import BaseShardStrategy -from colossalai.zero.sharded_model._zero3_utils import get_shard +from colossalai.zero.shard_utils.commons import get_shard from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor diff --git a/colossalai/zero/sharded_model/__init__.py b/colossalai/zero/sharded_model/__init__.py index dffd7f21a..725179295 100644 --- a/colossalai/zero/sharded_model/__init__.py +++ b/colossalai/zero/sharded_model/__init__.py @@ -1,4 +1,3 @@ -from .sharded_model import ShardedModel from .sharded_model_v2 import ShardedModelV2 -__all__ = ['ShardedModel', 'ShardedModelV2'] \ No newline at end of file +__all__ = ['ShardedModelV2'] \ No newline at end of file diff --git a/colossalai/zero/sharded_model/_zero3_utils.py b/colossalai/zero/sharded_model/_utils.py similarity index 58% rename from colossalai/zero/sharded_model/_zero3_utils.py rename to colossalai/zero/sharded_model/_utils.py index ab0b71b56..682a4ff1e 100644 --- a/colossalai/zero/sharded_model/_zero3_utils.py +++ b/colossalai/zero/sharded_model/_utils.py @@ -1,5 +1,4 @@ -from collections import OrderedDict -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, List, Tuple import torch import torch.nn.functional as F @@ -12,23 +11,6 @@ def get_gradient_predivide_factor(world_size: int) -> float: return float(factor) -def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]: - """Return the local shard of a full tensor.""" - # Shard using torch.chunk to match all-gather/reduce-scatter. - chunks = list(torch.flatten(tensor).chunk(world_size)) - while len(chunks) < world_size: - chunks.append(chunks[0].new_empty(0)) - - # Determine number of padding elements. - num_to_pad = chunks[0].numel() - chunks[rank].numel() - assert num_to_pad >= 0, num_to_pad - - shard = chunks[rank].clone() - if num_to_pad > 0: - shard = F.pad(shard, [0, num_to_pad]) - return shard, num_to_pad - - def free_storage(data: torch.Tensor) -> None: """Free underlying storage of a Tensor.""" if data.storage().size() > 0: @@ -86,31 +68,3 @@ def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]: if len(chunks) < num_chunks: chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))]) return chunks - - -def assert_in_engine(cond: Any, s: Any) -> None: - """Used in backward context to make sure error is printed.""" - if not cond: - print(s) - raise AssertionError - - -def replace_state_dict_prefix(state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], - old_prefix: str, new_prefix: str) -> None: - """ - Replace all keys that match a given old_prefix with a new_prefix (in-place). - - Usage:: - - state_dict = {"layer.xyz": torch.tensor(1)} - replace_state_dict_prefix(state_dict, "layer.", "module.layer.") - assert state_dict == {"module.layer.xyz": torch.tensor(1)} - """ - if old_prefix == new_prefix: - raise ValueError("old_prefix and new_prefix must be distinct") - for key in list(state_dict.keys()): - if not key.startswith(old_prefix): - continue - new_key = new_prefix + key[len(old_prefix):] - state_dict[new_key] = state_dict[key] - del state_dict[key] diff --git a/colossalai/zero/sharded_model/param_manager.py b/colossalai/zero/sharded_model/param_manager.py deleted file mode 100644 index 7670ab3a4..000000000 --- a/colossalai/zero/sharded_model/param_manager.py +++ /dev/null @@ -1,385 +0,0 @@ -import os -from typing import Dict, List, Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.distributed import ProcessGroup -from torch.nn.parameter import Parameter - -from ._zero3_utils import alloc_storage, free_storage, get_shard - -# TODO: Remove the toggle-enable_nccl_base_collectives in the future -if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0": - enable_nccl_base_collectives = False -else: - enable_nccl_base_collectives = True - -# TODO: add flatten params - - -class Zero3ParameterManager: - def __init__(self, - module: nn.Module, - process_group: Optional[ProcessGroup], - mixed_precision: bool = False, - flatten_parameters: bool = True, - compute_dtype: Optional[torch.dtype] = None, - compute_device: Optional[torch.device] = None, - offload_config: Optional[dict] = None - ) -> None: - """Manage parameter shards. We manage several attributes on each Parameter instance: - ``zero_is_sharded``: ``True`` if the Parameter is sharded or ``False`` - if the Parameter is intentionally not sharded (in which case we - will all-reduce grads for this param). - ``zero_orig_size``: the size of the original Parameter (before sharding) - ``zero_shard_padding``: the padding size. All paddings are right padding. - ``zero_fp32_shard``: a single shard of the parameters in full precision - (typically FP32, but this is dependent on the dtype of the model - as it's passed in by the user). This can be on CPU or GPU - depending on the value of *``offload_config``*. - ``zero_fp16_shard``: This will be a single shard of the parameters in FP16, used for all-gather. - This can be in FP16 or FP32 depending on the value of *``compute_dtype``* and - if params are offloaded to CPU. - ``zero_full_param_padded``: the full weight (padded to be evenly - divisible by ``world_size``), used for computation in the - forward and backward pass. This will be resized in place and - only materialized (via all-gather) as needed. - ``zero_cpu_grad``: the gradient saved on CPU. It's set only when using CPU offload. - - :param module: original module - :type module: nn.Module - :param process_group: typically data parallel process group, defaults to None - :type process_group: Optional[ProcessGroup], optional - :param mixed_precision: whether to use mixed precision mode, defaults to False - :type mixed_precision: bool, optional - :param flatten_parameters: whether to flatten parameters, useless now, defaults to True - :type flatten_parameters: bool, optional - :param compute_dtype: the dtype of parameters when computing, defaults to None - :type compute_dtype: Optional[torch.dtype], optional - :param compute_device: the device of parameters when computing, defaults to None - :type compute_device: Optional[torch.device], optional - :param offload_config: offload config, defaults to None - :type offload_config: Optional[dict], optional - """ - self.process_group = process_group - self.shard_idx = process_group.rank() - self.num_shards = process_group.size() - self.mixed_precision = mixed_precision - self.compute_dtype = compute_dtype - self.compute_device = compute_device - self.offload_config = offload_config - - self._cpu_offload = offload_config.get('device', None) == 'cpu' if offload_config else False - - self.params: List[Parameter] = [] - for param in module.parameters(): - if not hasattr(param, 'zero_is_sharded'): - self.params.append(param) - - self._has_params = len(self.params) > 0 - self._has_sharded_params = False - # Flag to indicate if the full params are gathered. - self.has_full_params: bool = False - - self._shard_params() - # Maybe no need, reserve to prevent bugs - # self.delete_fp32_shards() - - self._streams: Dict[str, torch.cuda.Stream] = {} - - def _shard_params(self) -> None: - for p in self.params: - assert not hasattr(p, "zero_is_sharded") - assert p.is_floating_point() - if self.mixed_precision: - assert p.dtype == torch.float32 - - # If world_size is 1, then we all-reduce grads instead of sharding. - p.zero_is_sharded = self.num_shards > 1 - p.zero_orig_size = p.data.size() - - if not p.zero_is_sharded: - p.zero_shard_padding = 0 - continue - - # Replace p.data with the relevant shard. - orig_data = p.data - p.data, p.zero_shard_padding = get_shard(p.data, self.shard_idx, self.num_shards) - free_storage(orig_data) - - @torch.no_grad() - def reset_param_attr(self, p: Parameter, training: bool) -> None: - """This should be called by ``ZeroRedundancyLevel3Model._lazy_init()`` - """ - assert hasattr(p, 'zero_is_sharded') and hasattr(p, 'zero_orig_size') - if hasattr(p, 'zero_fp32_shard'): - return - - # A single shard of the parameters in full precision. - p.zero_fp32_shard = p.data - - if self.mixed_precision: - assert p.zero_fp32_shard.dtype == torch.float32 - - if self._cpu_offload: - assert p.zero_fp32_shard.device == torch.device('cpu') - # If we plan to keep the FP32 parameters on CPU, then pinning - # memory allows us to later use non-blocking transfers when moving - # the FP32 param shard to compute_device. - p.zero_fp32_shard = p.zero_fp32_shard.pin_memory() - p.data = p.zero_fp32_shard - - if self.mixed_precision or self._cpu_offload: - - # In mixed precision mode, we maintain a reduced precision - # (typically FP16) parameter shard on compute_device for performing - # the computation in the forward/backward pass. We resize the - # storage to size 0 at init (here) and re-materialize (by copying - # from _fp32_shard) as needed. If offloading params to CPU, the - # dtype of the fp16 shard will depend on the *`compute_dtype`*. - p.zero_fp16_shard = torch.zeros_like( - p.zero_fp32_shard, device=self.compute_device, dtype=self.compute_dtype) - free_storage(p.zero_fp16_shard) - - if self.mixed_precision: - assert p.zero_fp32_shard.dtype == torch.float32 - - if not self.mixed_precision and not self._cpu_offload: - # use _fp32_shard if you are not in using mixed precision or - # offloading params and grads to CPU. - p.zero_fp16_shard = None - - # We also maintain a full-sized parameter of type self.compute_dtype - # (FP16 for mixed_precision or FP32 otherwise). We resize the - # storage to size 0 at init (here) and only materialize as needed. The - # storage may contain padding elements so that it is evenly divisible by - # world_size, although these padding elements will be removed before the - # relevant computation. - if p.zero_is_sharded: - p.zero_full_param_padded = torch.zeros( - p.data.numel() * self.num_shards, device=self.compute_device, dtype=self.compute_dtype - ) - free_storage(p.zero_full_param_padded) - - if self._cpu_offload and training: - p.zero_cpu_grad = torch.zeros_like(p.data, device='cpu').pin_memory() - - def setup_streams(self, streams): - self._streams = streams - - @torch.no_grad() - def rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]: - """ - Gather all shards of params. - - Note, this is idempotent if full params are already gathered. Callers - assume the idempotency. So please keep it that way. - - Args: - force_full_precision (bool, Optional): by default params will be gathered - in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is - ``True``, in which case they will be gathered in full precision - (e.g., FP32), possibly in fresh storage. The parameter that's being - rebuilt will end up in full precision as well. - - Returns: - A list of tuples, where the first element is the full-sized param - and the second element is a bool indicating if it's safe for the - caller to free the full-sized param. This will be ``None`` if - ``force_full_precision=False`` and the full params are already gathered. - """ - # Store tensor and free flag - output_tensors: List[Tuple[torch.Tensor, bool]] = [] - - def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None: - """ - Helper function to update p.data pointer. - - Args: - custom_output_tensor (torch.Tensor, Optional): if not None, this - tensor contains the data we just gathered. - """ - if custom_output_tensor is not None: - assert p.zero_is_sharded - p.data = custom_output_tensor - output_tensors.append((p.data, True)) - elif not p.zero_is_sharded: - if (self.mixed_precision or self._cpu_offload) and not force_full_precision: - assert p.zero_fp16_shard is not None - p.data = p.zero_fp16_shard - output_tensors.append((p.data, True)) - else: - # Here p.data == p._fp32_shard, so it's not safe to free. - output_tensors.append((p.data, False)) - else: - p.data = p.zero_full_param_padded - output_tensors.append((p.data, True)) - # Trim any padding and reshape to match original size. - p.data = p.data[: p.zero_orig_size.numel()].view(p.zero_orig_size) - - if self._has_sharded_params: - # self.has_full_params flag can be out of sync if a shared param is - # sharded by another ZeroRedundancyLevel3Model instance. An example is that in eval case - # with reshard_after_forward=False but the sharing instance has - # reshard_after_forward=True. Then, on the second forward, the - # other instance can shard the shared param and but this instance - # can mistakenly think the full param is already gathered from the - # has_full_params flag. - # - # Therefore, we update the flag accordingly here. - self.has_full_params = not any(p.zero_full_param_padded.storage().size() == 0 for p in self.params) - - # Early exit if we already have full params and don't need full precision. - if self.has_full_params and not force_full_precision: - for p in self.params: - update_p_data() - return output_tensors - - self.has_full_params = True - - with torch.cuda.stream(self._streams["all_gather"]): - if (self.mixed_precision or self._cpu_offload) and not force_full_precision: - self.use_fp16_shards() - - if self._cpu_offload and force_full_precision: - # If the compute_dtype and storage dtype are the same, - # use pinned memory. Otherwise move p.data to the compute - # device. - if self.params[0].dtype == self.compute_dtype: - self.use_fp16_shards() - else: - for p in self.params: - p.data = p.data.to(self.compute_device) - - for p in self.params: - if not p.zero_is_sharded: # e.g., when world_size == 1 - update_p_data() - else: - # Skip if already built. Only shared param can be rebuilt multiple times. - # A corner case is p.zero_orig_size = (1,), which means the shape equality is - # not a perfect check. But we assume we don't share a param with shape (1,). - # if p.data.shape == p.zero_orig_size and hasattr(p, "zero_is_shared") and p.zero_is_shared: - # continue - # If self._cpu_offload and force_full_precision, we need to cast - # the FP32 CPU param to CUDA for the all-gather. - p_data = p.data.to(p.zero_full_param_padded.device, non_blocking=True) - - p_size = p.zero_full_param_padded.size() - assert p_size.numel() % self.num_shards == 0 - if self.mixed_precision and force_full_precision: - # Allocate fresh tensor in full precision since we are in - # mixed precision and full precision rebuild is asked. - output_tensor = p_data.new_zeros(p_size) - else: - if p.zero_full_param_padded.storage().size() != p_size.numel(): - # Allocate based on full size from all shards. - alloc_storage(p.zero_full_param_padded, size=p_size) - output_tensor = p.zero_full_param_padded - - # Fill output_tensor with (p.data for each shard in self.world_size) - if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives: - # New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather. - dist._all_gather_base(output_tensor, p_data, group=self.process_group) - else: - chunks = list(output_tensor.chunk(self.num_shards)) - dist.all_gather(chunks, p_data, group=self.process_group) - - # Set p.data = output_tensor (with padding trimmed) - update_p_data(output_tensor) - - if (self.mixed_precision or self._cpu_offload) and not force_full_precision: - self.free_fp16_shards([p]) - - if self._cpu_offload and (self.params[0].dtype == self.compute_dtype): - self.free_fp16_shards([p]) - - torch.cuda.current_stream().wait_stream(self._streams["all_gather"]) - return output_tensors - - @torch.no_grad() - def use_full_params(self) -> None: - """ - Switch p.data pointers to use the full params. - - Note: this assumes full params are already gathered. - - Note: this might be called after full_params is already in used. So please - make sure it is idempotent in that case. - """ - assert self.has_full_params - for p in self.params: - if not p.zero_is_sharded: - if self.mixed_precision or self._cpu_offload: - assert p.zero_fp16_shard is not None - assert p.zero_fp16_shard.storage().size() != 0 - p.data = p.zero_fp16_shard - else: - assert p.zero_full_param_padded.storage().size() != 0, f"{p.zero_orig_size} {id(self)}" - p.data = p.zero_full_param_padded[: p.zero_orig_size.numel()].view(p.zero_orig_size) - - @torch.no_grad() - def use_fp16_shards(self, params: Optional[List[Parameter]] = None) -> None: - """Cast FP32 param shard to FP16 for a list of params.""" - if params is None: - params = self.params - with torch.cuda.stream(self._streams["fp32_to_fp16"]): - for p in params: - assert p.zero_fp16_shard is not None - alloc_storage(p.zero_fp16_shard, size=p.zero_fp32_shard.size()) - p.zero_fp16_shard.copy_( - # If _cpu_offload is True, this will be non-blocking - # because _fp32_shard is pinned, otherwise it's a no-op. - p.zero_fp32_shard.to(p.zero_fp16_shard.device, non_blocking=True) - ) - p.data = p.zero_fp16_shard - torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"]) - - @torch.no_grad() - def use_fp32_shards(self, params: Optional[List[Parameter]] = None) -> None: - """Use FP32 shard for a list of params.""" - if params is None: - params = self.params - for p in params: - p.data = p.zero_fp32_shard - - @torch.no_grad() - def free_full_params(self, params: Optional[List[Parameter]] = None) -> None: - """Free up storage for full parameters.""" - if params is None: - params = self.params - self.has_full_params = False - current_stream = torch.cuda.current_stream() - for p in params: - if not p.zero_is_sharded: # e.g., world_size == 1 - if self.mixed_precision or self._cpu_offload: - self.free_fp16_shards([p]) - continue - # Don't let PyTorch reuse this memory until all work in the current - # stream is complete. - p.zero_full_param_padded.record_stream(current_stream) - # There may be external references to the Tensor Storage that we - # can't modify, such as references that are created by - # ctx.save_for_backward in the forward pass. Thus when we - # unshard parameters, we should reuse the original Tensor - # Storage object and unshard it in-place. For now, just resize - # the Storage to 0 to save memory. - free_storage(p.zero_full_param_padded) - - @torch.no_grad() - def free_fp16_shards(self, params: Optional[List[Parameter]] = None) -> None: - """Free storage for FP16 shards for a list of params.""" - if params is None: - params = self.params - current_stream = torch.cuda.current_stream() - for p in params: - if p.zero_fp16_shard is not None: - # zero_fp16_shard is allocated in "fp32_to_fp16" stream, so we can't - # free it until the work in the current stream completes. - p.zero_fp16_shard.record_stream(current_stream) - free_storage(p.zero_fp16_shard) - - def delete_fp32_shards(self) -> None: - for p in self.params: - if hasattr(p, 'zero_fp32_shard'): - del p.zero_fp32_shard # reset _init_param_attr diff --git a/colossalai/zero/sharded_model/sharded_grad.py b/colossalai/zero/sharded_model/sharded_grad.py deleted file mode 100644 index 7c8667f1b..000000000 --- a/colossalai/zero/sharded_model/sharded_grad.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import Optional - -import torch -import torch.nn as nn -from torch.nn.parameter import Parameter - - -class ShardedGradient: - def __init__(self, - param: Parameter, - sharded_module: nn.Module, - offload_config: Optional[dict] = None - ) -> None: - assert hasattr( - param, 'ca_attr') and param.ca_attr.is_sharded, 'ShardedGradient can only be initialized with sharded parameter' - - self.param = param - self.sharded_module = sharded_module - self.offload_config = offload_config - - self._cpu_offload = offload_config.get('device', None) == 'cpu' if offload_config else False - - # _gpu_grad is either sharded or not - # all saved grads are fp32 - self._gpu_grad: Optional[torch.Tensor] = None - self._cpu_grad: Optional[torch.Tensor] = None - - if self._cpu_offload: - # this buffer will be held and reused every iteration - self._cpu_grad = torch.zeros(param.ca_attr.payload('cpu'), dtype=torch.float).pin_memory() - - @torch.no_grad() - def setup(self) -> None: - """This function will be called pre-backward. Save the local accumulated gradient to _gpu_grad. - When no_sync() is enable (_require_backward_grad_sync=False), the grad is accumulated locally in param.grad - - :raises AssertionError: Raise if grad shape is wrong - """ - if self.sharded_module._require_backward_grad_sync and self.param.grad is not None: - if self.param.grad.device != self.param.data.device: - # TODO: offload? - raise RuntimeError( - 'grad and param are on different device, grad {self.param.grad.device} vs. param {self.param.data.device}') - else: - self._gpu_grad = self.param.grad.data - self.param.grad = None - - def reduce_scatter_callback(self, reduced_grad: torch.Tensor) -> None: - """This function will be called in post-backward hook, so we cannot modify param.grad directly - - :param reduced_grad: the reduced grad - :type reduced_grad: torch.Tensor - """ - # Make sure we store fp32 grad - if torch.is_floating_point(reduced_grad) and reduced_grad.dtype != torch.float: - reduced_grad.data = reduced_grad.data.to(torch.float) - - if self._gpu_grad is None: - self._gpu_grad = reduced_grad.data - else: - self._gpu_grad += reduced_grad.data - - # Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full - # backwards pass completes, we will set `.grad` to the CPU copy. - if self._cpu_offload: - self._cpu_grad.copy_(self._gpu_grad.data, non_blocking=True) - # Don't let this memory get reused until after the transfer. - self._gpu_grad.data.record_stream(torch.cuda.current_stream()) - - @torch.no_grad() - def write_back(self) -> None: - """This function will be called in final backward hook - """ - if self._cpu_grad is not None: - assert self.param.device == torch.device( - 'cpu'), f'Incorrect param device, expected CPU, got {self.param.device}' - self.param.grad.data = self._cpu_grad - elif self._gpu_grad is not None: - assert self.param.device == self._gpu_grad.device, f'Incorrect _gpu_grad device, param on {self.param.device} but _gpu_grad on {self._gpu_grad.device}' - self.param.grad.data = self._gpu_grad - else: - raise RuntimeError('No grad to write back') - # If using CPU offload, _cpu_grad will store the CPU tensor of _gpu_grad - # They should be released here - self._gpu_grad = None diff --git a/colossalai/zero/sharded_model/sharded_model.py b/colossalai/zero/sharded_model/sharded_model.py deleted file mode 100644 index 9d3d77331..000000000 --- a/colossalai/zero/sharded_model/sharded_model.py +++ /dev/null @@ -1,1104 +0,0 @@ -import contextlib -import copy -import functools -import os -import traceback -from collections import OrderedDict -from enum import Enum, auto -from typing import (Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Union) - -import torch -import torch.distributed as dist -import torch.nn as nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device -from torch.autograd import Variable -from torch.distributed import ProcessGroup -from torch.nn.parameter import Parameter - -from ._zero3_utils import (apply_to_tensors, assert_in_engine, cast_float_arguments, cast_tensor_to_fp16, - cast_tensor_to_fp32, chunk_and_pad, free_storage, get_gradient_predivide_factor, get_shard, - replace_state_dict_prefix) -from .param_manager import Zero3ParameterManager -from .reduce_scatter import ReduceScatterBucketer - -# TODO: Remove the toggle-enable_nccl_base_collectives in the future -if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0": - enable_nccl_base_collectives = False -else: - enable_nccl_base_collectives = True - - -class TrainingState(Enum): - IDLE = auto() - FORWARD = auto() - PRE_BACKWARD = auto() - POST_BACKWARD = auto() - GATHER_FULL_PARAMS = auto() - - -# TODO: Add clip_grad_norm_ -# TODO: Add gather_full_optim_state_dict and get_shard_from_optim_state_dict - - -class ShardedModel(nn.Module): - - def __init__(self, - module: nn.Module, - process_group: Optional[ProcessGroup] = None, - reduce_scatter_process_group: Optional[ProcessGroup] = None, - reshard_after_forward: bool = True, - disable_reshard_on_root: bool = True, - mixed_precision: bool = False, - fp32_reduce_scatter: bool = False, - flatten_parameters: bool = True, - compute_dtype: Optional[torch.dtype] = None, - buffer_dtype: Optional[torch.dtype] = None, - reduce_scatter_bucket_size_mb: int = 25, - compute_device: Optional[torch.device] = None, - no_broadcast_optim_state: Optional[bool] = False, - state_dict_device: Optional[torch.device] = None, - clear_autocast_cache: bool = False, - force_input_to_fp32: bool = False, - verbose: bool = False, - offload_config: Optional[dict] = None, - state_dict_on_rank_0_only: bool = False, - gradient_predivide_factor: Optional[float] = 1.0) -> None: - super().__init__() - self.logger = get_dist_logger() - - self.process_group = process_group or gpc.get_group(ParallelMode.DATA) - self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group - self.world_size = dist.get_world_size(self.process_group) - self.rank = dist.get_rank(self.process_group) - - self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward - self.disable_reshard_on_root = disable_reshard_on_root - self.mixed_precision = mixed_precision - self.fp32_reduce_scatter = fp32_reduce_scatter - self.offload_config = offload_config - self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32) - self.buffer_dtype = buffer_dtype or self.compute_dtype - self.reduce_scatter_bucket_size_mb = reduce_scatter_bucket_size_mb - self.compute_device = compute_device or torch.device(f'cuda:{get_current_device()}') - self.uncollected_opt_state: Dict[int, Dict] = {} - self.no_broadcast_optim_state = no_broadcast_optim_state - self.state_dict_device = state_dict_device or self.compute_device - self.clear_autocast_cache = clear_autocast_cache - self.force_input_to_fp32 = force_input_to_fp32 - self.verbose = verbose - self.state_dict_on_rank_0_only = state_dict_on_rank_0_only - - self._cpu_offload = offload_config.get('device', None) == 'cpu' if offload_config else False - - # We find if gradient_predivide_factor != 1.0, there may be wrong precision problem - # So we use 1.0 as the default gradient_predivide_factor - # However, if you set gradient_predivide_factor to None - # we will set gradient_predivide_factor to a value >= 1.0 automatically - self.gradient_predivide_factor: float = gradient_predivide_factor if \ - gradient_predivide_factor is not None else \ - get_gradient_predivide_factor(self.world_size) - self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor - - self._check_sanity() - - self.params: List[Parameter] = [] - - for name, param in module.named_parameters(): - if not hasattr(param, 'zero_is_sharded'): - self.params.append(param) - - self.module = module - - self.param_manager = Zero3ParameterManager(module, - process_group=self.process_group, - mixed_precision=self.mixed_precision, - flatten_parameters=flatten_parameters, - compute_dtype=self.compute_dtype, - compute_device=self.compute_device, - offload_config=offload_config) - - self._reset_lazy_init_info() - - # Flag to indicate if we require gradient reduction in the backward - # pass. This will be False when inside the no_sync context manager. - self._require_backward_grad_sync: bool = True - - # Enum to indicate if we're in the forward/backward pass, idle, etc. - self.training_state = TrainingState.IDLE - - # Register hook after state_dict() to remove the "_zero3_module." - # prefix and before load_state_dict() to add it back. - self._register_state_dict_hook(functools.partial(_post_state_dict_hook, self.state_dict_on_rank_0_only)) - self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook) - - # Flag to indicate whether state_dict() should automatically gather the full params. - self._return_full_state_dict = True - - # Flag to guard against preparing gradients multiple times per iteration. - # This is reset at the end of the backward pass. - self._pre_backward_hook_has_run = False - - def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: - self._lazy_init() - - # Start of a forward pass. - self.training_state = TrainingState.FORWARD - - # For root and mixed precision, we convert the input to FP16 (no_grad is needed for - # the conversion). - if self._is_root and self.mixed_precision: - args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs) - - # If enabled, convert the input to FP32 if we are in full precision. - # no_grad is not used because the input might be for a non-root instance, - # which mean autograd needs to go through the conversion. - if self.force_input_to_fp32 and not self.mixed_precision: - args, kwargs = cast_float_arguments(cast_tensor_to_fp32, *args, **kwargs) - - # All-gather full parameters. This will also transfer FP32 parameters to - # ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``). - self.param_manager.rebuild_full_params() - - # Register backward hooks to reshard params and reduce-scatter grads. - # These need to be re-registered every forward pass. - self._register_post_backward_hooks() - - outputs = self.module(*args, **kwargs) - - if self.reshard_after_forward: - self.param_manager.free_full_params() - if self.mixed_precision or self._cpu_offload: - self.param_manager.free_fp16_shards() - - # Switch to main FP32 param shard. We maintain this invariant throughout - # the code, i.e., ``p.data == p.zero_fp32_shard`` after each function. This - # also ensures that after the first forward, the optimizer state will be - # initialized with the correct dtype and (sharded) size, since optimizer - # state is typically initialized lazily in ``optim.step()``. - self.param_manager.use_fp32_shards() - - # Register pre-backward hooks to all-gather the params for the backward - # pass (if output's grad was needed). This won't register anything if - # we are in eval mode. - # - # Some model does forward pass multiple times, we need to register the - # pre-backward hook on every output since the last output's hook has to - # fire first to setup for backward. However, we use ``self._pre_backward_hook_has_run`` - # to prevent repeated overhead from multiple hook callbacks. - outputs = self._register_pre_backward_hooks(outputs) - - # Done with a forward pass. - self.training_state = TrainingState.IDLE - - # Only need to clear cache during forward. During backward, the cache is not used. - if self.clear_autocast_cache: - torch.clear_autocast_cache() - - return outputs - - def _check_sanity(self) -> None: - if self.fp32_reduce_scatter and not self.mixed_precision: - raise ValueError("fp32_reduce_scatter requires mixed_precision=True") - if self.compute_device.type == 'cuda': - input_tensor = torch.ones(1).to(self.compute_device) - output = list(torch.zeros(self.world_size).to(self.compute_device).chunk(self.world_size)) - dist.all_gather(output, input_tensor, group=self.process_group) - assert torch.cat(output).sum() == float( - self.world_size), (f"found {torch.cat(output).sum()} devices in process group but " - f"world_size={self.world_size}. Check torch.cuda.set_device is called properly") - - def _reset_lazy_init_info(self) -> None: - self._is_root: Optional[bool] = None - self._streams: Dict[str, torch.cuda.Stream] = {} - self._reducer: Optional[ReduceScatterBucketer] = None - self.param_manager.delete_fp32_shards() - self._output_pre_backward_hook_registered: Optional[List] = None - self.reshard_after_forward = self._orig_reshard_after_forward - - def _lazy_init(self): - # Initialize param attributes lazily, in case the param's dtype or - # device changes after __init__. - for p in self.params: - self.param_manager.reset_param_attr(p, self.training) - - # Initialize _is_root and setup streams. These steps would ideally - # happen in __init__, but _is_root can only be determined after the - # entire model hierarchy is setup, thus we run it lazily. - if self._is_root is None: - self._set_is_root() - self._setup_streams() - self._setup_output_hook_list() - - if self._is_root: - # Buffers stay on GPU, and don't get sharded. Since _cast_buffers - # applies recursively, we only call this from the root instance. - self._cast_buffers() - - if self.disable_reshard_on_root: - # Don't free the full params for the outer-most (root) instance, - # since those params will be needed immediately after for the - # backward pass. - self.reshard_after_forward = False - - # Due to the use of streams, we need to make sure the previous - # ``optim.step()`` is done before we all-gather parameters. - self._wait_for_previous_optim_step() - - def _set_is_root(self) -> None: - """If ``True``, implies that no other :class:`ShardedModel` - instance wraps this one. Called once by :func:`_lazy_init`. - Also sets self.children_share_process_group = True if all child - instances share the same process group. If some child instances use a - different process group, self.clip_grad_norm_ will raise an error. - """ - if self._is_root is not None: - return - # No Zero3Model instance wraps this, else _is_root would be set to False. - self._is_root = True - # If final backward callback is never been queued, state should be IDLE. - # If final backward callback is queued, the callback should be finished - # and the state was reset to be IDLE. - # This should be asserted at the beginning of forward pass in the root instance only. - # For children instances, if they are checkpointed, state will not be reset to - # IDLE after each inner forward/backward. - self._assert_state(TrainingState.IDLE) - # As the root, we now set all children instances to False and - # give them a closure to try to queue a wait_for_post_backward. - self.children_share_process_group = True - for n, m in self.named_modules(): - # `n != ""` excludes self. - if n != '' and isinstance(m, ShardedModel): - # We relax the assert for non-root instance, when the nested inialized module is wrapped - # again in ShardedModel later, for example after training to run inference. - assert m._is_root is None or not m._is_root - if m._is_root is None: - m._is_root = False - if m.process_group != self.process_group: - self.children_share_process_group = False - - # if child instance in its own (smaller) world, that was probably an attempt to avoid OOM. - # Therefore gathering this child's optim state will probably cause OOM, so we won't do it. - m.no_broadcast_optim_state = m.no_broadcast_optim_state or \ - ((m.world_size == 1) - and (m.world_size < self.world_size) - and (m.process_group != self.process_group)) - - def _setup_streams(self) -> None: - """Create streams to overlap data transfer and computation.""" - if len(self._streams) > 0 or not self._is_root: - return - - if torch.cuda.is_available(): - # Stream to move main FP32 params (may be on CPU) to FP16 for forward. - self._streams['fp32_to_fp16'] = torch.cuda.Stream() - # Stream for all-gathering parameters. - self._streams['all_gather'] = torch.cuda.Stream() - # Stream for overlapping grad reduction with the backward pass. - self._streams['post_backward'] = torch.cuda.Stream() - - self.param_manager.setup_streams(self._streams) - # Helper for bucketing reduce-scatter ops. This is also shared with - # children instances to improve bucket utilization. - self._reducer = ReduceScatterBucketer(self.reduce_scatter_bucket_size_mb) - # We share streams with all children instances, which allows them to - # overlap transfers across the forward pass without synchronizing with - # the default stream. - for n, m in self.named_modules(): - if n != "" and isinstance(m, ShardedModel): - m._streams = self._streams - m._reducer = self._reducer - m.param_manager.setup_streams(self._streams) - - def _setup_output_hook_list(self) -> None: - """set up a list to avoid registering pre-backward hooks - incorrectly. - """ - assert self._is_root, "This should only be called on the root" - self._output_pre_backward_hook_registered = [] - for n, m in self.named_modules(): - if n != "" and isinstance(m, ShardedModel): - m._output_pre_backward_hook_registered = self._output_pre_backward_hook_registered - - def _wait_for_previous_optim_step(self) -> None: - """ - The outer-most :class:`ShardedModel` instance (i.e., the root - instance) needs to synchronize with the default stream to ensure the - previous optimizer step is done. - """ - if not torch.cuda.is_available(): - return - if self.mixed_precision or self._cpu_offload: - self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream()) - else: - self._streams["all_gather"].wait_stream(torch.cuda.current_stream()) - - def _cast_buffers(self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - memo: Optional[Set] = None) -> None: - """Move all buffers to the given *device* and *dtype*. - - If *device* or *dtype* are not given, then they will default to - ``self.compute_device`` and ``self.buffer_dtype``, respectively. In the - case of nested ShardedModel instances, we will respect the child instance's - ``compute_device`` and ``buffer_dtype`` configuration. - - Args: - device (torch.device, Optional): - device to cast buffers to (defaults to compute_device) - dtype (torch.dtype, Optional): - dtype to cast buffers to (defaults to buffer_dtype) - memo (Set, Optional): - set of modules that have already been processed - """ - if memo is None: - memo = set() - for module in self.modules(): - if module is not self and isinstance(module, ShardedModel): - # Allow any child Zero3Model instances to handle their own buffers. - module._cast_buffers(device=device, dtype=dtype, memo=memo) - elif module not in memo: - memo.add(module) - for name, buf in module.named_buffers(recurse=False): - if buf is None: - continue - buf = buf.to(device=device or self.compute_device) - if torch.is_floating_point(buf): - buf = buf.to(dtype=dtype or self.buffer_dtype) - setattr(module, name, buf) - - @torch.no_grad() - def _prep_grads_for_backward(self) -> None: - """Make sure p.grad is correctly prepared for the backward with - right shape, device, accumulated values, etc. - """ - for p in self.params: - if p.grad is not None: - if p.grad.device != p.data.device: - p.grad = None - elif p.grad.size() == p.zero_orig_size: - if not p.zero_is_sharded: - p.zero_saved_grad = p.grad.data - p.grad = None - else: - # This is gradient accumulation with no_sync context. - pass - elif p.grad.size() == p.zero_fp32_shard.shape: - # This is gradient accumulation without no_sync context. - # We save the grad shard and set p.grad to None for this backward pass. - # We will accumulate after this pass's grad is generated and reduced and - # sharded. - p.zero_saved_grad_shard = p.grad.data - p.grad = None - else: - raise AssertionError(f"unexpected grad shape: {p.grad.size()}") - - def _register_pre_backward_hooks(self, outputs: Any) -> Any: - """Register pre-backward hook to run before the wrapped module's - backward. Hooks should be attached to all outputs from the forward. - - Returns: - outputs: new outputs with hooks registered if they requires gradient. - """ - if not torch.is_grad_enabled(): - return outputs # don't register hooks if grad isn't enabled - - if self._is_root: - # This actually means that only root instance has - # _post_backward_callback_queued defined. Accidentally accessing this field - # will assert on all other instances, giving us a nice bug checker. - self._post_backward_callback_queued = False - - def _pre_backward_hook(*unused: Any) -> None: - # try to queue final backward callback only once for root, so - # that final backward callback is attached to the outer most - # backward graph task and called after all the backward - # calls are completed. - if self._is_root: - self._register_final_backward_hook() - - # All-gather full parameters or switching to the full params. - # - # This needs to be done on every pre_backward hook, even within the same - # iteration (i.e. for checkpointed, multiple forward pass modules). This is - # because after the forward pass (i.e. in checkpoint inner graph), we always - # switch to fp32_shard in the ``forward`` function. - # - # We used to do this only after the ``self._pre_backward_hook_has_run`` - # boolean guard below, which is incorrect. It worked in pytorch < 1.9 for - # some unknown reason, but pytorch 1.10 nightly exposed this bug. - # - # Note, both ``self.param_manager.rebuild_full_params`` and ``self.param_manager.use_full_params`` are - # idempotent. So in case they are called unnecessarily, they don't incur much - # overhead. - if self.reshard_after_forward: - self.param_manager.rebuild_full_params() - else: - self.param_manager.use_full_params() - - # Only run the ``self._prep_grads_for_backward`` once per iteration (i.e. in case - # it is multiple outputs or multiple forward passes). - if not self._pre_backward_hook_has_run: - self._pre_backward_hook_has_run = True - # Start of a backward pass for the first time in an iteration. - self._assert_state([TrainingState.IDLE, TrainingState.PRE_BACKWARD]) - # Prepare p.grad so that it is in the right shape, device, accumulated values, etc. - self._prep_grads_for_backward() - - # Transition to PRE_BACKWARD state if currently IDLE. We can transition from POST_BACKWARD - # to IDLE when ShardedModel is within activation checkpointing and called multiple times, due to the - # extra forward pass for re-computation. - if self.training_state == TrainingState.IDLE: - self.training_state = TrainingState.PRE_BACKWARD - self._assert_state([TrainingState.PRE_BACKWARD, TrainingState.POST_BACKWARD]) - - _registered = 0 - - def _register_hook(t: torch.Tensor) -> torch.Tensor: - # We don't register the pre_backward hook on the same tensor that has been - # returned from an inner ShardedModel, unless it is the first one. This does - # not cover all problematic cases though. A tensor not from an inner - # ShardedModel can cause problems too: - # ``` - # x = layer1(input) - # state = [x] # better change to x.detach(), not fixed by the following if-condition - # x = inner_zero3_module_layer2(x) - # state.append(x) # better change to x.detach(), but fixed by the following if-condition - # x = layer3(x) - # return x, state - # ``` - # The tensors in `state`, if not detached, can be registered with - # backward hooks (in addition to the `x` on the last line). In that case, - # pre-backward hook can fire multiple times in the order that causes - # the outer ShardedModel to crash. - # - # The best practice is for modules to be wrapped by ShardedModel to return 1 and only - # 1 tensor to be used for backward. All other tensors returned should be - # detached. - nonlocal _registered - assert self._output_pre_backward_hook_registered is not None - if t.requires_grad and (_registered == 0 or id(t) not in self._output_pre_backward_hook_registered): - t.register_hook(_pre_backward_hook) - self._output_pre_backward_hook_registered.append(id(t)) - _registered += 1 - return t - - # Attach hooks to Tensor outputs. - outputs = apply_to_tensors(outputs, _register_hook) - - return outputs - - def _register_post_backward_hooks(self) -> None: - """ - Register backward hooks to reshard params and reduce-scatter grads. - - This is called during forward pass. The goal is to attach a hook - on each of the parameter's gradient generating function (``grad_acc`` - below) so that the hook is called *after* all gradients for that - param are computed. - - Goals: - - 1. We want the hook to fire once and only once *after* all gradients - are accumulated for a param. - 2. If it fires more than once, we end up incorrectly shard the grad - multiple times. (could lead to dimension too small) - 3. If it fires once but too early or doesn't fire, we leave gradients - unsharded. (could lead to dimension too large) - - Due to multiple-pass forward, this function can be called on - the same parameter multiple times in a single forward pass. If we register - the hook multiple time, we end up getting called multiple times. We - could try to get a new hook every time and delete the previous one - registered. However, due to *unknown reason* (I have debugged it for - a long time!), in mixed precision mode, we get two different ``grad_acc`` - objects below during different calls of this function (in the same - forward pass). If we keep the last one, the hook end up firing too - early. In full precision mode, we luckily get the *same* ``grad_acc`` - object, so deleting and re-registering still ensured the hook fire - once after all gradients are generated. However, we find if we use activation - checkpoint in mixed precision mode, hook on ``grad_acc`` object won't be - fire for *unknown reason*. So we finally register hook on parameter directly. - - Empirically, keep the first hook register per forward pass seems to - work the best. We do need to remove the hook at the end of the - backward pass. Otherwise, the next forward pass will not register - a new hook, which is needed for a new forward pass. - """ - if not torch.is_grad_enabled(): - return # don't register grad hooks if grad isn't enabled - for p in self.params: - if p.requires_grad: - if hasattr(p, "zero_shard_bwd_hook"): - continue - # For mixed precision with activation checkpoint, hooks on GradAccumulation won't be fired normally - # Instead we register hook on parameter - # In this way, we can't modify param.grad and param.data directly, which leads to more memory usage - # Register a hook on the first call, empirically, autograd - # fires it at the end for this param, which makes sense. - # p_tmp = p.expand_as(p) # Get a grad_fn on p_tmp. - # assert p_tmp.grad_fn is not None - # grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object. - # handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p)) - # p.zero_shard_bwd_hook = (grad_acc, handle) - handle = p.register_hook(functools.partial(self._post_backward_hook, p)) - p.zero_shard_bwd_hook = handle - - @torch.no_grad() - def _post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: - """ - At the start of :func:`_post_backward_hook`, ``param.grad`` contains the - full gradient for the local batch. The reduce-scatter op will replace - ``param.grad`` with a single shard of the summed gradient across all - GPUs. This shard will align with the current GPU rank. For example:: - - before reduce_scatter: - param.grad (GPU #0): [1, 2, 3, 4] - param.grad (GPU #1): [5, 6, 7, 8] - - after reduce_scatter: - param.grad (GPU #0): [6, 8] # 1+5, 2+6 - param.grad (GPU #1): [10, 12] # 3+7, 4+8 - - The local GPU's ``optim.step`` is responsible for updating a single - shard of params, also corresponding to the current GPU's rank. This - alignment is created by `param_manager`, which ensures that - the local optimizer only sees the relevant parameter shard. - """ - # First hook callback will see PRE state. If we have multiple params, - # then subsequent hook callbacks will see POST state. - self._assert_state([TrainingState.PRE_BACKWARD, TrainingState.POST_BACKWARD]) - self.training_state = TrainingState.POST_BACKWARD - if grad is None: - return - - assert grad is not None, param.shape - if grad.requires_grad: - raise RuntimeError("ShardedModel only works with gradients that don't require gradients") - - if self._require_backward_grad_sync or self.reshard_after_forward: - # Free full params. As a special case, we don't free the full params - # when in a ``no_sync`` context (as inversely indicated by - # ``self._require_backward_grad_sync``), since the params will not - # get updated before the next forward. This saves networking - # bandwidth but uses more GPU memory. - self.param_manager.free_full_params([param]) - - if self.mixed_precision: - # This is a no-op if reshard_after_forward is True, since we already - # free the param shard when rebuilding the full params in the - # pre_backward_hook. - self.param_manager.free_fp16_shards([param]) - - # Switch to FP32 shard after backward. - # Cannot modify param.data, so we switch to FP32 in final backward hook - # self.param_manager.use_fp32_shards([param]) - - if not self._require_backward_grad_sync: - return - - # Wait for all work in the current stream to finish, then start the - # reductions in post_backward stream. - self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(self._streams["post_backward"]): - new_grad = grad.clone() - - if self.mixed_precision and self.fp32_reduce_scatter: - # Cast grad to FP32. - new_grad.data = new_grad.data.to(param.dtype) - - if self.gradient_predivide_factor > 1: - # Average grad by world_size for consistency with PyTorch DDP. - new_grad.data.div_(self.gradient_predivide_factor) - - orig_grad_data = new_grad.data - if param.zero_is_sharded: - assert self._reducer is not None - # Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into - # param.zero_saved_grad_shard. If this ShardedModel module was called multiple times - # it's possible that multiple - # gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't - # matter, neglecting rounding. - # Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction. - # - # The effect on memory consumption is not usually significant. No extra memory is allocated if this - # module is called only once, reduction happens quickly, or the tensor is bucketed. If the module is - # called multiple times, and the backwards pass runs far enough ahead of the `post_backward` stream, - # then we can end up with multiple unsharded gradients allocated and queued for reduction. - # - # We could guard against this by using CUDA events (see record_event, wait_event in torch.cuda.Stream). - # This ensures the `default` stream will wait for the `post_backward` stream to complete the last - # reduction for this module, before scheduling additional reduction work. Then at most there are two - # unsharded gradients allocated; one for a pending reduction, and one for gradient computation. - callback_fn = functools.partial(self._reduce_scatter_callback, param) - grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size()) - self._reducer.reduce_scatter_async(grad_chunks, - group=self.reduce_scatter_process_group, - callback_fn=callback_fn) - else: - # Currently the only way for _is_sharded to be False is if - # world_size == 1. This could be relaxed in the future, in which - # case grads should be all-reduced here. - assert self.world_size == 1 - self._reduce_scatter_callback(param, new_grad) - - # After _post_backward_hook returns, orig_grad_data will eventually - # go out of scope, at which point it could otherwise be freed for - # further reuse by the main stream while the div/reduce_scatter/copy - # are underway in the post_backward stream. See: - # github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py - orig_grad_data.record_stream(self._streams["post_backward"]) - - def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None: - """Hook to call on each param after the reduce-scatter.""" - assert torch.cuda.current_stream() == self._streams["post_backward"] - self._assert_state(TrainingState.POST_BACKWARD) - if self.gradient_postdivide_factor > 1: - # Average grad by world_size for consistency with PyTorch DDP. - reduced_grad.data.div_(self.gradient_postdivide_factor) - # Cast grad to param's dtype (typically FP32). Note: we do this - # before the cpu offload step so that this entire hook remains - # non-blocking. The downside is a bit more D2H transfer in that case. - if self.mixed_precision: - orig_param_grad_data = reduced_grad.data - reduced_grad.data = reduced_grad.data.to(dtype=param.zero_fp32_shard.dtype) - # Don't let this memory get reused until after the transfer. - orig_param_grad_data.record_stream(torch.cuda.current_stream()) - - if param.zero_is_sharded: - # Accumulate into the gradient shard. - if getattr(param, "zero_saved_grad_shard", None) is None: - param.zero_saved_grad_shard = reduced_grad.data - else: - assert ( - param.zero_saved_grad_shard.shape == reduced_grad.shape), f"{param.zero_saved_grad_shard.shape} \ - vs {reduced_grad.shape}" - - param.zero_saved_grad_shard.data += reduced_grad.data - reduced_grad = param.zero_saved_grad_shard.data - else: - # We can't modify the dtype of grad in this function - # So we use `param.zero_saved_grad` to store gradient - # This is useful when using mixed precision mode on single node - if getattr(param, 'zero_saved_grad', None) is None: - param.zero_saved_grad = reduced_grad.data - else: - param.zero_saved_grad.data += reduced_grad.data - - # Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full - # backwards pass completes, we will set `.grad` to the CPU copy. - if self._cpu_offload: - param.zero_cpu_grad.copy_(reduced_grad.data, non_blocking=True) - # Don't let this memory get reused until after the transfer. - reduced_grad.data.record_stream(torch.cuda.current_stream()) - - def _register_final_backward_hook(self) -> None: - """Try to queue a `_final_backward_hook` callback. - - Only called on root and only queue one callback at the beginning of - outer most backward. - """ - assert self._is_root - if not self._post_backward_callback_queued: - self._assert_state([TrainingState.IDLE]) - self._post_backward_callback_queued = True - Variable._execution_engine.queue_callback(self._final_backward_hook) - - @torch.no_grad() - def _final_backward_hook(self) -> None: - """Wait for post-backward to finish. Only called on root instance.""" - # None, backward runtime swallow the assert error, so we use assert_in_engine() here. - assert_in_engine(self._is_root, "FinalBackwardHook not called on root") - # Check if the root module has params and if any of them has - # the `requires_grad` field set. If `requires_grad=False` for - # all the params, the post_backward hook will not fire and the - # state will remain in `TrainingState.PRE_BACKWARD`. - if any([p.requires_grad for p in self.params]): - self._assert_state(TrainingState.POST_BACKWARD) - else: - self._assert_state(TrainingState.PRE_BACKWARD) - self.param_manager.use_fp32_shards() - if self._require_backward_grad_sync: - # Flush any unreduced buckets in the post_backward stream. - with torch.cuda.stream(self._streams["post_backward"]): - assert_in_engine(self._reducer is not None, "FinalBackwardHook: reducer is None") - assert self._reducer is not None # make mypy happy - self._reducer.flush() - torch.cuda.current_stream().wait_stream(self._streams["post_backward"]) - if self._cpu_offload: - # Wait for the non-blocking GPU -> CPU grad transfers to finish. - torch.cuda.current_stream().synchronize() - - # A backward pass is done, clean up below. - # Free reducer buffers. - if self._reducer is not None: - self._reducer.free() - - def _finalize_parameters(zero_module: ShardedModel) -> None: - """Helper used below on all zero3 modules.""" - for p in zero_module.params: - if not p.requires_grad: - continue - if hasattr(p, "zero_shard_bwd_hook"): - p.zero_shard_bwd_hook.remove() - delattr(p, "zero_shard_bwd_hook") - - # Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad - # remains the unsharded gradient accumulated from prior no-sync passes, and p.zero_saved_grad_shard - # remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and - # sync passes, if desired. - if not self._require_backward_grad_sync: - continue - - # Parameter and gradient devices must match. - if hasattr(p, "zero_cpu_grad"): - assert_in_engine(p.device == torch.device("cpu"), - f"FinalBackwardHook: incorrect cpu_grad device {p.device}") - p.grad = p.zero_cpu_grad - elif hasattr(p, "zero_saved_grad_shard"): - assert_in_engine( - p.device == p.zero_saved_grad_shard.device, - f"FinalBackwardHook: incorrect saved_grad_shard device \ - {p.device} vs {p.zero_saved_grad_shard.device}", - ) - p.grad = p.zero_saved_grad_shard - elif hasattr(p, 'zero_saved_grad'): - p.grad = p.zero_saved_grad - - if hasattr(p, "zero_saved_grad_shard"): - delattr(p, "zero_saved_grad_shard") - if hasattr(p, 'zero_saved_grad'): - delattr(p, "zero_saved_grad") - - # Update root and nested ShardedModel's hooks and flags. - for m in self.modules(): # includes self - if isinstance(m, ShardedModel): - _finalize_parameters(m) - m._pre_backward_hook_has_run = False - if any(p.requires_grad for p in m.parameters()): - # Check if the module has params and if any of them has - # the `requires_grad` field set. If `requires_grad=False` for - # all the params, the post_backward hook will not fire and the - # state will remain in `TrainingState.PRE_BACKWARD`. - if any([p.requires_grad for p in m.params]): - m._assert_state(TrainingState.POST_BACKWARD) - else: - m._assert_state(TrainingState.PRE_BACKWARD) - else: - # When `m` and its children has no params or has params but - # none with `requires_grad==True`, there are two cases: - # 1. output tensors are `requires_grad==True`. In this case, - # pre-backward hook is still registered, so it is in PRE_BACKWARD state. - # 2. output tensors are `requires_grad==False`. In this case, - # pre-backward hook is not registered, so it is in IDLE state. - m._assert_state([TrainingState.PRE_BACKWARD, TrainingState.IDLE]) - m.training_state = TrainingState.IDLE - - if m._is_root: - # reset this flag for cases like "one forward pass + multiple backward passes" - self._post_backward_callback_queued = False - # clear this list for next iteration - assert_in_engine( - self._output_pre_backward_hook_registered is not None, - "FinalBackwardHook: self._output_pre_backward_hook_registered should not be None", - ) - assert self._output_pre_backward_hook_registered is not None # make mypy happy - self._output_pre_backward_hook_registered.clear() - - @contextlib.contextmanager - def gather_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator: - """ - A context manager to expose full params for the current ShardedModel instance. - Can be useful *after* forward/backward for a model to get the params for - additional processing or checking. Parameters will be gathered in full - precision (e.g., FP32). - - .. note:: This can be used on inner ShardedModels. - - .. note:: This can *not* be used within a forward or backward pass. Nor - can forward and backward be started from within this context. - - .. note:: The full parameters will be freed after the context manager - exits; it is up to the caller to clone them if needed. - - .. note:: The full parameters can be modified, but only the portion - corresponding to the local param shard will persist after the - context manager exits (unless ``volatile=True``, in which case there - are no guarantees about persistence). - - Args: - recurse (bool, Optional): recursively summon all params for nested - ShardedModel instances (default: True) - volatile (bool, Optional): if ``True``, modifications to params are - not guaranteed to persist after the context manager exists; - enabling this can be slightly more efficient (default: False) - """ - if recurse: - with contextlib.ExitStack() as stack: - # Summon all params for any nested Zero3Model instances. - for module in self.modules(): - if isinstance(module, ShardedModel): - stack.enter_context(module.gather_full_params(recurse=False, volatile=volatile)) - # Yield to the caller, with full params in all nested instances. - yield - # Exiting from the ExitStack will re-shard params. - return - else: - torch.cuda.synchronize() - self._lazy_init() - self._assert_state(TrainingState.IDLE) - # Set the state so that we assert when trying to go into - # forward/backward. - self.training_state = TrainingState.GATHER_FULL_PARAMS - full_tensors = self.param_manager.rebuild_full_params(force_full_precision=True) - assert full_tensors is not None - with contextlib.ExitStack() as stack: - try: - yield - finally: - stack.close() - for p, (full_tensor, safe_to_free) in zip(self.params, full_tensors): - if not volatile: - # Copy any changes made to the full params back into - # the corresponding local shards. - local_shard, _ = get_shard(full_tensor) - p.zero_fp32_shard.copy_(local_shard.view_as(p.zero_fp32_shard)) - if safe_to_free: - free_storage(full_tensor) - self.has_full_params = False - self.param_manager.use_fp32_shards() - self.training_state = TrainingState.IDLE - - def apply(self, fn: Callable[[nn.Module], None]) -> "ShardedModel": - """ - Applies ``fn`` recursively to every submodule (as returned by - ``.children()``) as well as self. Typical use includes initializing the - parameters of a model. - - Compared to ``torch.nn.Module.apply``, this version additionally gathers - the full parameters before applying ``fn``. It should not be called from - within another ``summon_full_params`` context. - - Args: - fn (nn.Module): function to be applied to each submodule - - Returns: - Module: self - """ - is_uninitialized = self._is_root is None - self._assert_state(TrainingState.IDLE) - with self.gather_full_params(recurse=False): - return_value = super().apply(fn) - # summon_full_params will call _lazy_init, which sets _is_root. However, - # apply() may be called directly on children instances to do weight - # init, so we should reset the _is_root flag in this case. - if is_uninitialized and self._is_root: - for module in self.modules(): - if isinstance(module, ShardedModel): - module._reset_lazy_init_info() - return return_value - - def __getattr__(self, name: str) -> Any: - try: - return super().__getattr__(name) - except AttributeError: - return getattr(self.module, name) - - def __getstate__(self) -> Dict[str, str]: - """Serialize the state. - - Some properties are not serializable (e.g., process groups, streams), so - we remove them and try to reconstruct them in :func:`__setstate__`. - """ - state = copy.copy(self.__dict__) - state["is_sharded"] = [p.zero_is_sharded for p in self.params] - state["orig_sizes"] = [p.zero_orig_size for p in self.params] - if state["process_group"] is not None: - state["process_group"] = "MISSING" # process_group isn't pickleable - if state["process_group_reduce_scatter"] is not None: - state["process_group_reduce_scatter"] = "MISSING" # process_group_reduce_scatter isn't pickleable - self._reset_lazy_init_info() - return state - - def __setstate__(self, state: Dict[str, Any]) -> None: - """Intercept state setting and perform needed changes on params.""" - super().__setstate__(state) - - def fixup(p: Parameter, is_sharded: bool, size: torch.Size) -> Parameter: - assert isinstance(p, Parameter) - p.data = p.data.clone() # move tensors out of shared memory - p.zero_is_sharded = is_sharded - p.zero_orig_size = size - return p - - self.params = [ - fixup(p, is_sharded, size) for p, is_sharded, size in zip(self.params, self.is_sharded, self.orig_sizes) - ] - del self.is_sharded - del self.orig_sizes - self._reset_lazy_init_info() - - def __getitem__(self, key: int) -> Any: - """Forward indexing calls in case the module is a nn.Sequential.""" - return self.module.__getitem__(key) - - @contextlib.contextmanager - def no_sync(self) -> Generator: - """ - A context manager to disable gradient synchronizations across ShardedModel - processes. Within this context, gradients will be accumulated on module - variables, which will later be synchronized in the first - forward-backward pass after exiting the context. - - .. note:: This likely results in higher memory usage because ShardedModel will - accumulate the full model gradients (instead of gradient shards) - until the eventual sync. - - .. note:: Gradient accumulation can be done without this context, - avoiding the extra GPU memory overhead, but with the extra - networking overhead. - """ - self._lazy_init() - assert self._is_root, "no_sync on inner ShardedModel is not supported" - self._assert_state(TrainingState.IDLE) - # This instance may wrap other ShardedModel instances and we - # need to set all of them to accumulate gradients. - old_flags = [] - for m in self.modules(): # includes self - if isinstance(m, ShardedModel): - old_flags.append((m, m._require_backward_grad_sync)) - m._require_backward_grad_sync = False - try: - yield - finally: - for m, old_flag in old_flags: - assert m._require_backward_grad_sync is False - m._require_backward_grad_sync = old_flag - - def _assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None: - """Assert we are in the given state.""" - # Since assert can be turned off and this error checking - # is really important, we use explicit error checking - # and raise a ValueError if needed. - if isinstance(state, TrainingState): - state = [state] - if self.training_state not in state: - msg = f"expected to be in states {state} but current state " f"is {self.training_state}" - # In case we are failing in the context of autograd hook, asserting - # may not generate useful msg. So, let's print it to be sure. - self.logger.error(f'Zero3 instance {self} got error: {msg}', ranks=[0]) - if self.rank == 0: - traceback.print_stack() - raise ValueError(msg) - - def extra_repr(self) -> str: - repr = (f"world_size={self.world_size}, " - f"mixed_precision={self.mixed_precision}, ") - if self.verbose: - repr = (f"rank={self.rank}, " + repr + f"reshard_after_forward={self.reshard_after_forward}, " - f"compute_dtype={self.compute_dtype}, " - f"buffer_dtype={self.buffer_dtype}, " - f"fp32_reduce_scatter={self.fp32_reduce_scatter}, " - f"compute_device={self.compute_device}" - f"reduce_scatter_bucket_size_mb={self.reduce_scatter_bucket_size_mb}, " - f"clear_autocast_cache={self.clear_autocast_cache}" - f"force_input_to_fp32={self.force_input_to_fp32}" - f"offload_config={self.offload_config}") - return repr - - def state_dict(self, destination=None, prefix='', keep_vars=False): - """ - Returns the whole (unsharded) state of the module. Parameters are not - sharded, so the resulting state_dict can be loaded directly by the - wrapped Module without any sharding-specific logic. Returned tensors - will be full precision (e.g., FP32). - - .. warning:: This needs to be called on all ranks, since synchronization - primitives will be used. - """ - if torch.cuda.is_available(): - torch.cuda.synchronize() - self._lazy_init() - - def maybe_cast_buffers(dtype: Optional[torch.dtype] = None) -> None: - if self.mixed_precision: - self._cast_buffers(dtype=dtype) - - assert self._return_full_state_dict is True, 'Only support return full state dict now' - if self.training_state != TrainingState.GATHER_FULL_PARAMS: - with self.gather_full_params(recurse=False, volatile=True): - maybe_cast_buffers(torch.float32) - state_dict = super().state_dict() - else: - maybe_cast_buffers(torch.float32) - state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) - - if self._cpu_offload: - for k, tensor in state_dict.items(): - state_dict[k] = tensor.cpu() - - # In case we are in mixed precision, restore buffers back to buffer_dtype. - maybe_cast_buffers() - return state_dict - - def load_state_dict(self, - state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], - strict: bool = True) -> NamedTuple: - """ - Load a whole (unsharded) state_dict. - - .. warning:: This needs to be called on all ranks, since synchronization - primitives will be used. - """ - if self._return_full_state_dict: - with self.gather_full_params(): - return self.module.load_state_dict(state_dict, strict) - else: - torch.cuda.synchronize() - self._lazy_init() - return self.module.load_state_dict(state_dict, strict) - - -def _post_state_dict_hook( - state_dict_on_rank_0_only: bool, - module: Zero3ParameterManager, - state_dict: "OrderedDict[str, torch.Tensor]", - prefix: str, - *args: Any, -) -> "OrderedDict[str, torch.Tensor]": - # When state_dict_on_rank_0_only is ``True``, ``model.state_dict()`` will only - # returns full state dict on rank 0 and return empty dict non-rank 0, - # which allow ShardedModel to skip the GPU -> CPU copy on - # non-rank 0 altogether and prevent OOM. - if state_dict_on_rank_0_only and dist.get_rank() != 0: - state_dict.clear() - return state_dict - # Assuming we are in a ``gather_full_params()`` context, we need to clone - # each tensor so that it does not get freed (in-place) when the context - # exits. At the same time, this hook can be called multiple times - # recursively, so we need to make sure that we only clone each tensor at - # most once. Thus we add an attribute on the tensor called "_has_been_cloned" - # which keeps track of tensors that are no longer at risk of being freed. - for key in state_dict.keys(): - if not key.startswith(prefix) or getattr(state_dict[key], "_has_been_cloned", False): - continue - if state_dict[key].device.type != module.state_dict_device.type: - state_dict[key] = state_dict[key].to(device=module.state_dict_device) - state_dict[key]._has_been_cloned = True - elif module.training_state == TrainingState.GATHER_FULL_PARAMS: - # We copy the state_dict since full param will be freed after we - # exit the ``summon_full_params()`` context. - state_dict[key] = state_dict[key].clone() - state_dict[key]._has_been_cloned = True - - # Remove "_zero3_module." prefix - replace_state_dict_prefix(state_dict, prefix + "_zero3_module.", prefix) - return state_dict - - -def _pre_load_state_dict_hook(state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str, - *args: Any) -> None: - replace_state_dict_prefix(state_dict, prefix, prefix + "_zero3_module.") diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 392d25226..9860f91dd 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -18,8 +18,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter -from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, - get_gradient_predivide_factor) +from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, + get_gradient_predivide_factor) class ShardedModelV2(nn.Module): diff --git a/colossalai/zero/sharded_optim/__init__.py b/colossalai/zero/sharded_optim/__init__.py index 94ef14b95..b71a70aef 100644 --- a/colossalai/zero/sharded_optim/__init__.py +++ b/colossalai/zero/sharded_optim/__init__.py @@ -1,4 +1,3 @@ -from .sharded_optim import ShardedOptimizer from .sharded_optim_v2 import ShardedOptimizerV2 -__all__ = ['ShardedOptimizer', 'ShardedOptimizerV2'] +__all__ = ['ShardedOptimizerV2'] diff --git a/colossalai/zero/sharded_optim/bookkeeping/__init__.py b/colossalai/zero/sharded_optim/bookkeeping/__init__.py deleted file mode 100644 index a96c6b147..000000000 --- a/colossalai/zero/sharded_optim/bookkeeping/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .gradient_store import GradientStore -from .parameter_store import ParameterStore -from .bucket_store import BucketStore -from .tensor_bucket import TensorBucket - -__all__ = ['GradientStore', 'ParameterStore', 'BucketStore', 'TensorBucket'] \ No newline at end of file diff --git a/colossalai/zero/sharded_optim/bookkeeping/base_store.py b/colossalai/zero/sharded_optim/bookkeeping/base_store.py deleted file mode 100644 index 78cc0479b..000000000 --- a/colossalai/zero/sharded_optim/bookkeeping/base_store.py +++ /dev/null @@ -1,17 +0,0 @@ -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode - - -class BaseStore: - - def __init__(self, dp_parallel_mode=ParallelMode.DATA): - self._world_size = gpc.get_world_size(dp_parallel_mode) - self._local_rank = gpc.get_local_rank(dp_parallel_mode) - - @property - def world_size(self): - return self._world_size - - @property - def local_rank(self): - return self._local_rank diff --git a/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py b/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py deleted file mode 100644 index 37f5a3b99..000000000 --- a/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py +++ /dev/null @@ -1,43 +0,0 @@ -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode -from .base_store import BaseStore - -class BucketStore(BaseStore): - - def __init__(self, dp_parallel_mode): - super().__init__(dp_parallel_mode) - self._grads = dict() - self._params = dict() - self._num_elements_in_bucket = dict() - - self.reset() - - def num_elements_in_bucket(self, reduce_rank: int = None): - return self._num_elements_in_bucket[reduce_rank] - - def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None): - self._num_elements_in_bucket[reduce_rank] += num_elements - - def add_grad(self, tensor, reduce_rank: int = None): - self._grads[reduce_rank].append(tensor) - - def add_param(self, tensor, reduce_rank: int = None): - self._params[reduce_rank].append(tensor) - - def reset(self): - keys = [None] + list(range(self._world_size)) - self._grads = {rank: [] for rank in keys} - self._params = {rank: [] for rank in keys} - self._num_elements_in_bucket = {rank: 0 for rank in keys} - - def reset_by_rank(self, reduce_rank=None): - self._grads[reduce_rank] = [] - self._params[reduce_rank] = [] - self._num_elements_in_bucket[reduce_rank] = 0 - - - def get_grad(self, reduce_rank: int = None): - return self._grads[reduce_rank] - - def get_param(self, reduce_rank: int = None): - return self._params[reduce_rank] diff --git a/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py b/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py deleted file mode 100644 index 0abcbc8c1..000000000 --- a/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import List -from torch import Tensor -from .base_store import BaseStore - - -class GradientStore(BaseStore): - - def __init__(self, *args): - super().__init__(*args) - # bookkeeping data structures - self._averaged_gradients = dict() - - # for backward reduction hooks - self._grad_acc_objs = [] - - def add_accumulate_grad_object(self, obj): - """ - Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not - be attached successfully. - - :param obj: An object of :class:`AccumulateGrad` class - :type obj: :class:`AccumulateGrad` - """ - - self._grad_acc_objs.append(obj) - - def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]: - """ - Return average gradients of a parameter group - - :param group_id: The index of parameter group - :type group_id: int - - :return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter. - :rtype: List[torch.Tensor] - """ - - return self._averaged_gradients[group_id] - - def add_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None: - """ - Append an average gradient to the list of averaged gradients of a parameter group - - :param group_id: The index of a parameter group - :param tensor: A :class:`torch.Tensor` object - :type group_id: int - :type tensor: torch.Tensor - - """ - - if group_id in self._averaged_gradients: - self._averaged_gradients[group_id].append(tensor) - else: - self._averaged_gradients[group_id] = [tensor] - - def reset_average_gradients_by_group(self, group_id: int) -> None: - """ - Reset the bookkeeping data structure for averaged gradients to an empty list - - :param group_id: The index of a parameter group - :type group_id: int - """ - - self._averaged_gradients[group_id] = [] - - diff --git a/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py b/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py deleted file mode 100644 index 6a7cf7513..000000000 --- a/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py +++ /dev/null @@ -1,96 +0,0 @@ -from .base_store import BaseStore -from torch import Tensor -from typing import List - - -class ParameterStore(BaseStore): - - def __init__(self, dp_paralle_mode): - super().__init__(dp_paralle_mode) - # param partitioning data structures - self._fp16_param_to_rank = dict() - self._rank_groupid_to_fp16_param_list = dict() - self._rank_group_id_to_flat_fp16_param = dict() - - # param reduction data structures - self._is_param_reduced = dict() - self._reduced_param = [] - - def set_param_to_rank(self, tensor: Tensor, rank: int) -> None: - """ - Set the mapping between parameter to rank, each parameter should be owned by a rank. - - :param tensor: A :class:`torch.Tensor` object - :type tensor: torch.Tensor - :param rank: The rank of which the process is responsible for updating the parameter - :type rank: int - """ - - self._fp16_param_to_rank[tensor] = rank - - def get_param_rank(self, tensor: Tensor) -> int: - """ - Gives the rank which the parameter belongs to - - :param tensor: A :class:`torch.Tensor` object - :type tensor: torch.Tensor - """ - return self._fp16_param_to_rank[tensor] - - def belongs_to_current_rank(self, tensor) -> bool: - """ - Check whether a parameter is supposed to be updated by the process of the current rank - - :param tensor: A :class:`torch.Tensor` object - :type tensor: torch.Tensor - - :return: True if the parameter should be updated by the current rank. Otherwise false. - :rtype: bool - """ - - tensor_rank = self._fp16_param_to_rank[tensor] - return tensor_rank == self._local_rank - - def add_fp16_param_list_by_rank_group(self, rank, group_id, - tensor_list) -> None: - if rank not in self._rank_groupid_to_fp16_param_list: - self._rank_groupid_to_fp16_param_list[rank] = dict() - - if group_id not in self._rank_groupid_to_fp16_param_list[rank]: - self._rank_groupid_to_fp16_param_list[rank][group_id] = [] - - self._rank_groupid_to_fp16_param_list[rank][group_id].extend( - tensor_list) - - def get_fp16_params_by_rank_group(self, rank, group_id) -> List[Tensor]: - return self._rank_groupid_to_fp16_param_list[rank][group_id] - - def add_flat_fp16_param_by_rank_group(self, rank, group_id, tensor) -> None: - if rank not in self._rank_group_id_to_flat_fp16_param: - self._rank_group_id_to_flat_fp16_param[rank] = dict() - - self._rank_group_id_to_flat_fp16_param[rank][group_id] = tensor - - def get_flat_fp16_param_by_rank_group(self, rank, group_id) -> Tensor: - return self._rank_group_id_to_flat_fp16_param[rank][group_id] - - def is_param_reduced(self, tensor): - return self._is_param_reduced[tensor] - - def set_param_reduction_state(self, tensor, state): - self._is_param_reduced[tensor] = state - - def get_param_reduction_states(self): - return self._is_param_reduced - - def reset_previous_reduced_params(self): - self._reduced_param = [] - - def add_previous_reduced_param(self, tensor): - self._reduced_param.append(tensor) - - def clear_grads_of_previous_reduced_params(self): - if len(self._reduced_param) > 0: - for param in self._reduced_param: - param.grad = None - self.reset_previous_reduced_params() diff --git a/colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py b/colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py deleted file mode 100644 index c07f03263..000000000 --- a/colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py +++ /dev/null @@ -1,54 +0,0 @@ -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors - - -class TensorBucket: - - def __init__(self, size): - self._max_size = size - self._current_size = 0 - self._bucket = [] - - @property - def max_size(self): - return self._max_size - - @property - def current_size(self): - return self._current_size - - def is_full_or_oversized(self): - return self._current_size >= self._max_size - - def is_empty(self): - return len(self._bucket) == 0 - - def add_to_bucket(self, tensor, allow_oversize=False): - tensor_size = tensor.numel() - - if not allow_oversize and self.will_exceed_max_size(tensor_size): - msg = f"The param bucket max size {self._max_size} is exceeded" \ - + f"by tensor (size {tensor_size})" - raise RuntimeError(msg) - - self._bucket.append(tensor) - self._current_size += tensor_size - - def will_exceed_max_size(self, tensor_size): - expected_size = self._current_size + tensor_size - return expected_size > self._max_size - - def get_bucket(self): - return self._bucket - - def empty(self): - self._bucket = [] - self._size = 0 - - def flatten(self): - return _flatten_dense_tensors(self._bucket) - - def unflatten_and_copy(self, flat_tensor): - unflattened_tensor_list = _unflatten_dense_tensors( - flat_tensor, self._bucket) - for old, new in zip(self._bucket, unflattened_tensor_list): - old.copy_(new) diff --git a/colossalai/zero/sharded_optim/sharded_optim.py b/colossalai/zero/sharded_optim/sharded_optim.py deleted file mode 100644 index 2ea2feaf6..000000000 --- a/colossalai/zero/sharded_optim/sharded_optim.py +++ /dev/null @@ -1,563 +0,0 @@ -from colossalai.utils.cuda import get_current_device -import torch -import torch.distributed as dist -from colossalai.logging import get_dist_logger -from torch.optim import Optimizer -from .bookkeeping import ParameterStore, GradientStore, BucketStore, TensorBucket -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler -from colossalai.nn.optimizer import ColossalaiOptimizer -from ._utils import (move_tensor, flatten, get_grad_accumulate_object, split_half_float_double, reduce_tensor, - release_param_grad, calculate_global_norm_from_list, compute_norm, sync_param, has_inf_or_nan) -from functools import partial - - -class ShardedOptimizer(ColossalaiOptimizer): - - def __init__(self, - optimizer: Optimizer, - initial_scale=2**32, - min_scale=1, - growth_factor=2, - backoff_factor=0.5, - growth_interval=1000, - hysteresis=2, - max_scale: int = 2**32, - clip_grad_norm=2.0, - verbose=False, - reduce_bucket_size=500000000, - communication_dtype=torch.float16, - overlap_communication=False, - partition_grad=False, - dp_parallel_mode=ParallelMode.DATA, - mp_parallel_mode=ParallelMode.MODEL, - cpu_offload=False, - cpu_fp16_param=False, - cpu_fp16_grad=False): - - # TODO: add support for - # 1. fp16 master weights - # 2. contiguous gradients - # 3. cpu offload - # 4. support when some parameters requires_grad = False - - self._optimizer = optimizer - self._dtype = self._optimizer.param_groups[0]['params'][0].dtype - self._logger = get_dist_logger() - self._verbose = verbose - - # stage 2 - self._partition_grads = partition_grad - - # cpu_offload - self._cpu_offload = cpu_offload - self._cpu_fp16_param = cpu_fp16_param - self._cpu_fp16_grad = cpu_fp16_grad - - # get process groups - self._dp_parallel_mode = dp_parallel_mode - self._mp_parallel_mode = mp_parallel_mode - self._local_rank = gpc.get_local_rank(dp_parallel_mode) - self._world_size = gpc.get_world_size(dp_parallel_mode) - - self._dp_group = gpc.get_group(dp_parallel_mode) - if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1: - self._mp_group = gpc.get_group(mp_parallel_mode) - else: - self._mp_group = None - - # fp16 and fp32 params for mixed precision training - self._fp16_param_groups = dict() - self._fp32_flat_param_groups_of_current_rank = dict() - - # communication params - self._overlap_communication = overlap_communication - self._reduce_bucket_size = reduce_bucket_size - self._communication_dtype = communication_dtype - - # gradient scaler - self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale, - verbose=verbose) - self._found_overflow = torch.FloatTensor([0]).to(get_current_device()) - - # gradient clipping - self._clip_grad_norm = clip_grad_norm - - # check argument conflict - self._sanity_checks() - - # ParameterStore will manage the tensor buffers used for zero - # it will not manage the tensors used by mixed precision training - self._param_store = ParameterStore(self._dp_parallel_mode) - self._grad_store = GradientStore(self._dp_parallel_mode) - self._bucket_store = BucketStore(self._dp_parallel_mode) - - # iterate over the param group in the optimizer - # partition these param groups for data parallel training - # and add buffers to parameter store for future access - for group_id, param_group in enumerate(self._optimizer.param_groups): - params = param_group['params'] - - # add the fp16 params to fp16_param_groups for bookkeeping - self._fp16_param_groups[group_id] = params - - # assign parameters to ranks - # the params in the list are sorted - params_per_rank = self._partition_param_list(params) - - # store the mapping between param to rank - # each param should belong to only one rank - for rank, params in enumerate(params_per_rank): - self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params) - for param in params: - self._param_store.set_param_to_rank(param, rank) - - # move to cpu to make room to create the flat tensor - move_tensor(params, device='cpu') - - # flatten the reordered tensors - for rank in range(self._world_size): - tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) - flat_tensor = flatten(tensor_list) - flat_tensor = flat_tensor.cuda() - self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor) - - # sync parameters - for rank in range(self._world_size): - flat_tensor = self._param_store.get_flat_fp16_param_by_rank_group(rank, group_id) - tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) - sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list) - - # create a copy of fp32 weights of the parameters for which this rank is responsible - fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(self._local_rank, group_id) - # when using cpu offload, our cpu adam support fp16 paramters - if self._cpu_fp16_param: - fp32_flat_current_rank = fp16_flat_current_rank.detach() - else: - fp32_flat_current_rank = fp16_flat_current_rank.detach().float() - device = 'cpu' if self._cpu_offload else get_current_device() - fp32_flat_current_rank = fp32_flat_current_rank.to(device) - fp32_flat_current_rank.requires_grad = True - self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank - - # need to replace the params in the `params` field in the optimizer - # so that when the optimizer calls step(), it only updates the tensors - # managed by this data parallel rank - param_group['params'] = [fp32_flat_current_rank] - - # set reduction state - for param in self._fp16_param_groups[group_id]: - self._param_store.set_param_reduction_state(param, False) - - # intialize communication stream for - # communication-compuation overlapping - if self._overlap_communication: - self._comm_stream = torch.cuda.Stream() - - # reduction hook is only used if overlapping communication - # or stage 2 is used - # if it is stage 1 without overlapping, no hook will be attached - if self._overlap_communication or self._partition_grads: - self._attach_reduction_hook() - - self._initialize_optimizer_states() - - @property - def loss_scale(self): - return self.grad_scaler.scale - - @property - def num_param_groups(self): - return len(self._fp16_param_groups) - - def _partition_param_list(self, param_list): - params_per_rank = [[] for _ in range(self._world_size)] - numel_per_rank = [0 for _ in range(self._world_size)] - - # partititon the parameters in a greedy fashion - sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) - for param in sorted_params: - # allocate this parameter to the rank with - # the smallest numel for load balancing purpose - rank_to_go = numel_per_rank.index(min(numel_per_rank)) - params_per_rank[rank_to_go].append(param) - numel_per_rank[rank_to_go] += param.numel() - - if self._verbose: - self._logger.info(f'Number of elements on ranks: {numel_per_rank}', - ranks=[0], - parallel_mode=self._dp_parallel_mode) - return params_per_rank - - def _initialize_optimizer_states(self): - # create a dummy zero tensor which has the same shape as that of the param - # set this dummpy zero tensor as grad - for group_id in range(len(self._fp32_flat_param_groups_of_current_rank)): - fp32_partition_param = self._fp32_flat_param_groups_of_current_rank[group_id] - fp32_partition_grad = torch.zeros_like(fp32_partition_param) - fp32_partition_param.grad = fp32_partition_grad - - # update the parameter with zero gradients for initialization of optimizer stateus - self._optimizer.step() - - # remove the grad of the paramter to save memory - for group_id, fp32_flat_tensor in self._fp32_flat_param_groups_of_current_rank.items(): - fp32_flat_tensor.grad = None - - def _sanity_checks(self): - assert torch.cuda.is_available(), 'CUDA is required' - assert self._dtype == torch.float16, \ - f'Parameters are expected to be of type torch.float16, but got {self._dtype}' - - ########################################################### - # Backward Reduction Hook - ########################################################### - - def _attach_reduction_hook(self): - # we iterate over the fp16 params - # on each param, we register a hook to its AccumulateGrad object - for group_id in range(self.num_param_groups): - param_group = self._fp16_param_groups[group_id] - for param in param_group: - if param.requires_grad: - # determines the reduction destionation rank - # this is only valid for stage 2 - # dst_rank = None means using all-reduce - # else using reduce - if self._partition_grads: - reduce_rank = self._param_store.get_param_rank(param) - else: - reduce_rank = None - - def _define_and_attach(param, reduce_rank): - # get the AccumulateGrad object of the param itself - accum_grad_obj = get_grad_accumulate_object(param) - self._grad_store.add_accumulate_grad_object(accum_grad_obj) - - reduction_func = partial(self._reduce_and_remove_grads_by_bucket, - param=param, - reduce_rank=reduce_rank) - - # define hook - # NOT IMPORTANT BUT GOOD TO KNOW: - # args here is not grad, but allow_unreacable and accumulate_grad - def reduce_grad_hook(*args): - reduction_func() - - accum_grad_obj.register_hook(reduce_grad_hook) - - _define_and_attach(param, reduce_rank) - - def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None): - param_size = param.numel() - - # check if the bucket is full - # if full, will reduce the grads already in the bucket - # after reduction, the bucket will be empty - if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: - self._reduce_grads_in_bucket(reduce_rank) - - # the param must not be reduced to ensure correctness - is_param_reduced = self._param_store.is_param_reduced(param) - if is_param_reduced: - msg = f'Parameter of size ({param.size()}) has already been reduced, ' \ - + 'duplicate reduction will lead to arithmetic incorrectness' - raise RuntimeError(msg) - - # the param must have grad for reduction - assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced' - - self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank) - self._bucket_store.add_grad(param.grad, reduce_rank) - self._bucket_store.add_param(param, reduce_rank) - - def _reduce_grads_in_bucket(self, reduce_rank=None): - # reduce grads - self._reduce_grads_by_rank(reduce_rank=reduce_rank, - grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), - bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank)) - - # use communication stream if overlapping - # communication with computation - if self._overlap_communication: - stream = self._comm_stream - else: - stream = torch.cuda.current_stream() - - with torch.cuda.stream(stream): - params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank) - - for param in params_in_bucket: - # the is_param_reduced flag should be False showing that - # this param is not reduced before calling self._reduce_grads_by_rank - is_param_reduced = self._param_store.is_param_reduced(param) - - if is_param_reduced: - msg = f'Parameter of size ({param.size()}) has been reduced, ' + \ - 'duplicate reduction will lead to arithmetic incorrectness' - raise RuntimeError(msg) - - # update the flag - self._param_store.set_param_reduction_state(param, True) - - # if partition grads = True - # we do not keep the gradient after reduction - if self._partition_grads and not self._param_store.belongs_to_current_rank(param): - if self._overlap_communication: - # we need to keep this gradient for now as reduction may - # be completed yet since it is using a different cuda stream - self._param_store.add_previous_reduced_param(param) - else: - param.grad = None - - self._bucket_store.reset_by_rank(reduce_rank) - - def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size): - grad_buckets_by_dtype = split_half_float_double(grads) - - for tensor_list in grad_buckets_by_dtype: - self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank) - - ############################## - # Reduction Utility Function # - ############################## - def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank): - param_bucket = TensorBucket(size=bucket_size) - - for tensor in tensor_list: - param_bucket.add_to_bucket(tensor, allow_oversize=True) - - if param_bucket.is_full_or_oversized(): - self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) - param_bucket.empty() - - if not param_bucket.is_empty(): - self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) - - def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank): - if self._overlap_communication: - torch.cuda.synchronize() - self._param_store.clear_grads_of_previous_reduced_params() - stream = self._comm_stream - else: - stream = torch.cuda.current_stream() - - with torch.cuda.stream(stream): - flat = bucket.flatten() - reduced_flat = reduce_tensor(tensor=flat, - dtype=self._communication_dtype, - dst_rank=reduce_rank, - parallel_mode=self._dp_parallel_mode) - - # update the reduced tensor - if reduce_rank is None or reduce_rank == self._local_rank: - bucket.unflatten_and_copy(reduced_flat) - - ################################ - # torch.optim.Optimizer methods - ################################ - - def backward(self, loss, retain_graph=True): - loss = self.loss_scale * loss - loss.backward(retain_graph=retain_graph) - - def zero_grad(self, set_to_none=True): - """ - Set parameter gradients to zero. If set_to_none = True, gradient - will be set to None to save memory. - - :param set_to_none: Whether set the gradient to None. Default value is True. - :type set_to_none: bool - """ - for group_id, param_group in self._fp16_param_groups.items(): - for param in param_group: - if set_to_none: - param.grad = None - else: - if param.grad is not None: - param.grad.detach() - param.grad.zero_() - - #################### - # Update Parameter # - #################### - - def step(self, closure=None): - assert closure is None, 'closure is not supported by step()' - - # check for overflow - found_inf = self._check_overflow() - self.grad_scaler.update(found_inf) - - # update loss scale if overflow occurs - if found_inf: - self._grad_store._averaged_gradients = dict() - self.zero_grad() - return - - # copy the grad of fp16 param to fp32 param - single_grad_partition_groups = [] - norm_groups = [] - - for group_id in range(self.num_param_groups): - # compute norm - norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id], - params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id, - rank=self._local_rank), - dp_group=self._dp_group, - mp_group=self._mp_group) - norm_groups.append(norm_group) - - # create flat gradient for the flat fp32 params - fp16_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id) - flat_fp16_avg_grads = flatten(fp16_avg_grads) - - dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype - flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype) - - param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape - assert param_shape == flat_fp32_avg_grads.shape, \ - f'fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}' - - single_grad_partition_groups.append(flat_fp32_avg_grads) - device = self._fp32_flat_param_groups_of_current_rank[group_id].device - self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device) - self._grad_store._averaged_gradients[group_id] = [] - self._grad_store._averaged_gradients[group_id] = [] - - # unscale and clip grads - global_norm = calculate_global_norm_from_list(norm_list=norm_groups) - self._unscale_and_clip_grads(single_grad_partition_groups, global_norm) - - # update the parameters - self._optimizer.step() - # release the fp32 grad - release_param_grad(self._fp32_flat_param_groups_of_current_rank.values()) - - # update fp16 partition updated by the current rank - for group_id in range(len(self._fp16_param_groups)): - fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=self._local_rank, group_id=group_id) - fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id].to(fp16_param.device) - fp16_param.data.copy_(fp32_param) - - # broadcast the updated model weights - handles = [] - for group_id in range(self.num_param_groups): - for rank in range(self._world_size): - fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id) - handle = dist.broadcast(fp16_param, src=rank, group=self._dp_group, async_op=True) - handles.append(handle) - - for handle in handles: - handle.wait() - - ################## - # FP16 Utilities # - ################## - - def _check_overflow(self): - # clear previous overflow record - self._found_overflow.fill_(0.0) - - # check for overflow - for group_id in range(len(self._fp16_param_groups)): - for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id): - if avg_grad is not None and has_inf_or_nan(avg_grad): - self._found_overflow.fill_(1.0) - break - - # all-reduce across dp group - dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_group) - - # all-reduce over model parallel group - if self._mp_group: - dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_group) - - if self._found_overflow.item() > 0: - return True - else: - return False - - def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): - # compute combined scale factor for this group - combined_scale = self.loss_scale - - if self._clip_grad_norm > 0.: - # norm is in fact norm*scale - clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm - if clip > 1: - combined_scale = clip * self.loss_scale - - for grad in grad_groups_flat: - grad.data.mul_(1. / combined_scale) - - ############################ - # Gradient Synchronization # - ############################ - - def sync_grad(self): - if not self._partition_grads: - self._reduce_grad_stage1() - else: - # TODO: support async comm in reduce - self._reduce_grad_stage2() - - # update param already reduced flag - reduction_states = self._param_store.get_param_reduction_states() - for tensor, state in reduction_states.items(): - reduction_states[tensor] = False - - # clear reduced grads - if self._overlap_communication: - torch.cuda.synchronize() - self._param_store.clear_grads_of_previous_reduced_params() - - # accumulate gradient - avg_gradients = self._grad_store._averaged_gradients - for group_id in range(self.num_param_groups): - param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id) - - if group_id not in avg_gradients: - avg_gradients[group_id] = [] - - param_idx = 0 - for param in param_group: - if param.grad is not None: - if len(avg_gradients[group_id]) == param_idx: - avg_gradients[group_id].append(param.grad) - else: - avg_gradients[group_id][param_idx].add_(param.grad) - param_idx += 1 - - # the gradients needed are stored in the avg_gradients buffer - # thus, can clear this - self.zero_grad() - - def _reduce_grad_stage1(self): - # if not overlapping communication (no reduction hook is attached) - # we need to manually reduce these gradients - if not self._overlap_communication: - for group_id in range(len(self._fp16_param_groups)): - param_group = self._fp16_param_groups[group_id] - for param in param_group: - if param.grad is not None: - self._reduce_and_remove_grads_by_bucket(param) - - # we need to reduce the gradients - # left in the communication bucket - self._reduce_grads_in_bucket() - - def _reduce_grad_stage2(self): - # when partition_grads is True, reduction hooks - # are attached in the __init__ function, so we - # only need to reduce the gradients - # left in the communication bucket - for reduce_rank in range(self._world_size): - self._reduce_grads_in_bucket(reduce_rank) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index e69cf6654..3ba5fa4bd 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32 +from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32 from torch import Tensor from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter diff --git a/tests/test_utils/test_commons.py b/tests/test_utils/test_commons.py index dad0cacfc..b8ed42433 100644 --- a/tests/test_utils/test_commons.py +++ b/tests/test_utils/test_commons.py @@ -7,9 +7,7 @@ import colossalai import torch -from functools import partial import torch.multiprocessing as mp -import pytest def run_tensor_move(rank): diff --git a/tests/test_utils/test_zero_gradient_clippling.py b/tests/test_utils/test_zero_gradient_clippling.py index 508e7d33c..b60377cc8 100644 --- a/tests/test_utils/test_zero_gradient_clippling.py +++ b/tests/test_utils/test_zero_gradient_clippling.py @@ -2,11 +2,9 @@ # -*- encoding: utf-8 -*- import copy -import operator as op -from functools import partial, reduce -from typing import List import colossalai +from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 import pytest import torch import torch.distributed as dist @@ -14,10 +12,11 @@ import torch.multiprocessing as mp import torch.nn as nn from colossalai.logging import disable_existing_loggers from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port -from colossalai.zero.sharded_model import ShardedModel from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from colossalai.testing import parameterize +from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy +from functools import partial def checkpoint_wrapper(module, enable=True): @@ -97,41 +96,9 @@ def check_params(model, zero_model, loose=False): assert allclose(p, zero_p, loose=loose) -@parameterize('checkpoint', [False, True]) -@parameterize('fp16', [False, True]) -@parameterize('offload', [False, True]) -@parameterize('norm_type', [1.0, 2.0, float('inf')]) -def check_config(checkpoint=False, fp16=False, offload=False, norm_type=2.0): - model = Net(checkpoint=checkpoint).cuda() - zero_model = copy.deepcopy(model) - ddp_model = DDP(model) - - offload_config = {} - if offload: - offload_config['device'] = 'cpu' - zero_model = zero_model.cpu() - zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config) - - optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3) - zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3) - for _ in range(5): - x = torch.rand(2, 5).cuda() - run_step(ddp_model, optimizer, x, enable_autocast=fp16, norm_type=norm_type) - run_step(zero_model, zero_optimizer, x, enable_autocast=fp16, norm_type=norm_type) - check_grads(ddp_model, zero_model) - check_params(ddp_model, zero_model) - for _ in range(5): - x = torch.rand(2, 5).cuda() - run_step(ddp_model, optimizer, x, enable_autocast=False, norm_type=norm_type) - run_step(zero_model, zero_optimizer, x, enable_autocast=False, norm_type=norm_type) - check_grads(ddp_model, zero_model, loose=True) - check_params(ddp_model, zero_model, loose=True) - - def run_dist(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_config() @pytest.mark.dist diff --git a/tests/test_zero_data_parallel/test_shard_model_v2.py b/tests/test_zero_data_parallel/test_shard_model_v2.py index b22c2d86d..e0f3ca1b6 100644 --- a/tests/test_zero_data_parallel/test_shard_model_v2.py +++ b/tests/test_zero_data_parallel/test_shard_model_v2.py @@ -12,7 +12,7 @@ from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16 +from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model.utils import col_model_deepcopy from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP diff --git a/tests/test_zero_data_parallel/test_sharded_optim.py b/tests/test_zero_data_parallel/test_sharded_optim.py deleted file mode 100644 index 6720b8e39..000000000 --- a/tests/test_zero_data_parallel/test_sharded_optim.py +++ /dev/null @@ -1,168 +0,0 @@ -import torch -import colossalai -import copy -import pytest -import torch.multiprocessing as mp -from colossalai.zero import ShardedOptimizer -from torch.nn.parallel import DistributedDataParallel as DDP - -from colossalai.utils import free_port -from functools import partial -from common import allclose -from tests.components_to_test.registry import non_distributed_component_funcs - - -def check_completely_equal(a, b): - """ - This function checks if two tensors are completely equal - """ - assert torch.all(a == b), f'a = {a}, b = {b}' - - -def check_sharded_param_consistency(): - """ - In this test, we want to test whether zero stage 1 and 2 - deliver the same numerical results despite different communication - pattern - - we use these prefixes to differentiate the zero stage - oss: partition optimizer states - pg: partition gradients and optimizer states - - """ - test_models = ['repeated_computed_layers', 'resnet18', 'nested_model'] - - for name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(name) - model_builder, train_dataloader, *_ = get_components_func() - - # create model - oss_model = model_builder(checkpoint=True).cuda().half() - pg_model = copy.deepcopy(oss_model) - - # create optimizer - oss_optimizer = torch.optim.Adam(oss_model.parameters(), lr=0.001) - pg_optimizer = torch.optim.Adam(pg_model.parameters(), lr=0.001) - oss_optimizer = ShardedOptimizer(oss_optimizer, overlap_communication=True, initial_scale=1, clip_grad_norm=0.0) - pg_optimizer = ShardedOptimizer(pg_optimizer, - overlap_communication=True, - partition_grad=True, - initial_scale=1, - clip_grad_norm=0.0) - - # create - data, label = next(iter(train_dataloader)) - input_data = data.cuda().half() - - # forward - oss_output = oss_model(input_data) - pg_output = pg_model(input_data) - check_completely_equal(oss_output, pg_output) - - # backward - oss_optimizer.backward(oss_output.mean().float()) - pg_optimizer.backward(pg_output.mean().float()) - - # check grad - # as this param is small, the backward reduction - # will not be fired - for oss_param, pg_param in zip(oss_model.parameters(), pg_model.parameters()): - check_completely_equal(oss_param.grad, pg_param.grad) - - # step - oss_optimizer.sync_grad() - pg_optimizer.sync_grad() - - # step - oss_optimizer.step() - pg_optimizer.step() - - # check updated param - for oss_param, pg_param in zip(oss_model.parameters(), pg_model.parameters()): - check_completely_equal(oss_param, pg_param) - - -def check_sharded_optim_against_torch_ddp(): - """ - In this test, two pairs of model and optimizers are created. - 1. zero: use sharded optimizer and fp16 parameters - 2. torch: use torch DDP and fp32 parameters - - We feed these two sets of models with the same input and check if the - differences in model output and updated parameters are within tolerance. - """ - - test_models = ['repeated_computed_layers', 'resnet18', 'nested_model'] - - for name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(name) - model_builder, train_dataloader, *_ = get_components_func() - - # create model - zero_model = model_builder(checkpoint=True).cuda() - torch_model = copy.deepcopy(zero_model) - - zero_model = zero_model.half() - torch_model = DDP(torch_model.cuda()) - - # create optimizer - zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001) - - # we only test stage 1 here - # in `check_sharded_param_consistency.py`, we will test whether - # level 1 and 2 will produce exactly the same results - zero_optimizer = ShardedOptimizer(zero_optimizer, - overlap_communication=True, - initial_scale=1, - clip_grad_norm=0.0) - torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001) - - # create - input_data, _ = next(iter(train_dataloader)) - input_data = input_data.cuda() - - # zero-dp forward - zero_output = zero_model(input_data.half()) - - # torch-ddp forward - torch_output = torch_model(input_data) - allclose(zero_output, torch_output.half()) - - # zero-dp backward - zero_optimizer.backward(zero_output.mean().float()) - - # torch-ddp backward - torch_output.mean().backward() - - # check grad - for oss_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): - allclose(oss_param.grad, torch_param.grad.half()) - - # zero-dp step - zero_optimizer.sync_grad() - zero_optimizer.step() - - # torch ddp step - torch_optimizer.step() - - # check updated param - for oss_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): - allclose(oss_param, torch_param.half()) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') - - check_sharded_optim_against_torch_ddp() - check_sharded_param_consistency() - - -@pytest.mark.dist -def test_sharded_optim(): - world_size = 2 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_sharded_optim() diff --git a/tests/test_zero_data_parallel/test_zero_param_mgr.py b/tests/test_zero_data_parallel/test_zero_param_mgr.py deleted file mode 100644 index 8171a0946..000000000 --- a/tests/test_zero_data_parallel/test_zero_param_mgr.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from functools import partial -import pytest - -import torch -import torch.multiprocessing as mp - -import colossalai -from colossalai.zero.sharded_model.param_manager import Zero3ParameterManager -from colossalai.core import global_context as gpc -from colossalai.context.parallel_mode import ParallelMode -from colossalai.utils import free_port -from common import CONFIG - - -def run_shard_shape_check(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - model = torch.nn.Linear(2, 4 * world_size) - gpc.init_parallel_groups() - Zero3ParameterManager(module=model, - process_group=gpc.get_group(ParallelMode.DATA), - offload_config=CONFIG.get('offload_param_config')) - - assert (model.weight.numel() == 4 * 2) - assert (model.bias.numel() == 4) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2, 4]) -def test_run_shard_shape(world_size): - run_func = partial(run_shard_shape_check, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_run_shard_shape(2) diff --git a/tests/test_zero_tensor_parallel/components.py b/tests/test_zero_tensor_parallel/components.py deleted file mode 100644 index 69a4c9a95..000000000 --- a/tests/test_zero_tensor_parallel/components.py +++ /dev/null @@ -1,19 +0,0 @@ - -import sys -from pathlib import Path -repo_path = Path(__file__).absolute().parents[2] -sys.path.append(str(repo_path)) - -try: - import model_zoo.vit.vision_transformer_from_config -except ImportError: - raise ImportError("model_zoo is not found, please check your path") - -BATCH_SIZE = 8 -IMG_SIZE = 32 -PATCH_SIZE = 4 -DIM = 512 -NUM_ATTENTION_HEADS = 8 -SUMMA_DIM = 2 -NUM_CLASSES = 10 -DEPTH = 6 diff --git a/tests/test_zero_tensor_parallel/test_vit_2d_level_2.py b/tests/test_zero_tensor_parallel/test_vit_2d_level_2.py deleted file mode 100644 index f87ea7c68..000000000 --- a/tests/test_zero_tensor_parallel/test_vit_2d_level_2.py +++ /dev/null @@ -1,99 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import os -from functools import partial -from pathlib import Path - -import colossalai -import pytest -import torch -import torch.autograd -import torch.multiprocessing as mp -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.nn import CrossEntropyLoss -from colossalai.utils import free_port, get_dataloader -from model_zoo.vit import vit_lite_depth7_patch4_32 -from torchvision import transforms -from torchvision.datasets import CIFAR10 - -from components import * - -CONFIG = dict(parallel=dict( - pipeline=dict(size=1), - tensor=dict(size=4, mode='2d'), -), - fp16=dict(mode=None, ), - zero=dict(level=2)) - - -def train_epoch(engine, train_dataloader): - engine.train() - accumulated_loss = 0 - num_steps = len(train_dataloader) - data_iter = iter(train_dataloader) - for i in range(num_steps): - output, label, loss = engine.step(data_iter) - accumulated_loss += loss.detach().cpu().numpy() - avg_loss = accumulated_loss / num_steps - return avg_loss - - -def run_2d_parallel_vision_transformer_level_2(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - # build model - model = vit_lite_depth7_patch4_32() - - # build dataloader# build dataloaders - train_dataset = CIFAR10(root=Path(os.environ['DATA']), - download=True, - transform=transforms.Compose([ - transforms.Resize(size=(IMG_SIZE, IMG_SIZE)), - transforms.ToTensor(), - transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) - ])) - train_dataloader = get_dataloader(dataset=train_dataset, - shuffle=True, - batch_size=BATCH_SIZE, - pin_memory=True, - drop_last=True) - - # build optimizer and loss - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - criterion = CrossEntropyLoss() - - engine, train_dataloader, *args = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) - logger = get_dist_logger() - - logger.info('start training') - engine.train() - - for img, label in train_dataloader: - engine.zero_grad() - img = img.cuda() - label = label.cuda() - out = engine(img) - loss = engine.criterion(out, label) - engine.backward(loss) - engine.step() - break - - gpc.destroy() - torch.cuda.empty_cache() - - -@pytest.mark.dist -@pytest.mark.skip(reason="This test should be refactored for the reconstructed zero") -def test_2d_vit_zero_level_2(): - world_size = 8 - run_func = partial(run_2d_parallel_vision_transformer_level_2, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_2d_vit_zero_level_2() diff --git a/tests/test_zero_tensor_parallel/test_vit_2d_level_3.py b/tests/test_zero_tensor_parallel/test_vit_2d_level_3.py deleted file mode 100644 index 2f6416a17..000000000 --- a/tests/test_zero_tensor_parallel/test_vit_2d_level_3.py +++ /dev/null @@ -1,99 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import os -from functools import partial -from pathlib import Path - -import colossalai -import pytest -import torch -import torch.autograd -import torch.multiprocessing as mp -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.nn import CrossEntropyLoss -from colossalai.utils import free_port, get_dataloader -from model_zoo.vit import vit_lite_depth7_patch4_32 -from torchvision import transforms -from torchvision.datasets import CIFAR10 - -from components import * - -CONFIG = dict(parallel=dict( - pipeline=dict(size=1), - tensor=dict(size=4, mode='2d'), -), - fp16=dict(mode=None, ), - zero=dict(level=3)) - - -def train_epoch(engine, train_dataloader): - engine.train() - accumulated_loss = 0 - num_steps = len(train_dataloader) - data_iter = iter(train_dataloader) - for i in range(num_steps): - output, label, loss = engine.step(data_iter) - accumulated_loss += loss.detach().cpu().numpy() - avg_loss = accumulated_loss / num_steps - return avg_loss - - -def run_2d_parallel_vision_transformer_level_3(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - # build model - model = vit_lite_depth7_patch4_32() - - # build dataloader# build dataloaders - train_dataset = CIFAR10(root=Path(os.environ['DATA']), - download=True, - transform=transforms.Compose([ - transforms.Resize(size=(IMG_SIZE, IMG_SIZE)), - transforms.ToTensor(), - transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) - ])) - train_dataloader = get_dataloader(dataset=train_dataset, - shuffle=True, - batch_size=BATCH_SIZE, - pin_memory=True, - drop_last=True) - - # build optimizer and loss - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - criterion = CrossEntropyLoss() - - engine, train_dataloader, *args = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) - logger = get_dist_logger() - - logger.info('start training') - engine.train() - - for img, label in train_dataloader: - engine.zero_grad() - img = img.cuda() - label = label.cuda() - out = engine(img) - loss = engine.criterion(out, label) - engine.backward(loss) - engine.step() - break - - gpc.destroy() - torch.cuda.empty_cache() - - -@pytest.mark.dist -@pytest.mark.skip(reason="This test should be refactored for the reconstructed zero") -def test_3d_vit_zero_level_3(): - world_size = 8 - run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_3d_vit_zero_level_3()