diff --git a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py index 7f81a0567..90b447de1 100644 --- a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py @@ -10,6 +10,11 @@ from .tensor_shard_strategy import TensorShardStrategy class BucketTensorShardStrategy(TensorShardStrategy): + """Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together, + which will fully utilize network bandwidth. + It is especially useful when sub-module contains bias, + since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usaully small). + """ def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded] diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/shard_utils/tensor_shard_strategy.py index b393d4e88..9383889e9 100644 --- a/colossalai/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/tensor_shard_strategy.py @@ -9,6 +9,8 @@ from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor class TensorShardStrategy(BaseShardStrategy): + """A naive implementation which shard each tensor evenly over all ranks + """ def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): for t in tensor_list: diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index f1cda2148..4efd096b1 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -1,6 +1,6 @@ import functools from collections import OrderedDict -from typing import Any, Optional, Type +from typing import Any, Optional import torch import torch.distributed as dist @@ -16,7 +16,6 @@ from colossalai.utils.memory_tracer.allocator import col_move_to_cpu from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer -from colossalai.zero.sharded_param import ShardedParamV2 from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter @@ -25,10 +24,34 @@ from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tenso class ShardedModelV2(nn.Module): + """A wrapper for a sharded module, which implements Zero Redundancy Optimizer (ZeRO) stage 3. + Parameter, gradient and optimizer states are sharded, so memory efficiency is boosted drastically + compared to classic data parallelism while the computational granularity and communication efficiency are retained. + Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`. + + :param module: A sharded module, which must be initialized by `ZeroInitContext`. + :type module: nn.Module + :param shard_strategy: A shard strategy to manage shard behavior. + :type shard_strategy: BaseShardStrategy + :param process_group: Data parallel process group, defaults to None + :type process_group: Optional[ProcessGroup], optional + :param reduce_scatter_process_group: Reduce-scatter process group, defaults to None. Generally, it should be `None`. + :type reduce_scatter_process_group: Optional[ProcessGroup], optional + :param reduce_scatter_bucket_size_mb: Reduce-scatter bucket size in *MB*, defaults to 25 + :type reduce_scatter_bucket_size_mb: int, optional + :param fp32_reduce_scatter: If set to `True`, gradients are forced to FP32 before reduce-scatter, defaults to False + :type fp32_reduce_scatter: bool, optional + :param offload_config: We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload, defaults to None + :type offload_config: Optional[dict], optional + :param gradient_predivide_factor: Gradient is divived by this value before reduce-scatter, defaults to 1.0 + :type gradient_predivide_factor: Optional[float], optional + :param use_memory_tracer: Whether to use memoty tracer, defaults to False + :type use_memory_tracer: bool, optional + """ def __init__(self, module: nn.Module, - shard_strategy: Type[BaseShardStrategy], + shard_strategy: BaseShardStrategy, process_group: Optional[ProcessGroup] = None, reduce_scatter_process_group: Optional[ProcessGroup] = None, reduce_scatter_bucket_size_mb: int = 25, @@ -36,10 +59,6 @@ class ShardedModelV2(nn.Module): offload_config: Optional[dict] = None, gradient_predivide_factor: Optional[float] = 1.0, use_memory_tracer: bool = False): - r""" - A demo to reconfigure zero1 shared_model. - Currently do not consider the Optimizer States. - """ super().__init__() self.logger = get_dist_logger() diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index a4d260ed8..4f111921d 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -25,6 +25,46 @@ class OptimState(Enum): class ShardedOptimizerV2(ColossalaiOptimizer): + """A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO) stage 3. + You must use `ShardedOptimizerV2` with `ShardedModelV2`. + + :param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the + shard strategy provided by sharded model to shard param fp32 tensors. + :type sharded_model: sharded_model + + :param optimizer: A Optimizer instance. + :type optimizer: Optimizer + + :param cpu_offload: is offloading the optimizer states to CPU. + :type cpu_offload: bool + + :param initial_scale: initial scale used by DynamicGradScaler + :type initial_scale: float + + :param min_scale: min scale used by DynamicGradScaler + :type min_scale: float + + :param growth_factor: growth_factor used by DynamicGradScaler + :type growth_factor: float + + :param backoff_factor: backoff_factor used by DynamicGradScaler + :type backoff_factor: float + + :param growth_interval: growth_interval used by DynamicGradScaler + :type growth_interval: float + + :param hysteresis: hysteresis used by DynamicGradScaler + :type hysteresis: float + + :param max_scale: max_scale used by DynamicGradScaler + :type max_scale: float + + :param dp_process_group: data paralle process group + :type dp_process_group: Optional[ProcessGroup] + + :param mp_process_group: model paralle process group + :type mp_process_group: Optional[ProcessGroup] + """ def __init__(self, sharded_model: ShardedModelV2, @@ -39,47 +79,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer): max_scale: int = 2**32, dp_process_group: Optional[ProcessGroup] = None, mp_process_group: Optional[ProcessGroup] = None) -> None: - """ - :param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the - shard strategy provided by sharded model to shard param fp32 tensors. - :type sharded_model: sharded_model - - :param optimizer_class: A class type of Optimizer - :type optimizer_class: Type[Optimizer] - - :param cpu_offload: is offloading the optimizer states to CPU. - :type cpu_offload: bool - - :param initial_scale: initial scale used by DynamicGradScaler - :type initial_scale: float - - :param min_scale: min scale used by DynamicGradScaler - :type min_scale: float - - :param growth_factor: growth_factor used by DynamicGradScaler - :type growth_factor: float - - :param backoff_factor: backoff_factor used by DynamicGradScaler - :type backoff_factor: float - - :param growth_interval: growth_interval used by DynamicGradScaler - :type growth_interval: float - - :param hysteresis: hysteresis used by DynamicGradScaler - :type hysteresis: float - - :param max_scale: max_scale used by DynamicGradScaler - :type max_scale: float - - :param dp_process_group: data paralle process group - :type dp_process_group: Optional[ProcessGroup] - - :param mp_process_group: model paralle process group - :type mp_process_group: Optional[ProcessGroup] - - :**defaults: any trailing arguments, which are forwarded to the local optimizer. - :type defaults: dict() - """ assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' super().__init__(optimizer)