mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 21:22:04 +00:00
[refactor] remove old zero code (#517)
This commit is contained in:
parent
6a3f9fda83
commit
4d322b79da
@ -12,7 +12,6 @@ from colossalai.core import global_context as gpc
|
|||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.zero import ShardedModel, ShardedOptimizer
|
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
|
|
||||||
from ._base_schedule import BaseSchedule
|
from ._base_schedule import BaseSchedule
|
||||||
@ -92,8 +91,6 @@ class PipelineSchedule(BaseSchedule):
|
|||||||
|
|
||||||
def pre_processing(self, engine):
|
def pre_processing(self, engine):
|
||||||
# TODO: remove this after testing new zero with pipeline parallelism
|
# TODO: remove this after testing new zero with pipeline parallelism
|
||||||
if isinstance(engine.optimizer, ShardedOptimizer) or isinstance(engine.model, ShardedModel):
|
|
||||||
raise TypeError("Pipeline schedule is currently not compatible with ZeRO")
|
|
||||||
model = engine.model
|
model = engine.model
|
||||||
if isinstance(model, (NaiveAMPModel, ShardedModelV2)):
|
if isinstance(model, (NaiveAMPModel, ShardedModelV2)):
|
||||||
self.dtype = torch.half
|
self.dtype = torch.half
|
||||||
|
@ -2,14 +2,9 @@ from typing import Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||||
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
||||||
from torch.optim import Optimizer
|
|
||||||
|
|
||||||
from .sharded_model import ShardedModel
|
|
||||||
from .sharded_optim import ShardedOptimizer
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
|
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
|
||||||
@ -40,35 +35,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
|
|||||||
return zero_model, zero_optimizer
|
return zero_model, zero_optimizer
|
||||||
|
|
||||||
|
|
||||||
def convert_to_zero(model: nn.Module, optimizer: Optimizer, level: int, zero_config: dict):
|
__all__ = ['convert_to_zerov2', 'ShardedModelV2', 'ShardedOptimizerV2']
|
||||||
"""
|
|
||||||
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
|
|
||||||
|
|
||||||
:param model: Your model object
|
|
||||||
:type model: :class:`torch.nn.Module`
|
|
||||||
:param optimizer: Your optimizer object
|
|
||||||
:type optimizer: :class:`torch.optim.Optimizer`
|
|
||||||
:param level: Optimizer level, can be 2 or 3
|
|
||||||
:type level: int
|
|
||||||
:param zero_config: Configuration for zero
|
|
||||||
:type zero_config: dict
|
|
||||||
|
|
||||||
:return: (model, optimizer)
|
|
||||||
:rtype: Tuple
|
|
||||||
"""
|
|
||||||
assert 1 <= level <= 3, 'Only ZERO Optimizer Level 1-3 are provided'
|
|
||||||
if level in [1, 2]:
|
|
||||||
if level == 2:
|
|
||||||
if 'partition_grad' in zero_config:
|
|
||||||
assert zero_config['partition_grad'], \
|
|
||||||
'Sharded Optimizer requires partition_grad to be True'
|
|
||||||
else:
|
|
||||||
zero_config['partiton_grad'] = True
|
|
||||||
model = NaiveAMPModel(model, output_to_fp32=True)
|
|
||||||
optimizer = ShardedOptimizer(optimizer, **zero_config)
|
|
||||||
else:
|
|
||||||
model = ShardedModel(module=model, **zero_config)
|
|
||||||
return model, optimizer
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['convert_to_zero', 'ShardedModel', 'ShardedOptimizer']
|
|
||||||
|
@ -8,7 +8,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
|
|||||||
GLOBAL_MODEL_DATA_TRACER
|
GLOBAL_MODEL_DATA_TRACER
|
||||||
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
|
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
|
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||||
from colossalai.zero.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param import ShardedParamV2
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from colossalai.logging import get_dist_logger, disable_existing_loggers
|
from colossalai.logging import get_dist_logger, disable_existing_loggers
|
||||||
|
20
colossalai/zero/shard_utils/commons.py
Normal file
20
colossalai/zero/shard_utils/commons.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]:
|
||||||
|
"""Return the local shard of a full tensor."""
|
||||||
|
# Shard using torch.chunk to match all-gather/reduce-scatter.
|
||||||
|
chunks = list(torch.flatten(tensor).chunk(world_size))
|
||||||
|
while len(chunks) < world_size:
|
||||||
|
chunks.append(chunks[0].new_empty(0))
|
||||||
|
|
||||||
|
# Determine number of padding elements.
|
||||||
|
num_to_pad = chunks[0].numel() - chunks[rank].numel()
|
||||||
|
assert num_to_pad >= 0, num_to_pad
|
||||||
|
|
||||||
|
shard = chunks[rank].clone()
|
||||||
|
if num_to_pad > 0:
|
||||||
|
shard = F.pad(shard, [0, num_to_pad])
|
||||||
|
return shard, num_to_pad
|
@ -4,7 +4,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
from colossalai.zero.shard_utils.commons import get_shard
|
||||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from .sharded_model import ShardedModel
|
|
||||||
from .sharded_model_v2 import ShardedModelV2
|
from .sharded_model_v2 import ShardedModelV2
|
||||||
|
|
||||||
__all__ = ['ShardedModel', 'ShardedModelV2']
|
__all__ = ['ShardedModelV2']
|
@ -1,5 +1,4 @@
|
|||||||
from collections import OrderedDict
|
from typing import Any, Callable, List, Tuple
|
||||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -12,23 +11,6 @@ def get_gradient_predivide_factor(world_size: int) -> float:
|
|||||||
return float(factor)
|
return float(factor)
|
||||||
|
|
||||||
|
|
||||||
def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]:
|
|
||||||
"""Return the local shard of a full tensor."""
|
|
||||||
# Shard using torch.chunk to match all-gather/reduce-scatter.
|
|
||||||
chunks = list(torch.flatten(tensor).chunk(world_size))
|
|
||||||
while len(chunks) < world_size:
|
|
||||||
chunks.append(chunks[0].new_empty(0))
|
|
||||||
|
|
||||||
# Determine number of padding elements.
|
|
||||||
num_to_pad = chunks[0].numel() - chunks[rank].numel()
|
|
||||||
assert num_to_pad >= 0, num_to_pad
|
|
||||||
|
|
||||||
shard = chunks[rank].clone()
|
|
||||||
if num_to_pad > 0:
|
|
||||||
shard = F.pad(shard, [0, num_to_pad])
|
|
||||||
return shard, num_to_pad
|
|
||||||
|
|
||||||
|
|
||||||
def free_storage(data: torch.Tensor) -> None:
|
def free_storage(data: torch.Tensor) -> None:
|
||||||
"""Free underlying storage of a Tensor."""
|
"""Free underlying storage of a Tensor."""
|
||||||
if data.storage().size() > 0:
|
if data.storage().size() > 0:
|
||||||
@ -86,31 +68,3 @@ def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
|
|||||||
if len(chunks) < num_chunks:
|
if len(chunks) < num_chunks:
|
||||||
chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))])
|
chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))])
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
def assert_in_engine(cond: Any, s: Any) -> None:
|
|
||||||
"""Used in backward context to make sure error is printed."""
|
|
||||||
if not cond:
|
|
||||||
print(s)
|
|
||||||
raise AssertionError
|
|
||||||
|
|
||||||
|
|
||||||
def replace_state_dict_prefix(state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
|
|
||||||
old_prefix: str, new_prefix: str) -> None:
|
|
||||||
"""
|
|
||||||
Replace all keys that match a given old_prefix with a new_prefix (in-place).
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
state_dict = {"layer.xyz": torch.tensor(1)}
|
|
||||||
replace_state_dict_prefix(state_dict, "layer.", "module.layer.")
|
|
||||||
assert state_dict == {"module.layer.xyz": torch.tensor(1)}
|
|
||||||
"""
|
|
||||||
if old_prefix == new_prefix:
|
|
||||||
raise ValueError("old_prefix and new_prefix must be distinct")
|
|
||||||
for key in list(state_dict.keys()):
|
|
||||||
if not key.startswith(old_prefix):
|
|
||||||
continue
|
|
||||||
new_key = new_prefix + key[len(old_prefix):]
|
|
||||||
state_dict[new_key] = state_dict[key]
|
|
||||||
del state_dict[key]
|
|
@ -1,385 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.distributed import ProcessGroup
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
|
|
||||||
from ._zero3_utils import alloc_storage, free_storage, get_shard
|
|
||||||
|
|
||||||
# TODO: Remove the toggle-enable_nccl_base_collectives in the future
|
|
||||||
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
|
|
||||||
enable_nccl_base_collectives = False
|
|
||||||
else:
|
|
||||||
enable_nccl_base_collectives = True
|
|
||||||
|
|
||||||
# TODO: add flatten params
|
|
||||||
|
|
||||||
|
|
||||||
class Zero3ParameterManager:
|
|
||||||
def __init__(self,
|
|
||||||
module: nn.Module,
|
|
||||||
process_group: Optional[ProcessGroup],
|
|
||||||
mixed_precision: bool = False,
|
|
||||||
flatten_parameters: bool = True,
|
|
||||||
compute_dtype: Optional[torch.dtype] = None,
|
|
||||||
compute_device: Optional[torch.device] = None,
|
|
||||||
offload_config: Optional[dict] = None
|
|
||||||
) -> None:
|
|
||||||
"""Manage parameter shards. We manage several attributes on each Parameter instance:
|
|
||||||
``zero_is_sharded``: ``True`` if the Parameter is sharded or ``False``
|
|
||||||
if the Parameter is intentionally not sharded (in which case we
|
|
||||||
will all-reduce grads for this param).
|
|
||||||
``zero_orig_size``: the size of the original Parameter (before sharding)
|
|
||||||
``zero_shard_padding``: the padding size. All paddings are right padding.
|
|
||||||
``zero_fp32_shard``: a single shard of the parameters in full precision
|
|
||||||
(typically FP32, but this is dependent on the dtype of the model
|
|
||||||
as it's passed in by the user). This can be on CPU or GPU
|
|
||||||
depending on the value of *``offload_config``*.
|
|
||||||
``zero_fp16_shard``: This will be a single shard of the parameters in FP16, used for all-gather.
|
|
||||||
This can be in FP16 or FP32 depending on the value of *``compute_dtype``* and
|
|
||||||
if params are offloaded to CPU.
|
|
||||||
``zero_full_param_padded``: the full weight (padded to be evenly
|
|
||||||
divisible by ``world_size``), used for computation in the
|
|
||||||
forward and backward pass. This will be resized in place and
|
|
||||||
only materialized (via all-gather) as needed.
|
|
||||||
``zero_cpu_grad``: the gradient saved on CPU. It's set only when using CPU offload.
|
|
||||||
|
|
||||||
:param module: original module
|
|
||||||
:type module: nn.Module
|
|
||||||
:param process_group: typically data parallel process group, defaults to None
|
|
||||||
:type process_group: Optional[ProcessGroup], optional
|
|
||||||
:param mixed_precision: whether to use mixed precision mode, defaults to False
|
|
||||||
:type mixed_precision: bool, optional
|
|
||||||
:param flatten_parameters: whether to flatten parameters, useless now, defaults to True
|
|
||||||
:type flatten_parameters: bool, optional
|
|
||||||
:param compute_dtype: the dtype of parameters when computing, defaults to None
|
|
||||||
:type compute_dtype: Optional[torch.dtype], optional
|
|
||||||
:param compute_device: the device of parameters when computing, defaults to None
|
|
||||||
:type compute_device: Optional[torch.device], optional
|
|
||||||
:param offload_config: offload config, defaults to None
|
|
||||||
:type offload_config: Optional[dict], optional
|
|
||||||
"""
|
|
||||||
self.process_group = process_group
|
|
||||||
self.shard_idx = process_group.rank()
|
|
||||||
self.num_shards = process_group.size()
|
|
||||||
self.mixed_precision = mixed_precision
|
|
||||||
self.compute_dtype = compute_dtype
|
|
||||||
self.compute_device = compute_device
|
|
||||||
self.offload_config = offload_config
|
|
||||||
|
|
||||||
self._cpu_offload = offload_config.get('device', None) == 'cpu' if offload_config else False
|
|
||||||
|
|
||||||
self.params: List[Parameter] = []
|
|
||||||
for param in module.parameters():
|
|
||||||
if not hasattr(param, 'zero_is_sharded'):
|
|
||||||
self.params.append(param)
|
|
||||||
|
|
||||||
self._has_params = len(self.params) > 0
|
|
||||||
self._has_sharded_params = False
|
|
||||||
# Flag to indicate if the full params are gathered.
|
|
||||||
self.has_full_params: bool = False
|
|
||||||
|
|
||||||
self._shard_params()
|
|
||||||
# Maybe no need, reserve to prevent bugs
|
|
||||||
# self.delete_fp32_shards()
|
|
||||||
|
|
||||||
self._streams: Dict[str, torch.cuda.Stream] = {}
|
|
||||||
|
|
||||||
def _shard_params(self) -> None:
|
|
||||||
for p in self.params:
|
|
||||||
assert not hasattr(p, "zero_is_sharded")
|
|
||||||
assert p.is_floating_point()
|
|
||||||
if self.mixed_precision:
|
|
||||||
assert p.dtype == torch.float32
|
|
||||||
|
|
||||||
# If world_size is 1, then we all-reduce grads instead of sharding.
|
|
||||||
p.zero_is_sharded = self.num_shards > 1
|
|
||||||
p.zero_orig_size = p.data.size()
|
|
||||||
|
|
||||||
if not p.zero_is_sharded:
|
|
||||||
p.zero_shard_padding = 0
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Replace p.data with the relevant shard.
|
|
||||||
orig_data = p.data
|
|
||||||
p.data, p.zero_shard_padding = get_shard(p.data, self.shard_idx, self.num_shards)
|
|
||||||
free_storage(orig_data)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def reset_param_attr(self, p: Parameter, training: bool) -> None:
|
|
||||||
"""This should be called by ``ZeroRedundancyLevel3Model._lazy_init()``
|
|
||||||
"""
|
|
||||||
assert hasattr(p, 'zero_is_sharded') and hasattr(p, 'zero_orig_size')
|
|
||||||
if hasattr(p, 'zero_fp32_shard'):
|
|
||||||
return
|
|
||||||
|
|
||||||
# A single shard of the parameters in full precision.
|
|
||||||
p.zero_fp32_shard = p.data
|
|
||||||
|
|
||||||
if self.mixed_precision:
|
|
||||||
assert p.zero_fp32_shard.dtype == torch.float32
|
|
||||||
|
|
||||||
if self._cpu_offload:
|
|
||||||
assert p.zero_fp32_shard.device == torch.device('cpu')
|
|
||||||
# If we plan to keep the FP32 parameters on CPU, then pinning
|
|
||||||
# memory allows us to later use non-blocking transfers when moving
|
|
||||||
# the FP32 param shard to compute_device.
|
|
||||||
p.zero_fp32_shard = p.zero_fp32_shard.pin_memory()
|
|
||||||
p.data = p.zero_fp32_shard
|
|
||||||
|
|
||||||
if self.mixed_precision or self._cpu_offload:
|
|
||||||
|
|
||||||
# In mixed precision mode, we maintain a reduced precision
|
|
||||||
# (typically FP16) parameter shard on compute_device for performing
|
|
||||||
# the computation in the forward/backward pass. We resize the
|
|
||||||
# storage to size 0 at init (here) and re-materialize (by copying
|
|
||||||
# from _fp32_shard) as needed. If offloading params to CPU, the
|
|
||||||
# dtype of the fp16 shard will depend on the *`compute_dtype`*.
|
|
||||||
p.zero_fp16_shard = torch.zeros_like(
|
|
||||||
p.zero_fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
|
|
||||||
free_storage(p.zero_fp16_shard)
|
|
||||||
|
|
||||||
if self.mixed_precision:
|
|
||||||
assert p.zero_fp32_shard.dtype == torch.float32
|
|
||||||
|
|
||||||
if not self.mixed_precision and not self._cpu_offload:
|
|
||||||
# use _fp32_shard if you are not in using mixed precision or
|
|
||||||
# offloading params and grads to CPU.
|
|
||||||
p.zero_fp16_shard = None
|
|
||||||
|
|
||||||
# We also maintain a full-sized parameter of type self.compute_dtype
|
|
||||||
# (FP16 for mixed_precision or FP32 otherwise). We resize the
|
|
||||||
# storage to size 0 at init (here) and only materialize as needed. The
|
|
||||||
# storage may contain padding elements so that it is evenly divisible by
|
|
||||||
# world_size, although these padding elements will be removed before the
|
|
||||||
# relevant computation.
|
|
||||||
if p.zero_is_sharded:
|
|
||||||
p.zero_full_param_padded = torch.zeros(
|
|
||||||
p.data.numel() * self.num_shards, device=self.compute_device, dtype=self.compute_dtype
|
|
||||||
)
|
|
||||||
free_storage(p.zero_full_param_padded)
|
|
||||||
|
|
||||||
if self._cpu_offload and training:
|
|
||||||
p.zero_cpu_grad = torch.zeros_like(p.data, device='cpu').pin_memory()
|
|
||||||
|
|
||||||
def setup_streams(self, streams):
|
|
||||||
self._streams = streams
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
|
|
||||||
"""
|
|
||||||
Gather all shards of params.
|
|
||||||
|
|
||||||
Note, this is idempotent if full params are already gathered. Callers
|
|
||||||
assume the idempotency. So please keep it that way.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
force_full_precision (bool, Optional): by default params will be gathered
|
|
||||||
in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
|
|
||||||
``True``, in which case they will be gathered in full precision
|
|
||||||
(e.g., FP32), possibly in fresh storage. The parameter that's being
|
|
||||||
rebuilt will end up in full precision as well.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of tuples, where the first element is the full-sized param
|
|
||||||
and the second element is a bool indicating if it's safe for the
|
|
||||||
caller to free the full-sized param. This will be ``None`` if
|
|
||||||
``force_full_precision=False`` and the full params are already gathered.
|
|
||||||
"""
|
|
||||||
# Store tensor and free flag
|
|
||||||
output_tensors: List[Tuple[torch.Tensor, bool]] = []
|
|
||||||
|
|
||||||
def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
|
|
||||||
"""
|
|
||||||
Helper function to update p.data pointer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
custom_output_tensor (torch.Tensor, Optional): if not None, this
|
|
||||||
tensor contains the data we just gathered.
|
|
||||||
"""
|
|
||||||
if custom_output_tensor is not None:
|
|
||||||
assert p.zero_is_sharded
|
|
||||||
p.data = custom_output_tensor
|
|
||||||
output_tensors.append((p.data, True))
|
|
||||||
elif not p.zero_is_sharded:
|
|
||||||
if (self.mixed_precision or self._cpu_offload) and not force_full_precision:
|
|
||||||
assert p.zero_fp16_shard is not None
|
|
||||||
p.data = p.zero_fp16_shard
|
|
||||||
output_tensors.append((p.data, True))
|
|
||||||
else:
|
|
||||||
# Here p.data == p._fp32_shard, so it's not safe to free.
|
|
||||||
output_tensors.append((p.data, False))
|
|
||||||
else:
|
|
||||||
p.data = p.zero_full_param_padded
|
|
||||||
output_tensors.append((p.data, True))
|
|
||||||
# Trim any padding and reshape to match original size.
|
|
||||||
p.data = p.data[: p.zero_orig_size.numel()].view(p.zero_orig_size)
|
|
||||||
|
|
||||||
if self._has_sharded_params:
|
|
||||||
# self.has_full_params flag can be out of sync if a shared param is
|
|
||||||
# sharded by another ZeroRedundancyLevel3Model instance. An example is that in eval case
|
|
||||||
# with reshard_after_forward=False but the sharing instance has
|
|
||||||
# reshard_after_forward=True. Then, on the second forward, the
|
|
||||||
# other instance can shard the shared param and but this instance
|
|
||||||
# can mistakenly think the full param is already gathered from the
|
|
||||||
# has_full_params flag.
|
|
||||||
#
|
|
||||||
# Therefore, we update the flag accordingly here.
|
|
||||||
self.has_full_params = not any(p.zero_full_param_padded.storage().size() == 0 for p in self.params)
|
|
||||||
|
|
||||||
# Early exit if we already have full params and don't need full precision.
|
|
||||||
if self.has_full_params and not force_full_precision:
|
|
||||||
for p in self.params:
|
|
||||||
update_p_data()
|
|
||||||
return output_tensors
|
|
||||||
|
|
||||||
self.has_full_params = True
|
|
||||||
|
|
||||||
with torch.cuda.stream(self._streams["all_gather"]):
|
|
||||||
if (self.mixed_precision or self._cpu_offload) and not force_full_precision:
|
|
||||||
self.use_fp16_shards()
|
|
||||||
|
|
||||||
if self._cpu_offload and force_full_precision:
|
|
||||||
# If the compute_dtype and storage dtype are the same,
|
|
||||||
# use pinned memory. Otherwise move p.data to the compute
|
|
||||||
# device.
|
|
||||||
if self.params[0].dtype == self.compute_dtype:
|
|
||||||
self.use_fp16_shards()
|
|
||||||
else:
|
|
||||||
for p in self.params:
|
|
||||||
p.data = p.data.to(self.compute_device)
|
|
||||||
|
|
||||||
for p in self.params:
|
|
||||||
if not p.zero_is_sharded: # e.g., when world_size == 1
|
|
||||||
update_p_data()
|
|
||||||
else:
|
|
||||||
# Skip if already built. Only shared param can be rebuilt multiple times.
|
|
||||||
# A corner case is p.zero_orig_size = (1,), which means the shape equality is
|
|
||||||
# not a perfect check. But we assume we don't share a param with shape (1,).
|
|
||||||
# if p.data.shape == p.zero_orig_size and hasattr(p, "zero_is_shared") and p.zero_is_shared:
|
|
||||||
# continue
|
|
||||||
# If self._cpu_offload and force_full_precision, we need to cast
|
|
||||||
# the FP32 CPU param to CUDA for the all-gather.
|
|
||||||
p_data = p.data.to(p.zero_full_param_padded.device, non_blocking=True)
|
|
||||||
|
|
||||||
p_size = p.zero_full_param_padded.size()
|
|
||||||
assert p_size.numel() % self.num_shards == 0
|
|
||||||
if self.mixed_precision and force_full_precision:
|
|
||||||
# Allocate fresh tensor in full precision since we are in
|
|
||||||
# mixed precision and full precision rebuild is asked.
|
|
||||||
output_tensor = p_data.new_zeros(p_size)
|
|
||||||
else:
|
|
||||||
if p.zero_full_param_padded.storage().size() != p_size.numel():
|
|
||||||
# Allocate based on full size from all shards.
|
|
||||||
alloc_storage(p.zero_full_param_padded, size=p_size)
|
|
||||||
output_tensor = p.zero_full_param_padded
|
|
||||||
|
|
||||||
# Fill output_tensor with (p.data for each shard in self.world_size)
|
|
||||||
if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives:
|
|
||||||
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
|
|
||||||
dist._all_gather_base(output_tensor, p_data, group=self.process_group)
|
|
||||||
else:
|
|
||||||
chunks = list(output_tensor.chunk(self.num_shards))
|
|
||||||
dist.all_gather(chunks, p_data, group=self.process_group)
|
|
||||||
|
|
||||||
# Set p.data = output_tensor (with padding trimmed)
|
|
||||||
update_p_data(output_tensor)
|
|
||||||
|
|
||||||
if (self.mixed_precision or self._cpu_offload) and not force_full_precision:
|
|
||||||
self.free_fp16_shards([p])
|
|
||||||
|
|
||||||
if self._cpu_offload and (self.params[0].dtype == self.compute_dtype):
|
|
||||||
self.free_fp16_shards([p])
|
|
||||||
|
|
||||||
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
|
|
||||||
return output_tensors
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def use_full_params(self) -> None:
|
|
||||||
"""
|
|
||||||
Switch p.data pointers to use the full params.
|
|
||||||
|
|
||||||
Note: this assumes full params are already gathered.
|
|
||||||
|
|
||||||
Note: this might be called after full_params is already in used. So please
|
|
||||||
make sure it is idempotent in that case.
|
|
||||||
"""
|
|
||||||
assert self.has_full_params
|
|
||||||
for p in self.params:
|
|
||||||
if not p.zero_is_sharded:
|
|
||||||
if self.mixed_precision or self._cpu_offload:
|
|
||||||
assert p.zero_fp16_shard is not None
|
|
||||||
assert p.zero_fp16_shard.storage().size() != 0
|
|
||||||
p.data = p.zero_fp16_shard
|
|
||||||
else:
|
|
||||||
assert p.zero_full_param_padded.storage().size() != 0, f"{p.zero_orig_size} {id(self)}"
|
|
||||||
p.data = p.zero_full_param_padded[: p.zero_orig_size.numel()].view(p.zero_orig_size)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def use_fp16_shards(self, params: Optional[List[Parameter]] = None) -> None:
|
|
||||||
"""Cast FP32 param shard to FP16 for a list of params."""
|
|
||||||
if params is None:
|
|
||||||
params = self.params
|
|
||||||
with torch.cuda.stream(self._streams["fp32_to_fp16"]):
|
|
||||||
for p in params:
|
|
||||||
assert p.zero_fp16_shard is not None
|
|
||||||
alloc_storage(p.zero_fp16_shard, size=p.zero_fp32_shard.size())
|
|
||||||
p.zero_fp16_shard.copy_(
|
|
||||||
# If _cpu_offload is True, this will be non-blocking
|
|
||||||
# because _fp32_shard is pinned, otherwise it's a no-op.
|
|
||||||
p.zero_fp32_shard.to(p.zero_fp16_shard.device, non_blocking=True)
|
|
||||||
)
|
|
||||||
p.data = p.zero_fp16_shard
|
|
||||||
torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def use_fp32_shards(self, params: Optional[List[Parameter]] = None) -> None:
|
|
||||||
"""Use FP32 shard for a list of params."""
|
|
||||||
if params is None:
|
|
||||||
params = self.params
|
|
||||||
for p in params:
|
|
||||||
p.data = p.zero_fp32_shard
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
|
|
||||||
"""Free up storage for full parameters."""
|
|
||||||
if params is None:
|
|
||||||
params = self.params
|
|
||||||
self.has_full_params = False
|
|
||||||
current_stream = torch.cuda.current_stream()
|
|
||||||
for p in params:
|
|
||||||
if not p.zero_is_sharded: # e.g., world_size == 1
|
|
||||||
if self.mixed_precision or self._cpu_offload:
|
|
||||||
self.free_fp16_shards([p])
|
|
||||||
continue
|
|
||||||
# Don't let PyTorch reuse this memory until all work in the current
|
|
||||||
# stream is complete.
|
|
||||||
p.zero_full_param_padded.record_stream(current_stream)
|
|
||||||
# There may be external references to the Tensor Storage that we
|
|
||||||
# can't modify, such as references that are created by
|
|
||||||
# ctx.save_for_backward in the forward pass. Thus when we
|
|
||||||
# unshard parameters, we should reuse the original Tensor
|
|
||||||
# Storage object and unshard it in-place. For now, just resize
|
|
||||||
# the Storage to 0 to save memory.
|
|
||||||
free_storage(p.zero_full_param_padded)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def free_fp16_shards(self, params: Optional[List[Parameter]] = None) -> None:
|
|
||||||
"""Free storage for FP16 shards for a list of params."""
|
|
||||||
if params is None:
|
|
||||||
params = self.params
|
|
||||||
current_stream = torch.cuda.current_stream()
|
|
||||||
for p in params:
|
|
||||||
if p.zero_fp16_shard is not None:
|
|
||||||
# zero_fp16_shard is allocated in "fp32_to_fp16" stream, so we can't
|
|
||||||
# free it until the work in the current stream completes.
|
|
||||||
p.zero_fp16_shard.record_stream(current_stream)
|
|
||||||
free_storage(p.zero_fp16_shard)
|
|
||||||
|
|
||||||
def delete_fp32_shards(self) -> None:
|
|
||||||
for p in self.params:
|
|
||||||
if hasattr(p, 'zero_fp32_shard'):
|
|
||||||
del p.zero_fp32_shard # reset _init_param_attr
|
|
@ -1,85 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
|
|
||||||
|
|
||||||
class ShardedGradient:
|
|
||||||
def __init__(self,
|
|
||||||
param: Parameter,
|
|
||||||
sharded_module: nn.Module,
|
|
||||||
offload_config: Optional[dict] = None
|
|
||||||
) -> None:
|
|
||||||
assert hasattr(
|
|
||||||
param, 'ca_attr') and param.ca_attr.is_sharded, 'ShardedGradient can only be initialized with sharded parameter'
|
|
||||||
|
|
||||||
self.param = param
|
|
||||||
self.sharded_module = sharded_module
|
|
||||||
self.offload_config = offload_config
|
|
||||||
|
|
||||||
self._cpu_offload = offload_config.get('device', None) == 'cpu' if offload_config else False
|
|
||||||
|
|
||||||
# _gpu_grad is either sharded or not
|
|
||||||
# all saved grads are fp32
|
|
||||||
self._gpu_grad: Optional[torch.Tensor] = None
|
|
||||||
self._cpu_grad: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
if self._cpu_offload:
|
|
||||||
# this buffer will be held and reused every iteration
|
|
||||||
self._cpu_grad = torch.zeros(param.ca_attr.payload('cpu'), dtype=torch.float).pin_memory()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def setup(self) -> None:
|
|
||||||
"""This function will be called pre-backward. Save the local accumulated gradient to _gpu_grad.
|
|
||||||
When no_sync() is enable (_require_backward_grad_sync=False), the grad is accumulated locally in param.grad
|
|
||||||
|
|
||||||
:raises AssertionError: Raise if grad shape is wrong
|
|
||||||
"""
|
|
||||||
if self.sharded_module._require_backward_grad_sync and self.param.grad is not None:
|
|
||||||
if self.param.grad.device != self.param.data.device:
|
|
||||||
# TODO: offload?
|
|
||||||
raise RuntimeError(
|
|
||||||
'grad and param are on different device, grad {self.param.grad.device} vs. param {self.param.data.device}')
|
|
||||||
else:
|
|
||||||
self._gpu_grad = self.param.grad.data
|
|
||||||
self.param.grad = None
|
|
||||||
|
|
||||||
def reduce_scatter_callback(self, reduced_grad: torch.Tensor) -> None:
|
|
||||||
"""This function will be called in post-backward hook, so we cannot modify param.grad directly
|
|
||||||
|
|
||||||
:param reduced_grad: the reduced grad
|
|
||||||
:type reduced_grad: torch.Tensor
|
|
||||||
"""
|
|
||||||
# Make sure we store fp32 grad
|
|
||||||
if torch.is_floating_point(reduced_grad) and reduced_grad.dtype != torch.float:
|
|
||||||
reduced_grad.data = reduced_grad.data.to(torch.float)
|
|
||||||
|
|
||||||
if self._gpu_grad is None:
|
|
||||||
self._gpu_grad = reduced_grad.data
|
|
||||||
else:
|
|
||||||
self._gpu_grad += reduced_grad.data
|
|
||||||
|
|
||||||
# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
|
|
||||||
# backwards pass completes, we will set `.grad` to the CPU copy.
|
|
||||||
if self._cpu_offload:
|
|
||||||
self._cpu_grad.copy_(self._gpu_grad.data, non_blocking=True)
|
|
||||||
# Don't let this memory get reused until after the transfer.
|
|
||||||
self._gpu_grad.data.record_stream(torch.cuda.current_stream())
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def write_back(self) -> None:
|
|
||||||
"""This function will be called in final backward hook
|
|
||||||
"""
|
|
||||||
if self._cpu_grad is not None:
|
|
||||||
assert self.param.device == torch.device(
|
|
||||||
'cpu'), f'Incorrect param device, expected CPU, got {self.param.device}'
|
|
||||||
self.param.grad.data = self._cpu_grad
|
|
||||||
elif self._gpu_grad is not None:
|
|
||||||
assert self.param.device == self._gpu_grad.device, f'Incorrect _gpu_grad device, param on {self.param.device} but _gpu_grad on {self._gpu_grad.device}'
|
|
||||||
self.param.grad.data = self._gpu_grad
|
|
||||||
else:
|
|
||||||
raise RuntimeError('No grad to write back')
|
|
||||||
# If using CPU offload, _cpu_grad will store the CPU tensor of _gpu_grad
|
|
||||||
# They should be released here
|
|
||||||
self._gpu_grad = None
|
|
File diff suppressed because it is too large
Load Diff
@ -18,7 +18,7 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
|||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||||
get_gradient_predivide_factor)
|
get_gradient_predivide_factor)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from .sharded_optim import ShardedOptimizer
|
|
||||||
from .sharded_optim_v2 import ShardedOptimizerV2
|
from .sharded_optim_v2 import ShardedOptimizerV2
|
||||||
|
|
||||||
__all__ = ['ShardedOptimizer', 'ShardedOptimizerV2']
|
__all__ = ['ShardedOptimizerV2']
|
||||||
|
@ -1,6 +0,0 @@
|
|||||||
from .gradient_store import GradientStore
|
|
||||||
from .parameter_store import ParameterStore
|
|
||||||
from .bucket_store import BucketStore
|
|
||||||
from .tensor_bucket import TensorBucket
|
|
||||||
|
|
||||||
__all__ = ['GradientStore', 'ParameterStore', 'BucketStore', 'TensorBucket']
|
|
@ -1,17 +0,0 @@
|
|||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.context import ParallelMode
|
|
||||||
|
|
||||||
|
|
||||||
class BaseStore:
|
|
||||||
|
|
||||||
def __init__(self, dp_parallel_mode=ParallelMode.DATA):
|
|
||||||
self._world_size = gpc.get_world_size(dp_parallel_mode)
|
|
||||||
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def world_size(self):
|
|
||||||
return self._world_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def local_rank(self):
|
|
||||||
return self._local_rank
|
|
@ -1,43 +0,0 @@
|
|||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.context import ParallelMode
|
|
||||||
from .base_store import BaseStore
|
|
||||||
|
|
||||||
class BucketStore(BaseStore):
|
|
||||||
|
|
||||||
def __init__(self, dp_parallel_mode):
|
|
||||||
super().__init__(dp_parallel_mode)
|
|
||||||
self._grads = dict()
|
|
||||||
self._params = dict()
|
|
||||||
self._num_elements_in_bucket = dict()
|
|
||||||
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def num_elements_in_bucket(self, reduce_rank: int = None):
|
|
||||||
return self._num_elements_in_bucket[reduce_rank]
|
|
||||||
|
|
||||||
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
|
|
||||||
self._num_elements_in_bucket[reduce_rank] += num_elements
|
|
||||||
|
|
||||||
def add_grad(self, tensor, reduce_rank: int = None):
|
|
||||||
self._grads[reduce_rank].append(tensor)
|
|
||||||
|
|
||||||
def add_param(self, tensor, reduce_rank: int = None):
|
|
||||||
self._params[reduce_rank].append(tensor)
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
keys = [None] + list(range(self._world_size))
|
|
||||||
self._grads = {rank: [] for rank in keys}
|
|
||||||
self._params = {rank: [] for rank in keys}
|
|
||||||
self._num_elements_in_bucket = {rank: 0 for rank in keys}
|
|
||||||
|
|
||||||
def reset_by_rank(self, reduce_rank=None):
|
|
||||||
self._grads[reduce_rank] = []
|
|
||||||
self._params[reduce_rank] = []
|
|
||||||
self._num_elements_in_bucket[reduce_rank] = 0
|
|
||||||
|
|
||||||
|
|
||||||
def get_grad(self, reduce_rank: int = None):
|
|
||||||
return self._grads[reduce_rank]
|
|
||||||
|
|
||||||
def get_param(self, reduce_rank: int = None):
|
|
||||||
return self._params[reduce_rank]
|
|
@ -1,66 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
from torch import Tensor
|
|
||||||
from .base_store import BaseStore
|
|
||||||
|
|
||||||
|
|
||||||
class GradientStore(BaseStore):
|
|
||||||
|
|
||||||
def __init__(self, *args):
|
|
||||||
super().__init__(*args)
|
|
||||||
# bookkeeping data structures
|
|
||||||
self._averaged_gradients = dict()
|
|
||||||
|
|
||||||
# for backward reduction hooks
|
|
||||||
self._grad_acc_objs = []
|
|
||||||
|
|
||||||
def add_accumulate_grad_object(self, obj):
|
|
||||||
"""
|
|
||||||
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
|
|
||||||
be attached successfully.
|
|
||||||
|
|
||||||
:param obj: An object of :class:`AccumulateGrad` class
|
|
||||||
:type obj: :class:`AccumulateGrad`
|
|
||||||
"""
|
|
||||||
|
|
||||||
self._grad_acc_objs.append(obj)
|
|
||||||
|
|
||||||
def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
|
|
||||||
"""
|
|
||||||
Return average gradients of a parameter group
|
|
||||||
|
|
||||||
:param group_id: The index of parameter group
|
|
||||||
:type group_id: int
|
|
||||||
|
|
||||||
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
|
|
||||||
:rtype: List[torch.Tensor]
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self._averaged_gradients[group_id]
|
|
||||||
|
|
||||||
def add_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
|
|
||||||
"""
|
|
||||||
Append an average gradient to the list of averaged gradients of a parameter group
|
|
||||||
|
|
||||||
:param group_id: The index of a parameter group
|
|
||||||
:param tensor: A :class:`torch.Tensor` object
|
|
||||||
:type group_id: int
|
|
||||||
:type tensor: torch.Tensor
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
if group_id in self._averaged_gradients:
|
|
||||||
self._averaged_gradients[group_id].append(tensor)
|
|
||||||
else:
|
|
||||||
self._averaged_gradients[group_id] = [tensor]
|
|
||||||
|
|
||||||
def reset_average_gradients_by_group(self, group_id: int) -> None:
|
|
||||||
"""
|
|
||||||
Reset the bookkeeping data structure for averaged gradients to an empty list
|
|
||||||
|
|
||||||
:param group_id: The index of a parameter group
|
|
||||||
:type group_id: int
|
|
||||||
"""
|
|
||||||
|
|
||||||
self._averaged_gradients[group_id] = []
|
|
||||||
|
|
||||||
|
|
@ -1,96 +0,0 @@
|
|||||||
from .base_store import BaseStore
|
|
||||||
from torch import Tensor
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
|
||||||
class ParameterStore(BaseStore):
|
|
||||||
|
|
||||||
def __init__(self, dp_paralle_mode):
|
|
||||||
super().__init__(dp_paralle_mode)
|
|
||||||
# param partitioning data structures
|
|
||||||
self._fp16_param_to_rank = dict()
|
|
||||||
self._rank_groupid_to_fp16_param_list = dict()
|
|
||||||
self._rank_group_id_to_flat_fp16_param = dict()
|
|
||||||
|
|
||||||
# param reduction data structures
|
|
||||||
self._is_param_reduced = dict()
|
|
||||||
self._reduced_param = []
|
|
||||||
|
|
||||||
def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
|
|
||||||
"""
|
|
||||||
Set the mapping between parameter to rank, each parameter should be owned by a rank.
|
|
||||||
|
|
||||||
:param tensor: A :class:`torch.Tensor` object
|
|
||||||
:type tensor: torch.Tensor
|
|
||||||
:param rank: The rank of which the process is responsible for updating the parameter
|
|
||||||
:type rank: int
|
|
||||||
"""
|
|
||||||
|
|
||||||
self._fp16_param_to_rank[tensor] = rank
|
|
||||||
|
|
||||||
def get_param_rank(self, tensor: Tensor) -> int:
|
|
||||||
"""
|
|
||||||
Gives the rank which the parameter belongs to
|
|
||||||
|
|
||||||
:param tensor: A :class:`torch.Tensor` object
|
|
||||||
:type tensor: torch.Tensor
|
|
||||||
"""
|
|
||||||
return self._fp16_param_to_rank[tensor]
|
|
||||||
|
|
||||||
def belongs_to_current_rank(self, tensor) -> bool:
|
|
||||||
"""
|
|
||||||
Check whether a parameter is supposed to be updated by the process of the current rank
|
|
||||||
|
|
||||||
:param tensor: A :class:`torch.Tensor` object
|
|
||||||
:type tensor: torch.Tensor
|
|
||||||
|
|
||||||
:return: True if the parameter should be updated by the current rank. Otherwise false.
|
|
||||||
:rtype: bool
|
|
||||||
"""
|
|
||||||
|
|
||||||
tensor_rank = self._fp16_param_to_rank[tensor]
|
|
||||||
return tensor_rank == self._local_rank
|
|
||||||
|
|
||||||
def add_fp16_param_list_by_rank_group(self, rank, group_id,
|
|
||||||
tensor_list) -> None:
|
|
||||||
if rank not in self._rank_groupid_to_fp16_param_list:
|
|
||||||
self._rank_groupid_to_fp16_param_list[rank] = dict()
|
|
||||||
|
|
||||||
if group_id not in self._rank_groupid_to_fp16_param_list[rank]:
|
|
||||||
self._rank_groupid_to_fp16_param_list[rank][group_id] = []
|
|
||||||
|
|
||||||
self._rank_groupid_to_fp16_param_list[rank][group_id].extend(
|
|
||||||
tensor_list)
|
|
||||||
|
|
||||||
def get_fp16_params_by_rank_group(self, rank, group_id) -> List[Tensor]:
|
|
||||||
return self._rank_groupid_to_fp16_param_list[rank][group_id]
|
|
||||||
|
|
||||||
def add_flat_fp16_param_by_rank_group(self, rank, group_id, tensor) -> None:
|
|
||||||
if rank not in self._rank_group_id_to_flat_fp16_param:
|
|
||||||
self._rank_group_id_to_flat_fp16_param[rank] = dict()
|
|
||||||
|
|
||||||
self._rank_group_id_to_flat_fp16_param[rank][group_id] = tensor
|
|
||||||
|
|
||||||
def get_flat_fp16_param_by_rank_group(self, rank, group_id) -> Tensor:
|
|
||||||
return self._rank_group_id_to_flat_fp16_param[rank][group_id]
|
|
||||||
|
|
||||||
def is_param_reduced(self, tensor):
|
|
||||||
return self._is_param_reduced[tensor]
|
|
||||||
|
|
||||||
def set_param_reduction_state(self, tensor, state):
|
|
||||||
self._is_param_reduced[tensor] = state
|
|
||||||
|
|
||||||
def get_param_reduction_states(self):
|
|
||||||
return self._is_param_reduced
|
|
||||||
|
|
||||||
def reset_previous_reduced_params(self):
|
|
||||||
self._reduced_param = []
|
|
||||||
|
|
||||||
def add_previous_reduced_param(self, tensor):
|
|
||||||
self._reduced_param.append(tensor)
|
|
||||||
|
|
||||||
def clear_grads_of_previous_reduced_params(self):
|
|
||||||
if len(self._reduced_param) > 0:
|
|
||||||
for param in self._reduced_param:
|
|
||||||
param.grad = None
|
|
||||||
self.reset_previous_reduced_params()
|
|
@ -1,54 +0,0 @@
|
|||||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
|
||||||
|
|
||||||
|
|
||||||
class TensorBucket:
|
|
||||||
|
|
||||||
def __init__(self, size):
|
|
||||||
self._max_size = size
|
|
||||||
self._current_size = 0
|
|
||||||
self._bucket = []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max_size(self):
|
|
||||||
return self._max_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_size(self):
|
|
||||||
return self._current_size
|
|
||||||
|
|
||||||
def is_full_or_oversized(self):
|
|
||||||
return self._current_size >= self._max_size
|
|
||||||
|
|
||||||
def is_empty(self):
|
|
||||||
return len(self._bucket) == 0
|
|
||||||
|
|
||||||
def add_to_bucket(self, tensor, allow_oversize=False):
|
|
||||||
tensor_size = tensor.numel()
|
|
||||||
|
|
||||||
if not allow_oversize and self.will_exceed_max_size(tensor_size):
|
|
||||||
msg = f"The param bucket max size {self._max_size} is exceeded" \
|
|
||||||
+ f"by tensor (size {tensor_size})"
|
|
||||||
raise RuntimeError(msg)
|
|
||||||
|
|
||||||
self._bucket.append(tensor)
|
|
||||||
self._current_size += tensor_size
|
|
||||||
|
|
||||||
def will_exceed_max_size(self, tensor_size):
|
|
||||||
expected_size = self._current_size + tensor_size
|
|
||||||
return expected_size > self._max_size
|
|
||||||
|
|
||||||
def get_bucket(self):
|
|
||||||
return self._bucket
|
|
||||||
|
|
||||||
def empty(self):
|
|
||||||
self._bucket = []
|
|
||||||
self._size = 0
|
|
||||||
|
|
||||||
def flatten(self):
|
|
||||||
return _flatten_dense_tensors(self._bucket)
|
|
||||||
|
|
||||||
def unflatten_and_copy(self, flat_tensor):
|
|
||||||
unflattened_tensor_list = _unflatten_dense_tensors(
|
|
||||||
flat_tensor, self._bucket)
|
|
||||||
for old, new in zip(self._bucket, unflattened_tensor_list):
|
|
||||||
old.copy_(new)
|
|
@ -1,563 +0,0 @@
|
|||||||
from colossalai.utils.cuda import get_current_device
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from .bookkeeping import ParameterStore, GradientStore, BucketStore, TensorBucket
|
|
||||||
from colossalai.context import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
|
||||||
from ._utils import (move_tensor, flatten, get_grad_accumulate_object, split_half_float_double, reduce_tensor,
|
|
||||||
release_param_grad, calculate_global_norm_from_list, compute_norm, sync_param, has_inf_or_nan)
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
|
|
||||||
class ShardedOptimizer(ColossalaiOptimizer):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
optimizer: Optimizer,
|
|
||||||
initial_scale=2**32,
|
|
||||||
min_scale=1,
|
|
||||||
growth_factor=2,
|
|
||||||
backoff_factor=0.5,
|
|
||||||
growth_interval=1000,
|
|
||||||
hysteresis=2,
|
|
||||||
max_scale: int = 2**32,
|
|
||||||
clip_grad_norm=2.0,
|
|
||||||
verbose=False,
|
|
||||||
reduce_bucket_size=500000000,
|
|
||||||
communication_dtype=torch.float16,
|
|
||||||
overlap_communication=False,
|
|
||||||
partition_grad=False,
|
|
||||||
dp_parallel_mode=ParallelMode.DATA,
|
|
||||||
mp_parallel_mode=ParallelMode.MODEL,
|
|
||||||
cpu_offload=False,
|
|
||||||
cpu_fp16_param=False,
|
|
||||||
cpu_fp16_grad=False):
|
|
||||||
|
|
||||||
# TODO: add support for
|
|
||||||
# 1. fp16 master weights
|
|
||||||
# 2. contiguous gradients
|
|
||||||
# 3. cpu offload
|
|
||||||
# 4. support when some parameters requires_grad = False
|
|
||||||
|
|
||||||
self._optimizer = optimizer
|
|
||||||
self._dtype = self._optimizer.param_groups[0]['params'][0].dtype
|
|
||||||
self._logger = get_dist_logger()
|
|
||||||
self._verbose = verbose
|
|
||||||
|
|
||||||
# stage 2
|
|
||||||
self._partition_grads = partition_grad
|
|
||||||
|
|
||||||
# cpu_offload
|
|
||||||
self._cpu_offload = cpu_offload
|
|
||||||
self._cpu_fp16_param = cpu_fp16_param
|
|
||||||
self._cpu_fp16_grad = cpu_fp16_grad
|
|
||||||
|
|
||||||
# get process groups
|
|
||||||
self._dp_parallel_mode = dp_parallel_mode
|
|
||||||
self._mp_parallel_mode = mp_parallel_mode
|
|
||||||
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
|
|
||||||
self._world_size = gpc.get_world_size(dp_parallel_mode)
|
|
||||||
|
|
||||||
self._dp_group = gpc.get_group(dp_parallel_mode)
|
|
||||||
if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
|
|
||||||
self._mp_group = gpc.get_group(mp_parallel_mode)
|
|
||||||
else:
|
|
||||||
self._mp_group = None
|
|
||||||
|
|
||||||
# fp16 and fp32 params for mixed precision training
|
|
||||||
self._fp16_param_groups = dict()
|
|
||||||
self._fp32_flat_param_groups_of_current_rank = dict()
|
|
||||||
|
|
||||||
# communication params
|
|
||||||
self._overlap_communication = overlap_communication
|
|
||||||
self._reduce_bucket_size = reduce_bucket_size
|
|
||||||
self._communication_dtype = communication_dtype
|
|
||||||
|
|
||||||
# gradient scaler
|
|
||||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
|
||||||
min_scale=min_scale,
|
|
||||||
growth_factor=growth_factor,
|
|
||||||
backoff_factor=backoff_factor,
|
|
||||||
growth_interval=growth_interval,
|
|
||||||
hysteresis=hysteresis,
|
|
||||||
max_scale=max_scale,
|
|
||||||
verbose=verbose)
|
|
||||||
self._found_overflow = torch.FloatTensor([0]).to(get_current_device())
|
|
||||||
|
|
||||||
# gradient clipping
|
|
||||||
self._clip_grad_norm = clip_grad_norm
|
|
||||||
|
|
||||||
# check argument conflict
|
|
||||||
self._sanity_checks()
|
|
||||||
|
|
||||||
# ParameterStore will manage the tensor buffers used for zero
|
|
||||||
# it will not manage the tensors used by mixed precision training
|
|
||||||
self._param_store = ParameterStore(self._dp_parallel_mode)
|
|
||||||
self._grad_store = GradientStore(self._dp_parallel_mode)
|
|
||||||
self._bucket_store = BucketStore(self._dp_parallel_mode)
|
|
||||||
|
|
||||||
# iterate over the param group in the optimizer
|
|
||||||
# partition these param groups for data parallel training
|
|
||||||
# and add buffers to parameter store for future access
|
|
||||||
for group_id, param_group in enumerate(self._optimizer.param_groups):
|
|
||||||
params = param_group['params']
|
|
||||||
|
|
||||||
# add the fp16 params to fp16_param_groups for bookkeeping
|
|
||||||
self._fp16_param_groups[group_id] = params
|
|
||||||
|
|
||||||
# assign parameters to ranks
|
|
||||||
# the params in the list are sorted
|
|
||||||
params_per_rank = self._partition_param_list(params)
|
|
||||||
|
|
||||||
# store the mapping between param to rank
|
|
||||||
# each param should belong to only one rank
|
|
||||||
for rank, params in enumerate(params_per_rank):
|
|
||||||
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params)
|
|
||||||
for param in params:
|
|
||||||
self._param_store.set_param_to_rank(param, rank)
|
|
||||||
|
|
||||||
# move to cpu to make room to create the flat tensor
|
|
||||||
move_tensor(params, device='cpu')
|
|
||||||
|
|
||||||
# flatten the reordered tensors
|
|
||||||
for rank in range(self._world_size):
|
|
||||||
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
|
|
||||||
flat_tensor = flatten(tensor_list)
|
|
||||||
flat_tensor = flat_tensor.cuda()
|
|
||||||
self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor)
|
|
||||||
|
|
||||||
# sync parameters
|
|
||||||
for rank in range(self._world_size):
|
|
||||||
flat_tensor = self._param_store.get_flat_fp16_param_by_rank_group(rank, group_id)
|
|
||||||
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
|
|
||||||
sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list)
|
|
||||||
|
|
||||||
# create a copy of fp32 weights of the parameters for which this rank is responsible
|
|
||||||
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(self._local_rank, group_id)
|
|
||||||
# when using cpu offload, our cpu adam support fp16 paramters
|
|
||||||
if self._cpu_fp16_param:
|
|
||||||
fp32_flat_current_rank = fp16_flat_current_rank.detach()
|
|
||||||
else:
|
|
||||||
fp32_flat_current_rank = fp16_flat_current_rank.detach().float()
|
|
||||||
device = 'cpu' if self._cpu_offload else get_current_device()
|
|
||||||
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
|
|
||||||
fp32_flat_current_rank.requires_grad = True
|
|
||||||
self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank
|
|
||||||
|
|
||||||
# need to replace the params in the `params` field in the optimizer
|
|
||||||
# so that when the optimizer calls step(), it only updates the tensors
|
|
||||||
# managed by this data parallel rank
|
|
||||||
param_group['params'] = [fp32_flat_current_rank]
|
|
||||||
|
|
||||||
# set reduction state
|
|
||||||
for param in self._fp16_param_groups[group_id]:
|
|
||||||
self._param_store.set_param_reduction_state(param, False)
|
|
||||||
|
|
||||||
# intialize communication stream for
|
|
||||||
# communication-compuation overlapping
|
|
||||||
if self._overlap_communication:
|
|
||||||
self._comm_stream = torch.cuda.Stream()
|
|
||||||
|
|
||||||
# reduction hook is only used if overlapping communication
|
|
||||||
# or stage 2 is used
|
|
||||||
# if it is stage 1 without overlapping, no hook will be attached
|
|
||||||
if self._overlap_communication or self._partition_grads:
|
|
||||||
self._attach_reduction_hook()
|
|
||||||
|
|
||||||
self._initialize_optimizer_states()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def loss_scale(self):
|
|
||||||
return self.grad_scaler.scale
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_param_groups(self):
|
|
||||||
return len(self._fp16_param_groups)
|
|
||||||
|
|
||||||
def _partition_param_list(self, param_list):
|
|
||||||
params_per_rank = [[] for _ in range(self._world_size)]
|
|
||||||
numel_per_rank = [0 for _ in range(self._world_size)]
|
|
||||||
|
|
||||||
# partititon the parameters in a greedy fashion
|
|
||||||
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
|
|
||||||
for param in sorted_params:
|
|
||||||
# allocate this parameter to the rank with
|
|
||||||
# the smallest numel for load balancing purpose
|
|
||||||
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
|
||||||
params_per_rank[rank_to_go].append(param)
|
|
||||||
numel_per_rank[rank_to_go] += param.numel()
|
|
||||||
|
|
||||||
if self._verbose:
|
|
||||||
self._logger.info(f'Number of elements on ranks: {numel_per_rank}',
|
|
||||||
ranks=[0],
|
|
||||||
parallel_mode=self._dp_parallel_mode)
|
|
||||||
return params_per_rank
|
|
||||||
|
|
||||||
def _initialize_optimizer_states(self):
|
|
||||||
# create a dummy zero tensor which has the same shape as that of the param
|
|
||||||
# set this dummpy zero tensor as grad
|
|
||||||
for group_id in range(len(self._fp32_flat_param_groups_of_current_rank)):
|
|
||||||
fp32_partition_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
|
||||||
fp32_partition_grad = torch.zeros_like(fp32_partition_param)
|
|
||||||
fp32_partition_param.grad = fp32_partition_grad
|
|
||||||
|
|
||||||
# update the parameter with zero gradients for initialization of optimizer stateus
|
|
||||||
self._optimizer.step()
|
|
||||||
|
|
||||||
# remove the grad of the paramter to save memory
|
|
||||||
for group_id, fp32_flat_tensor in self._fp32_flat_param_groups_of_current_rank.items():
|
|
||||||
fp32_flat_tensor.grad = None
|
|
||||||
|
|
||||||
def _sanity_checks(self):
|
|
||||||
assert torch.cuda.is_available(), 'CUDA is required'
|
|
||||||
assert self._dtype == torch.float16, \
|
|
||||||
f'Parameters are expected to be of type torch.float16, but got {self._dtype}'
|
|
||||||
|
|
||||||
###########################################################
|
|
||||||
# Backward Reduction Hook
|
|
||||||
###########################################################
|
|
||||||
|
|
||||||
def _attach_reduction_hook(self):
|
|
||||||
# we iterate over the fp16 params
|
|
||||||
# on each param, we register a hook to its AccumulateGrad object
|
|
||||||
for group_id in range(self.num_param_groups):
|
|
||||||
param_group = self._fp16_param_groups[group_id]
|
|
||||||
for param in param_group:
|
|
||||||
if param.requires_grad:
|
|
||||||
# determines the reduction destionation rank
|
|
||||||
# this is only valid for stage 2
|
|
||||||
# dst_rank = None means using all-reduce
|
|
||||||
# else using reduce
|
|
||||||
if self._partition_grads:
|
|
||||||
reduce_rank = self._param_store.get_param_rank(param)
|
|
||||||
else:
|
|
||||||
reduce_rank = None
|
|
||||||
|
|
||||||
def _define_and_attach(param, reduce_rank):
|
|
||||||
# get the AccumulateGrad object of the param itself
|
|
||||||
accum_grad_obj = get_grad_accumulate_object(param)
|
|
||||||
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
|
|
||||||
|
|
||||||
reduction_func = partial(self._reduce_and_remove_grads_by_bucket,
|
|
||||||
param=param,
|
|
||||||
reduce_rank=reduce_rank)
|
|
||||||
|
|
||||||
# define hook
|
|
||||||
# NOT IMPORTANT BUT GOOD TO KNOW:
|
|
||||||
# args here is not grad, but allow_unreacable and accumulate_grad
|
|
||||||
def reduce_grad_hook(*args):
|
|
||||||
reduction_func()
|
|
||||||
|
|
||||||
accum_grad_obj.register_hook(reduce_grad_hook)
|
|
||||||
|
|
||||||
_define_and_attach(param, reduce_rank)
|
|
||||||
|
|
||||||
def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None):
|
|
||||||
param_size = param.numel()
|
|
||||||
|
|
||||||
# check if the bucket is full
|
|
||||||
# if full, will reduce the grads already in the bucket
|
|
||||||
# after reduction, the bucket will be empty
|
|
||||||
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
|
||||||
self._reduce_grads_in_bucket(reduce_rank)
|
|
||||||
|
|
||||||
# the param must not be reduced to ensure correctness
|
|
||||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
|
||||||
if is_param_reduced:
|
|
||||||
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
|
|
||||||
+ 'duplicate reduction will lead to arithmetic incorrectness'
|
|
||||||
raise RuntimeError(msg)
|
|
||||||
|
|
||||||
# the param must have grad for reduction
|
|
||||||
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
|
|
||||||
|
|
||||||
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
|
|
||||||
self._bucket_store.add_grad(param.grad, reduce_rank)
|
|
||||||
self._bucket_store.add_param(param, reduce_rank)
|
|
||||||
|
|
||||||
def _reduce_grads_in_bucket(self, reduce_rank=None):
|
|
||||||
# reduce grads
|
|
||||||
self._reduce_grads_by_rank(reduce_rank=reduce_rank,
|
|
||||||
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
|
|
||||||
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
|
|
||||||
|
|
||||||
# use communication stream if overlapping
|
|
||||||
# communication with computation
|
|
||||||
if self._overlap_communication:
|
|
||||||
stream = self._comm_stream
|
|
||||||
else:
|
|
||||||
stream = torch.cuda.current_stream()
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
|
|
||||||
|
|
||||||
for param in params_in_bucket:
|
|
||||||
# the is_param_reduced flag should be False showing that
|
|
||||||
# this param is not reduced before calling self._reduce_grads_by_rank
|
|
||||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
|
||||||
|
|
||||||
if is_param_reduced:
|
|
||||||
msg = f'Parameter of size ({param.size()}) has been reduced, ' + \
|
|
||||||
'duplicate reduction will lead to arithmetic incorrectness'
|
|
||||||
raise RuntimeError(msg)
|
|
||||||
|
|
||||||
# update the flag
|
|
||||||
self._param_store.set_param_reduction_state(param, True)
|
|
||||||
|
|
||||||
# if partition grads = True
|
|
||||||
# we do not keep the gradient after reduction
|
|
||||||
if self._partition_grads and not self._param_store.belongs_to_current_rank(param):
|
|
||||||
if self._overlap_communication:
|
|
||||||
# we need to keep this gradient for now as reduction may
|
|
||||||
# be completed yet since it is using a different cuda stream
|
|
||||||
self._param_store.add_previous_reduced_param(param)
|
|
||||||
else:
|
|
||||||
param.grad = None
|
|
||||||
|
|
||||||
self._bucket_store.reset_by_rank(reduce_rank)
|
|
||||||
|
|
||||||
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
|
|
||||||
grad_buckets_by_dtype = split_half_float_double(grads)
|
|
||||||
|
|
||||||
for tensor_list in grad_buckets_by_dtype:
|
|
||||||
self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank)
|
|
||||||
|
|
||||||
##############################
|
|
||||||
# Reduction Utility Function #
|
|
||||||
##############################
|
|
||||||
def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank):
|
|
||||||
param_bucket = TensorBucket(size=bucket_size)
|
|
||||||
|
|
||||||
for tensor in tensor_list:
|
|
||||||
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
|
||||||
|
|
||||||
if param_bucket.is_full_or_oversized():
|
|
||||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
|
||||||
param_bucket.empty()
|
|
||||||
|
|
||||||
if not param_bucket.is_empty():
|
|
||||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
|
||||||
|
|
||||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
|
|
||||||
if self._overlap_communication:
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self._param_store.clear_grads_of_previous_reduced_params()
|
|
||||||
stream = self._comm_stream
|
|
||||||
else:
|
|
||||||
stream = torch.cuda.current_stream()
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
flat = bucket.flatten()
|
|
||||||
reduced_flat = reduce_tensor(tensor=flat,
|
|
||||||
dtype=self._communication_dtype,
|
|
||||||
dst_rank=reduce_rank,
|
|
||||||
parallel_mode=self._dp_parallel_mode)
|
|
||||||
|
|
||||||
# update the reduced tensor
|
|
||||||
if reduce_rank is None or reduce_rank == self._local_rank:
|
|
||||||
bucket.unflatten_and_copy(reduced_flat)
|
|
||||||
|
|
||||||
################################
|
|
||||||
# torch.optim.Optimizer methods
|
|
||||||
################################
|
|
||||||
|
|
||||||
def backward(self, loss, retain_graph=True):
|
|
||||||
loss = self.loss_scale * loss
|
|
||||||
loss.backward(retain_graph=retain_graph)
|
|
||||||
|
|
||||||
def zero_grad(self, set_to_none=True):
|
|
||||||
"""
|
|
||||||
Set parameter gradients to zero. If set_to_none = True, gradient
|
|
||||||
will be set to None to save memory.
|
|
||||||
|
|
||||||
:param set_to_none: Whether set the gradient to None. Default value is True.
|
|
||||||
:type set_to_none: bool
|
|
||||||
"""
|
|
||||||
for group_id, param_group in self._fp16_param_groups.items():
|
|
||||||
for param in param_group:
|
|
||||||
if set_to_none:
|
|
||||||
param.grad = None
|
|
||||||
else:
|
|
||||||
if param.grad is not None:
|
|
||||||
param.grad.detach()
|
|
||||||
param.grad.zero_()
|
|
||||||
|
|
||||||
####################
|
|
||||||
# Update Parameter #
|
|
||||||
####################
|
|
||||||
|
|
||||||
def step(self, closure=None):
|
|
||||||
assert closure is None, 'closure is not supported by step()'
|
|
||||||
|
|
||||||
# check for overflow
|
|
||||||
found_inf = self._check_overflow()
|
|
||||||
self.grad_scaler.update(found_inf)
|
|
||||||
|
|
||||||
# update loss scale if overflow occurs
|
|
||||||
if found_inf:
|
|
||||||
self._grad_store._averaged_gradients = dict()
|
|
||||||
self.zero_grad()
|
|
||||||
return
|
|
||||||
|
|
||||||
# copy the grad of fp16 param to fp32 param
|
|
||||||
single_grad_partition_groups = []
|
|
||||||
norm_groups = []
|
|
||||||
|
|
||||||
for group_id in range(self.num_param_groups):
|
|
||||||
# compute norm
|
|
||||||
norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id],
|
|
||||||
params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id,
|
|
||||||
rank=self._local_rank),
|
|
||||||
dp_group=self._dp_group,
|
|
||||||
mp_group=self._mp_group)
|
|
||||||
norm_groups.append(norm_group)
|
|
||||||
|
|
||||||
# create flat gradient for the flat fp32 params
|
|
||||||
fp16_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id)
|
|
||||||
flat_fp16_avg_grads = flatten(fp16_avg_grads)
|
|
||||||
|
|
||||||
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
|
|
||||||
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
|
|
||||||
|
|
||||||
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
|
|
||||||
assert param_shape == flat_fp32_avg_grads.shape, \
|
|
||||||
f'fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}'
|
|
||||||
|
|
||||||
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
|
||||||
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
|
||||||
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
|
||||||
self._grad_store._averaged_gradients[group_id] = []
|
|
||||||
self._grad_store._averaged_gradients[group_id] = []
|
|
||||||
|
|
||||||
# unscale and clip grads
|
|
||||||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
|
||||||
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm)
|
|
||||||
|
|
||||||
# update the parameters
|
|
||||||
self._optimizer.step()
|
|
||||||
# release the fp32 grad
|
|
||||||
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
|
|
||||||
|
|
||||||
# update fp16 partition updated by the current rank
|
|
||||||
for group_id in range(len(self._fp16_param_groups)):
|
|
||||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=self._local_rank, group_id=group_id)
|
|
||||||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id].to(fp16_param.device)
|
|
||||||
fp16_param.data.copy_(fp32_param)
|
|
||||||
|
|
||||||
# broadcast the updated model weights
|
|
||||||
handles = []
|
|
||||||
for group_id in range(self.num_param_groups):
|
|
||||||
for rank in range(self._world_size):
|
|
||||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
|
||||||
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_group, async_op=True)
|
|
||||||
handles.append(handle)
|
|
||||||
|
|
||||||
for handle in handles:
|
|
||||||
handle.wait()
|
|
||||||
|
|
||||||
##################
|
|
||||||
# FP16 Utilities #
|
|
||||||
##################
|
|
||||||
|
|
||||||
def _check_overflow(self):
|
|
||||||
# clear previous overflow record
|
|
||||||
self._found_overflow.fill_(0.0)
|
|
||||||
|
|
||||||
# check for overflow
|
|
||||||
for group_id in range(len(self._fp16_param_groups)):
|
|
||||||
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
|
|
||||||
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
|
||||||
self._found_overflow.fill_(1.0)
|
|
||||||
break
|
|
||||||
|
|
||||||
# all-reduce across dp group
|
|
||||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_group)
|
|
||||||
|
|
||||||
# all-reduce over model parallel group
|
|
||||||
if self._mp_group:
|
|
||||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_group)
|
|
||||||
|
|
||||||
if self._found_overflow.item() > 0:
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
|
|
||||||
# compute combined scale factor for this group
|
|
||||||
combined_scale = self.loss_scale
|
|
||||||
|
|
||||||
if self._clip_grad_norm > 0.:
|
|
||||||
# norm is in fact norm*scale
|
|
||||||
clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm
|
|
||||||
if clip > 1:
|
|
||||||
combined_scale = clip * self.loss_scale
|
|
||||||
|
|
||||||
for grad in grad_groups_flat:
|
|
||||||
grad.data.mul_(1. / combined_scale)
|
|
||||||
|
|
||||||
############################
|
|
||||||
# Gradient Synchronization #
|
|
||||||
############################
|
|
||||||
|
|
||||||
def sync_grad(self):
|
|
||||||
if not self._partition_grads:
|
|
||||||
self._reduce_grad_stage1()
|
|
||||||
else:
|
|
||||||
# TODO: support async comm in reduce
|
|
||||||
self._reduce_grad_stage2()
|
|
||||||
|
|
||||||
# update param already reduced flag
|
|
||||||
reduction_states = self._param_store.get_param_reduction_states()
|
|
||||||
for tensor, state in reduction_states.items():
|
|
||||||
reduction_states[tensor] = False
|
|
||||||
|
|
||||||
# clear reduced grads
|
|
||||||
if self._overlap_communication:
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self._param_store.clear_grads_of_previous_reduced_params()
|
|
||||||
|
|
||||||
# accumulate gradient
|
|
||||||
avg_gradients = self._grad_store._averaged_gradients
|
|
||||||
for group_id in range(self.num_param_groups):
|
|
||||||
param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id)
|
|
||||||
|
|
||||||
if group_id not in avg_gradients:
|
|
||||||
avg_gradients[group_id] = []
|
|
||||||
|
|
||||||
param_idx = 0
|
|
||||||
for param in param_group:
|
|
||||||
if param.grad is not None:
|
|
||||||
if len(avg_gradients[group_id]) == param_idx:
|
|
||||||
avg_gradients[group_id].append(param.grad)
|
|
||||||
else:
|
|
||||||
avg_gradients[group_id][param_idx].add_(param.grad)
|
|
||||||
param_idx += 1
|
|
||||||
|
|
||||||
# the gradients needed are stored in the avg_gradients buffer
|
|
||||||
# thus, can clear this
|
|
||||||
self.zero_grad()
|
|
||||||
|
|
||||||
def _reduce_grad_stage1(self):
|
|
||||||
# if not overlapping communication (no reduction hook is attached)
|
|
||||||
# we need to manually reduce these gradients
|
|
||||||
if not self._overlap_communication:
|
|
||||||
for group_id in range(len(self._fp16_param_groups)):
|
|
||||||
param_group = self._fp16_param_groups[group_id]
|
|
||||||
for param in param_group:
|
|
||||||
if param.grad is not None:
|
|
||||||
self._reduce_and_remove_grads_by_bucket(param)
|
|
||||||
|
|
||||||
# we need to reduce the gradients
|
|
||||||
# left in the communication bucket
|
|
||||||
self._reduce_grads_in_bucket()
|
|
||||||
|
|
||||||
def _reduce_grad_stage2(self):
|
|
||||||
# when partition_grads is True, reduction hooks
|
|
||||||
# are attached in the __init__ function, so we
|
|
||||||
# only need to reduce the gradients
|
|
||||||
# left in the communication bucket
|
|
||||||
for reduce_rank in range(self._world_size):
|
|
||||||
self._reduce_grads_in_bucket(reduce_rank)
|
|
@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
|
|||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32
|
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
@ -7,9 +7,7 @@ import colossalai
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
def run_tensor_move(rank):
|
def run_tensor_move(rank):
|
||||||
|
@ -2,11 +2,9 @@
|
|||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import operator as op
|
|
||||||
from functools import partial, reduce
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -14,10 +12,11 @@ import torch.multiprocessing as mp
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
|
from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
|
||||||
from colossalai.zero.sharded_model import ShardedModel
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from colossalai.testing import parameterize
|
from colossalai.testing import parameterize
|
||||||
|
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_wrapper(module, enable=True):
|
def checkpoint_wrapper(module, enable=True):
|
||||||
@ -97,41 +96,9 @@ def check_params(model, zero_model, loose=False):
|
|||||||
assert allclose(p, zero_p, loose=loose)
|
assert allclose(p, zero_p, loose=loose)
|
||||||
|
|
||||||
|
|
||||||
@parameterize('checkpoint', [False, True])
|
|
||||||
@parameterize('fp16', [False, True])
|
|
||||||
@parameterize('offload', [False, True])
|
|
||||||
@parameterize('norm_type', [1.0, 2.0, float('inf')])
|
|
||||||
def check_config(checkpoint=False, fp16=False, offload=False, norm_type=2.0):
|
|
||||||
model = Net(checkpoint=checkpoint).cuda()
|
|
||||||
zero_model = copy.deepcopy(model)
|
|
||||||
ddp_model = DDP(model)
|
|
||||||
|
|
||||||
offload_config = {}
|
|
||||||
if offload:
|
|
||||||
offload_config['device'] = 'cpu'
|
|
||||||
zero_model = zero_model.cpu()
|
|
||||||
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
|
|
||||||
|
|
||||||
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
|
|
||||||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
|
|
||||||
for _ in range(5):
|
|
||||||
x = torch.rand(2, 5).cuda()
|
|
||||||
run_step(ddp_model, optimizer, x, enable_autocast=fp16, norm_type=norm_type)
|
|
||||||
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16, norm_type=norm_type)
|
|
||||||
check_grads(ddp_model, zero_model)
|
|
||||||
check_params(ddp_model, zero_model)
|
|
||||||
for _ in range(5):
|
|
||||||
x = torch.rand(2, 5).cuda()
|
|
||||||
run_step(ddp_model, optimizer, x, enable_autocast=False, norm_type=norm_type)
|
|
||||||
run_step(zero_model, zero_optimizer, x, enable_autocast=False, norm_type=norm_type)
|
|
||||||
check_grads(ddp_model, zero_model, loose=True)
|
|
||||||
check_params(ddp_model, zero_model, loose=True)
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
check_config()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@ -12,7 +12,7 @@ from colossalai.utils import free_port
|
|||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
|
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
@ -1,168 +0,0 @@
|
|||||||
import torch
|
|
||||||
import colossalai
|
|
||||||
import copy
|
|
||||||
import pytest
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from colossalai.zero import ShardedOptimizer
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
|
|
||||||
from colossalai.utils import free_port
|
|
||||||
from functools import partial
|
|
||||||
from common import allclose
|
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
|
||||||
|
|
||||||
|
|
||||||
def check_completely_equal(a, b):
|
|
||||||
"""
|
|
||||||
This function checks if two tensors are completely equal
|
|
||||||
"""
|
|
||||||
assert torch.all(a == b), f'a = {a}, b = {b}'
|
|
||||||
|
|
||||||
|
|
||||||
def check_sharded_param_consistency():
|
|
||||||
"""
|
|
||||||
In this test, we want to test whether zero stage 1 and 2
|
|
||||||
deliver the same numerical results despite different communication
|
|
||||||
pattern
|
|
||||||
|
|
||||||
we use these prefixes to differentiate the zero stage
|
|
||||||
oss: partition optimizer states
|
|
||||||
pg: partition gradients and optimizer states
|
|
||||||
|
|
||||||
"""
|
|
||||||
test_models = ['repeated_computed_layers', 'resnet18', 'nested_model']
|
|
||||||
|
|
||||||
for name in test_models:
|
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(name)
|
|
||||||
model_builder, train_dataloader, *_ = get_components_func()
|
|
||||||
|
|
||||||
# create model
|
|
||||||
oss_model = model_builder(checkpoint=True).cuda().half()
|
|
||||||
pg_model = copy.deepcopy(oss_model)
|
|
||||||
|
|
||||||
# create optimizer
|
|
||||||
oss_optimizer = torch.optim.Adam(oss_model.parameters(), lr=0.001)
|
|
||||||
pg_optimizer = torch.optim.Adam(pg_model.parameters(), lr=0.001)
|
|
||||||
oss_optimizer = ShardedOptimizer(oss_optimizer, overlap_communication=True, initial_scale=1, clip_grad_norm=0.0)
|
|
||||||
pg_optimizer = ShardedOptimizer(pg_optimizer,
|
|
||||||
overlap_communication=True,
|
|
||||||
partition_grad=True,
|
|
||||||
initial_scale=1,
|
|
||||||
clip_grad_norm=0.0)
|
|
||||||
|
|
||||||
# create
|
|
||||||
data, label = next(iter(train_dataloader))
|
|
||||||
input_data = data.cuda().half()
|
|
||||||
|
|
||||||
# forward
|
|
||||||
oss_output = oss_model(input_data)
|
|
||||||
pg_output = pg_model(input_data)
|
|
||||||
check_completely_equal(oss_output, pg_output)
|
|
||||||
|
|
||||||
# backward
|
|
||||||
oss_optimizer.backward(oss_output.mean().float())
|
|
||||||
pg_optimizer.backward(pg_output.mean().float())
|
|
||||||
|
|
||||||
# check grad
|
|
||||||
# as this param is small, the backward reduction
|
|
||||||
# will not be fired
|
|
||||||
for oss_param, pg_param in zip(oss_model.parameters(), pg_model.parameters()):
|
|
||||||
check_completely_equal(oss_param.grad, pg_param.grad)
|
|
||||||
|
|
||||||
# step
|
|
||||||
oss_optimizer.sync_grad()
|
|
||||||
pg_optimizer.sync_grad()
|
|
||||||
|
|
||||||
# step
|
|
||||||
oss_optimizer.step()
|
|
||||||
pg_optimizer.step()
|
|
||||||
|
|
||||||
# check updated param
|
|
||||||
for oss_param, pg_param in zip(oss_model.parameters(), pg_model.parameters()):
|
|
||||||
check_completely_equal(oss_param, pg_param)
|
|
||||||
|
|
||||||
|
|
||||||
def check_sharded_optim_against_torch_ddp():
|
|
||||||
"""
|
|
||||||
In this test, two pairs of model and optimizers are created.
|
|
||||||
1. zero: use sharded optimizer and fp16 parameters
|
|
||||||
2. torch: use torch DDP and fp32 parameters
|
|
||||||
|
|
||||||
We feed these two sets of models with the same input and check if the
|
|
||||||
differences in model output and updated parameters are within tolerance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
test_models = ['repeated_computed_layers', 'resnet18', 'nested_model']
|
|
||||||
|
|
||||||
for name in test_models:
|
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(name)
|
|
||||||
model_builder, train_dataloader, *_ = get_components_func()
|
|
||||||
|
|
||||||
# create model
|
|
||||||
zero_model = model_builder(checkpoint=True).cuda()
|
|
||||||
torch_model = copy.deepcopy(zero_model)
|
|
||||||
|
|
||||||
zero_model = zero_model.half()
|
|
||||||
torch_model = DDP(torch_model.cuda())
|
|
||||||
|
|
||||||
# create optimizer
|
|
||||||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001)
|
|
||||||
|
|
||||||
# we only test stage 1 here
|
|
||||||
# in `check_sharded_param_consistency.py`, we will test whether
|
|
||||||
# level 1 and 2 will produce exactly the same results
|
|
||||||
zero_optimizer = ShardedOptimizer(zero_optimizer,
|
|
||||||
overlap_communication=True,
|
|
||||||
initial_scale=1,
|
|
||||||
clip_grad_norm=0.0)
|
|
||||||
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001)
|
|
||||||
|
|
||||||
# create
|
|
||||||
input_data, _ = next(iter(train_dataloader))
|
|
||||||
input_data = input_data.cuda()
|
|
||||||
|
|
||||||
# zero-dp forward
|
|
||||||
zero_output = zero_model(input_data.half())
|
|
||||||
|
|
||||||
# torch-ddp forward
|
|
||||||
torch_output = torch_model(input_data)
|
|
||||||
allclose(zero_output, torch_output.half())
|
|
||||||
|
|
||||||
# zero-dp backward
|
|
||||||
zero_optimizer.backward(zero_output.mean().float())
|
|
||||||
|
|
||||||
# torch-ddp backward
|
|
||||||
torch_output.mean().backward()
|
|
||||||
|
|
||||||
# check grad
|
|
||||||
for oss_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
|
|
||||||
allclose(oss_param.grad, torch_param.grad.half())
|
|
||||||
|
|
||||||
# zero-dp step
|
|
||||||
zero_optimizer.sync_grad()
|
|
||||||
zero_optimizer.step()
|
|
||||||
|
|
||||||
# torch ddp step
|
|
||||||
torch_optimizer.step()
|
|
||||||
|
|
||||||
# check updated param
|
|
||||||
for oss_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
|
|
||||||
allclose(oss_param, torch_param.half())
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
|
||||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
|
||||||
|
|
||||||
check_sharded_optim_against_torch_ddp()
|
|
||||||
check_sharded_param_consistency()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
|
||||||
def test_sharded_optim():
|
|
||||||
world_size = 2
|
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_sharded_optim()
|
|
@ -1,39 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.zero.sharded_model.param_manager import Zero3ParameterManager
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
|
||||||
from colossalai.utils import free_port
|
|
||||||
from common import CONFIG
|
|
||||||
|
|
||||||
|
|
||||||
def run_shard_shape_check(rank, world_size, port):
|
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
||||||
|
|
||||||
model = torch.nn.Linear(2, 4 * world_size)
|
|
||||||
gpc.init_parallel_groups()
|
|
||||||
Zero3ParameterManager(module=model,
|
|
||||||
process_group=gpc.get_group(ParallelMode.DATA),
|
|
||||||
offload_config=CONFIG.get('offload_param_config'))
|
|
||||||
|
|
||||||
assert (model.weight.numel() == 4 * 2)
|
|
||||||
assert (model.bias.numel() == 4)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
|
||||||
@pytest.mark.parametrize("world_size", [1, 2, 4])
|
|
||||||
def test_run_shard_shape(world_size):
|
|
||||||
run_func = partial(run_shard_shape_check, world_size=world_size, port=free_port())
|
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_run_shard_shape(2)
|
|
@ -1,19 +0,0 @@
|
|||||||
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
repo_path = Path(__file__).absolute().parents[2]
|
|
||||||
sys.path.append(str(repo_path))
|
|
||||||
|
|
||||||
try:
|
|
||||||
import model_zoo.vit.vision_transformer_from_config
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("model_zoo is not found, please check your path")
|
|
||||||
|
|
||||||
BATCH_SIZE = 8
|
|
||||||
IMG_SIZE = 32
|
|
||||||
PATCH_SIZE = 4
|
|
||||||
DIM = 512
|
|
||||||
NUM_ATTENTION_HEADS = 8
|
|
||||||
SUMMA_DIM = 2
|
|
||||||
NUM_CLASSES = 10
|
|
||||||
DEPTH = 6
|
|
@ -1,99 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
import os
|
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch.autograd
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
from colossalai.nn import CrossEntropyLoss
|
|
||||||
from colossalai.utils import free_port, get_dataloader
|
|
||||||
from model_zoo.vit import vit_lite_depth7_patch4_32
|
|
||||||
from torchvision import transforms
|
|
||||||
from torchvision.datasets import CIFAR10
|
|
||||||
|
|
||||||
from components import *
|
|
||||||
|
|
||||||
CONFIG = dict(parallel=dict(
|
|
||||||
pipeline=dict(size=1),
|
|
||||||
tensor=dict(size=4, mode='2d'),
|
|
||||||
),
|
|
||||||
fp16=dict(mode=None, ),
|
|
||||||
zero=dict(level=2))
|
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(engine, train_dataloader):
|
|
||||||
engine.train()
|
|
||||||
accumulated_loss = 0
|
|
||||||
num_steps = len(train_dataloader)
|
|
||||||
data_iter = iter(train_dataloader)
|
|
||||||
for i in range(num_steps):
|
|
||||||
output, label, loss = engine.step(data_iter)
|
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
|
||||||
avg_loss = accumulated_loss / num_steps
|
|
||||||
return avg_loss
|
|
||||||
|
|
||||||
|
|
||||||
def run_2d_parallel_vision_transformer_level_2(rank, world_size, port):
|
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
||||||
|
|
||||||
# build model
|
|
||||||
model = vit_lite_depth7_patch4_32()
|
|
||||||
|
|
||||||
# build dataloader# build dataloaders
|
|
||||||
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
|
|
||||||
download=True,
|
|
||||||
transform=transforms.Compose([
|
|
||||||
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
|
||||||
]))
|
|
||||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
|
||||||
shuffle=True,
|
|
||||||
batch_size=BATCH_SIZE,
|
|
||||||
pin_memory=True,
|
|
||||||
drop_last=True)
|
|
||||||
|
|
||||||
# build optimizer and loss
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
|
||||||
criterion = CrossEntropyLoss()
|
|
||||||
|
|
||||||
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
criterion=criterion,
|
|
||||||
train_dataloader=train_dataloader)
|
|
||||||
logger = get_dist_logger()
|
|
||||||
|
|
||||||
logger.info('start training')
|
|
||||||
engine.train()
|
|
||||||
|
|
||||||
for img, label in train_dataloader:
|
|
||||||
engine.zero_grad()
|
|
||||||
img = img.cuda()
|
|
||||||
label = label.cuda()
|
|
||||||
out = engine(img)
|
|
||||||
loss = engine.criterion(out, label)
|
|
||||||
engine.backward(loss)
|
|
||||||
engine.step()
|
|
||||||
break
|
|
||||||
|
|
||||||
gpc.destroy()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
|
||||||
@pytest.mark.skip(reason="This test should be refactored for the reconstructed zero")
|
|
||||||
def test_2d_vit_zero_level_2():
|
|
||||||
world_size = 8
|
|
||||||
run_func = partial(run_2d_parallel_vision_transformer_level_2, world_size=world_size, port=free_port())
|
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_2d_vit_zero_level_2()
|
|
@ -1,99 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
import os
|
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch.autograd
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
from colossalai.nn import CrossEntropyLoss
|
|
||||||
from colossalai.utils import free_port, get_dataloader
|
|
||||||
from model_zoo.vit import vit_lite_depth7_patch4_32
|
|
||||||
from torchvision import transforms
|
|
||||||
from torchvision.datasets import CIFAR10
|
|
||||||
|
|
||||||
from components import *
|
|
||||||
|
|
||||||
CONFIG = dict(parallel=dict(
|
|
||||||
pipeline=dict(size=1),
|
|
||||||
tensor=dict(size=4, mode='2d'),
|
|
||||||
),
|
|
||||||
fp16=dict(mode=None, ),
|
|
||||||
zero=dict(level=3))
|
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(engine, train_dataloader):
|
|
||||||
engine.train()
|
|
||||||
accumulated_loss = 0
|
|
||||||
num_steps = len(train_dataloader)
|
|
||||||
data_iter = iter(train_dataloader)
|
|
||||||
for i in range(num_steps):
|
|
||||||
output, label, loss = engine.step(data_iter)
|
|
||||||
accumulated_loss += loss.detach().cpu().numpy()
|
|
||||||
avg_loss = accumulated_loss / num_steps
|
|
||||||
return avg_loss
|
|
||||||
|
|
||||||
|
|
||||||
def run_2d_parallel_vision_transformer_level_3(rank, world_size, port):
|
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
||||||
|
|
||||||
# build model
|
|
||||||
model = vit_lite_depth7_patch4_32()
|
|
||||||
|
|
||||||
# build dataloader# build dataloaders
|
|
||||||
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
|
|
||||||
download=True,
|
|
||||||
transform=transforms.Compose([
|
|
||||||
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
|
||||||
]))
|
|
||||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
|
||||||
shuffle=True,
|
|
||||||
batch_size=BATCH_SIZE,
|
|
||||||
pin_memory=True,
|
|
||||||
drop_last=True)
|
|
||||||
|
|
||||||
# build optimizer and loss
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
|
||||||
criterion = CrossEntropyLoss()
|
|
||||||
|
|
||||||
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
criterion=criterion,
|
|
||||||
train_dataloader=train_dataloader)
|
|
||||||
logger = get_dist_logger()
|
|
||||||
|
|
||||||
logger.info('start training')
|
|
||||||
engine.train()
|
|
||||||
|
|
||||||
for img, label in train_dataloader:
|
|
||||||
engine.zero_grad()
|
|
||||||
img = img.cuda()
|
|
||||||
label = label.cuda()
|
|
||||||
out = engine(img)
|
|
||||||
loss = engine.criterion(out, label)
|
|
||||||
engine.backward(loss)
|
|
||||||
engine.step()
|
|
||||||
break
|
|
||||||
|
|
||||||
gpc.destroy()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
|
||||||
@pytest.mark.skip(reason="This test should be refactored for the reconstructed zero")
|
|
||||||
def test_3d_vit_zero_level_3():
|
|
||||||
world_size = 8
|
|
||||||
run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size, port=free_port())
|
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_3d_vit_zero_level_3()
|
|
Loading…
Reference in New Issue
Block a user