mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
Polish sharded parameter (#297)
* init shard param from shape tuple * add more unitest for shard param * add more unittests to shareded param
This commit is contained in:
@@ -1,3 +0,0 @@
|
||||
from .shard_param import ShardParam
|
||||
|
||||
__all__ = ['ShardParam']
|
@@ -1,4 +1,3 @@
|
||||
|
||||
import functools
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -7,11 +6,10 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook,
|
||||
register_ophooks_recursively)
|
||||
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, register_ophooks_recursively)
|
||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.zero.shard_param import ShardParam
|
||||
from colossalai.zero.sharded_param import ShardedParam
|
||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
from colossalai.zero.sharded_model.sharded_grad import ShardedGradient
|
||||
from torch.distributed import ProcessGroup
|
||||
@@ -21,17 +19,19 @@ from ._zero3_utils import chunk_and_pad, get_gradient_predivide_factor
|
||||
|
||||
|
||||
class ShardedModelV2(nn.Module):
|
||||
def __init__(self,
|
||||
module: nn.Module,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_bucket_size_mb: int = 25,
|
||||
reshard_after_forward: bool = True,
|
||||
mixed_precision: bool = False,
|
||||
fp32_reduce_scatter: bool = False,
|
||||
offload_config: Optional[dict] = None,
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Module,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_bucket_size_mb: int = 25,
|
||||
reshard_after_forward: bool = True,
|
||||
mixed_precision: bool = False,
|
||||
fp32_reduce_scatter: bool = False,
|
||||
offload_config: Optional[dict] = None,
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
):
|
||||
r"""
|
||||
A demo to reconfigure zero1 shared_model.
|
||||
Currently do not consider the Optimizer States.
|
||||
@@ -49,7 +49,7 @@ class ShardedModelV2(nn.Module):
|
||||
|
||||
# Shard the parameters at first
|
||||
for _, param in self.module.named_parameters():
|
||||
param.ca_attr = ShardParam(param)
|
||||
param.ca_attr = ShardedParam(param)
|
||||
param.ca_attr.shard()
|
||||
param._sharded_grad = ShardedGradient(param, self, offload_config)
|
||||
|
||||
@@ -64,8 +64,10 @@ class ShardedModelV2(nn.Module):
|
||||
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
|
||||
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
|
||||
# So we use 1.0 as the default gradient_predivide_factor
|
||||
# However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically
|
||||
self.gradient_predivide_factor: float = gradient_predivide_factor if gradient_predivide_factor is not None else \
|
||||
# However, if you set gradient_predivide_factor to None,
|
||||
# we will set gradient_predivide_factor to a value >= 1.0 automatically
|
||||
self.gradient_predivide_factor: float = \
|
||||
gradient_predivide_factor if gradient_predivide_factor is not None else \
|
||||
get_gradient_predivide_factor(self.world_size)
|
||||
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
|
||||
|
||||
@@ -107,7 +109,8 @@ class ShardedModelV2(nn.Module):
|
||||
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
|
||||
full gradient for the local batch. The reduce-scatter op will save a single shard of the summed gradient across all
|
||||
full gradient for the local batch. The reduce-scatter op will save
|
||||
a single shard of the summed gradient across all
|
||||
GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example::
|
||||
|
||||
before reduce_scatter:
|
||||
@@ -139,8 +142,9 @@ class ShardedModelV2(nn.Module):
|
||||
orig_grad_data = new_grad.data
|
||||
if self.world_size > 1:
|
||||
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
|
||||
self.reducer.reduce_scatter_async(
|
||||
grad_chunks, group=self.reduce_scatter_process_group, callback_fn=functools.partial(self._reduce_scatter_callback, param))
|
||||
self.reducer.reduce_scatter_async(grad_chunks,
|
||||
group=self.reduce_scatter_process_group,
|
||||
callback_fn=functools.partial(self._reduce_scatter_callback, param))
|
||||
else:
|
||||
self._reduce_scatter_callback(param, new_grad)
|
||||
orig_grad_data.record_stream(self.comm_stream)
|
||||
|
3
colossalai/zero/sharded_param/__init__.py
Normal file
3
colossalai/zero/sharded_param/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .sharded_param import ShardedParam
|
||||
|
||||
__all__ = ['ShardedParam']
|
@@ -1,41 +1,59 @@
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||
from typing import Union, Tuple, Optional
|
||||
import numpy
|
||||
|
||||
|
||||
class TensorType(Enum):
|
||||
GRAD = 1
|
||||
DATA = 2
|
||||
|
||||
|
||||
class ShardParam(object):
|
||||
class ShardedParam(object):
|
||||
r"""
|
||||
A wrapper to torch.nn.Parameter. Shard a param
|
||||
on different processes.
|
||||
on memory space of different processes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
param: torch.nn.Parameter,
|
||||
tensor_type: TensorType = TensorType.DATA,
|
||||
process_group=None,
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
other: Union[torch.nn.Parameter, Tuple[int, ...]],
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
is_sharded: bool = False,
|
||||
device: Optional[torch.device] = None) -> None:
|
||||
r"""
|
||||
other: either an existing torch parameter or a tuple, indicate allocate a new param with the tuple as shape.
|
||||
process_group: the process group storing the shared data.
|
||||
is_sharded: is shared the param during __init__.
|
||||
device: the device to place param data payload on
|
||||
"""
|
||||
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self._param_payload = param.data if tensor_type == TensorType.DATA else param.grad
|
||||
self._payload_shape = None
|
||||
self._payload_numel = None
|
||||
self._origin_shape = param.shape
|
||||
self._origin_numel = param.numel()
|
||||
self._origin_dtype = param.dtype
|
||||
self.is_sharded = False
|
||||
|
||||
# Hijack the data payload of param
|
||||
if isinstance(other, torch.nn.Parameter):
|
||||
self._param_payload = other.data.to(device)
|
||||
self._origin_shape = other.shape
|
||||
self._origin_numel = other.numel()
|
||||
if is_sharded:
|
||||
self.shard()
|
||||
elif isinstance(other, tuple):
|
||||
self._origin_shape = other
|
||||
self._origin_numel = numpy.prod(other)
|
||||
|
||||
# TODO(jiaruifang) can be optimized. Directly allocate payload as the sharded shape.
|
||||
assert device is not None, "You have to assign a device to initialize a ShardParam from a shape tuple"
|
||||
self._param_payload = torch.empty(self._origin_shape, device=device)
|
||||
if is_sharded:
|
||||
self.shard()
|
||||
else:
|
||||
raise RuntimeError(f"Initialize ShardParam failed. The 2nd parameter is wrong type {type(other)}")
|
||||
|
||||
self._payload_numel = None
|
||||
|
||||
def payload(self, target_device: torch.device):
|
||||
r"""
|
||||
get the payload and move it to target device
|
||||
"""
|
||||
return self._param_payload.to(target_device)
|
||||
|
||||
def shard(self):
|
||||
@@ -50,6 +68,7 @@ class ShardParam(object):
|
||||
def gather(self):
|
||||
r"""
|
||||
Collect the payload of param from different processes to process of local rank.
|
||||
The payload has to be moved to cuda memory before communication.
|
||||
"""
|
||||
if not self.is_sharded:
|
||||
return
|
Reference in New Issue
Block a user