mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-21 23:02:07 +00:00
[zero] add sharded grad and refactor grad hooks for ShardedModel (#287)
This commit is contained in:
@@ -190,10 +190,6 @@ class ReduceScatterBucketer:
|
||||
return int(bucket_size // num_shards)
|
||||
|
||||
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
|
||||
# TODO (Min): the `group` used here in the key is the object hash, not the content
|
||||
# hash. That means if FSDP instances are initialized with different process groups,
|
||||
# even when the group members are in fact the same, we end up creating different
|
||||
# buckets here.
|
||||
key = (tensor.dtype, tensor.device, group)
|
||||
if key not in self.buckets:
|
||||
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
|
||||
|
85
colossalai/zero/sharded_model/sharded_grad.py
Normal file
85
colossalai/zero/sharded_model/sharded_grad.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class ShardedGradient:
|
||||
def __init__(self,
|
||||
param: Parameter,
|
||||
sharded_module: nn.Module,
|
||||
offload_config: Optional[dict] = None
|
||||
) -> None:
|
||||
assert hasattr(
|
||||
param, 'ca_attr') and param.ca_attr.is_sharded, 'ShardedGradient can only be initialized with sharded parameter'
|
||||
|
||||
self.param = param
|
||||
self.sharded_module = sharded_module
|
||||
self.offload_config = offload_config
|
||||
|
||||
self._cpu_offload = offload_config.get('device', None) == 'cpu' if offload_config else False
|
||||
|
||||
# _gpu_grad is either sharded or not
|
||||
# all saved grads are fp32
|
||||
self._gpu_grad: Optional[torch.Tensor] = None
|
||||
self._cpu_grad: Optional[torch.Tensor] = None
|
||||
|
||||
if self._cpu_offload:
|
||||
# this buffer will be held and reused every iteration
|
||||
self._cpu_grad = torch.zeros(param.ca_attr.payload('cpu'), dtype=torch.float).pin_memory()
|
||||
|
||||
@torch.no_grad()
|
||||
def setup(self) -> None:
|
||||
"""This function will be called pre-backward. Save the local accumulated gradient to _gpu_grad.
|
||||
When no_sync() is enable (_require_backward_grad_sync=False), the grad is accumulated locally in param.grad
|
||||
|
||||
:raises AssertionError: Raise if grad shape is wrong
|
||||
"""
|
||||
if self.sharded_module._require_backward_grad_sync and self.param.grad is not None:
|
||||
if self.param.grad.device != self.param.data.device:
|
||||
# TODO: offload?
|
||||
raise RuntimeError(
|
||||
'grad and param are on different device, grad {self.param.grad.device} vs. param {self.param.data.device}')
|
||||
else:
|
||||
self._gpu_grad = self.param.grad.data
|
||||
self.param.grad = None
|
||||
|
||||
def reduce_scatter_callback(self, reduced_grad: torch.Tensor) -> None:
|
||||
"""This function will be called in post-backward hook, so we cannot modify param.grad directly
|
||||
|
||||
:param reduced_grad: the reduced grad
|
||||
:type reduced_grad: torch.Tensor
|
||||
"""
|
||||
# Make sure we store fp32 grad
|
||||
if torch.is_floating_point(reduced_grad) and reduced_grad.dtype != torch.float:
|
||||
reduced_grad.data = reduced_grad.data.to(torch.float)
|
||||
|
||||
if self._gpu_grad is None:
|
||||
self._gpu_grad = reduced_grad.data
|
||||
else:
|
||||
self._gpu_grad += reduced_grad.data
|
||||
|
||||
# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
|
||||
# backwards pass completes, we will set `.grad` to the CPU copy.
|
||||
if self._cpu_offload:
|
||||
self._cpu_grad.copy_(self._gpu_grad.data, non_blocking=True)
|
||||
# Don't let this memory get reused until after the transfer.
|
||||
self._gpu_grad.data.record_stream(torch.cuda.current_stream())
|
||||
|
||||
@torch.no_grad()
|
||||
def write_back(self) -> None:
|
||||
"""This function will be called in final backward hook
|
||||
"""
|
||||
if self._cpu_grad is not None:
|
||||
assert self.param.device == torch.device(
|
||||
'cpu'), f'Incorrect param device, expected CPU, got {self.param.device}'
|
||||
self.param.grad.data = self._cpu_grad
|
||||
elif self._gpu_grad is not None:
|
||||
assert self.param.device == self._gpu_grad.device, f'Incorrect _gpu_grad device, param on {self.param.device} but _gpu_grad on {self._gpu_grad.device}'
|
||||
self.param.grad.data = self._gpu_grad
|
||||
else:
|
||||
raise RuntimeError('No grad to write back')
|
||||
# If using CPU offload, _cpu_grad will store the CPU tensor of _gpu_grad
|
||||
# They should be released here
|
||||
self._gpu_grad = None
|
@@ -1,31 +1,37 @@
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import os
|
||||
import traceback
|
||||
from collections import OrderedDict
|
||||
from enum import Enum, auto
|
||||
from typing import (Any, Callable, Dict, Generator, List, NamedTuple, Optional,
|
||||
Set, Union)
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook,
|
||||
register_ophooks_recursively)
|
||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
from torch.distributed import ProcessGroup
|
||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook, ShardParamHook
|
||||
from colossalai.zero.shard_param import ShardParam
|
||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
from colossalai.zero.sharded_model.sharded_grad import ShardedGradient
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ._zero3_utils import chunk_and_pad, get_gradient_predivide_factor
|
||||
|
||||
|
||||
class ShardedModelV2(nn.Module):
|
||||
def __init__(self,
|
||||
module: nn.Module,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_process_group: Optional[ProcessGroup] = None
|
||||
):
|
||||
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_bucket_size_mb: int = 25,
|
||||
reshard_after_forward: bool = True,
|
||||
mixed_precision: bool = False,
|
||||
fp32_reduce_scatter: bool = False,
|
||||
offload_config: Optional[dict] = None,
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
):
|
||||
r"""
|
||||
A demo to reconfigure zero1 shared_model.
|
||||
Currently do not consider the Optimizer States.
|
||||
@@ -45,19 +51,111 @@ class ShardedModelV2(nn.Module):
|
||||
for _, param in self.module.named_parameters():
|
||||
param.ca_attr = ShardParam(param)
|
||||
param.ca_attr.shard()
|
||||
param._sharded_grad = ShardedGradient(param, self, offload_config)
|
||||
|
||||
# Register hooks
|
||||
register_ophooks_recursively(self.module, [ShardParamHook()])
|
||||
register_ophooks_recursively(self.module, [ShardParamHook(), ShardGradHook()])
|
||||
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
|
||||
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||
|
||||
self.reshard_after_forward = reshard_after_forward
|
||||
self.mixed_precision = mixed_precision
|
||||
self.fp32_reduce_scatter = fp32_reduce_scatter
|
||||
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
|
||||
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
|
||||
# So we use 1.0 as the default gradient_predivide_factor
|
||||
# However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically
|
||||
self.gradient_predivide_factor: float = gradient_predivide_factor if gradient_predivide_factor is not None else \
|
||||
get_gradient_predivide_factor(self.world_size)
|
||||
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
|
||||
|
||||
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
||||
self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb)
|
||||
self._require_backward_grad_sync: bool = True
|
||||
|
||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
outputs = self.module(*args, **kwargs)
|
||||
return outputs
|
||||
|
||||
|
||||
def backward(self, loss):
|
||||
if self.loss_scaler:
|
||||
self.loss_scaler.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
|
||||
loss.backward()
|
||||
self._final_backward_hook()
|
||||
|
||||
@torch.no_grad()
|
||||
def _final_backward_hook(self) -> None:
|
||||
if self._require_backward_grad_sync:
|
||||
# Flush any unreduced buckets in the post_backward stream.
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
self.reducer.flush()
|
||||
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
||||
if self._cpu_offload:
|
||||
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.reducer.free()
|
||||
for p in self.module.parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
|
||||
# remains the unsharded gradient accumulated from prior no-sync passes, and _saved_grad_shard
|
||||
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
|
||||
# sync passes, if desired.
|
||||
if not self._require_backward_grad_sync:
|
||||
continue
|
||||
p._sharded_grad.write_back()
|
||||
|
||||
@torch.no_grad()
|
||||
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
|
||||
full gradient for the local batch. The reduce-scatter op will save a single shard of the summed gradient across all
|
||||
GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example::
|
||||
|
||||
before reduce_scatter:
|
||||
param.grad (GPU #0): [1, 2, 3, 4]
|
||||
param.grad (GPU #1): [5, 6, 7, 8]
|
||||
|
||||
after reduce_scatter:
|
||||
param.grad (GPU #0): [6, 8] # 1+5, 2+6
|
||||
param.grad (GPU #1): [10, 12] # 3+7, 4+8
|
||||
|
||||
The local GPU's ``optim.step`` is responsible for updating a single
|
||||
shard of params, also corresponding to the current GPU's rank. This
|
||||
alignment is created by `param._sharded_grad`, which ensures that
|
||||
the local optimizer only sees the relevant parameter shard.
|
||||
"""
|
||||
if grad is None:
|
||||
return
|
||||
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
|
||||
if not self._require_backward_grad_sync:
|
||||
return
|
||||
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
new_grad = grad.clone()
|
||||
if self.mixed_precision and self.fp32_reduce_scatter:
|
||||
new_grad.data = new_grad.data.to(param.dtype)
|
||||
if self.gradient_predivide_factor > 1.0:
|
||||
# Average grad by world_size for consistency with PyTorch DDP.
|
||||
new_grad.data.div_(self.gradient_predivide_factor)
|
||||
orig_grad_data = new_grad.data
|
||||
if self.world_size > 1:
|
||||
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
|
||||
self.reducer.reduce_scatter_async(
|
||||
grad_chunks, group=self.reduce_scatter_process_group, callback_fn=functools.partial(self._reduce_scatter_callback, param))
|
||||
else:
|
||||
self._reduce_scatter_callback(param, new_grad)
|
||||
orig_grad_data.record_stream(self.comm_stream)
|
||||
|
||||
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
||||
if self.gradient_postdivide_factor > 1:
|
||||
# Average grad by world_size for consistency with PyTorch DDP.
|
||||
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
||||
# Cast grad to param's dtype (typically FP32). Note: we do this
|
||||
# before the cpu offload step so that this entire hook remains
|
||||
# non-blocking. The downside is a bit more D2H transfer in that case.
|
||||
if self.mixed_precision:
|
||||
orig_param_grad_data = reduced_grad.data
|
||||
reduced_grad.data = reduced_grad.data.to(dtype=param.ca_attr.origin_dtype)
|
||||
# Don't let this memory get reused until after the transfer.
|
||||
orig_param_grad_data.record_stream(torch.cuda.current_stream())
|
||||
|
||||
param._sharded_grad.reduce_scatter_callback(reduced_grad)
|
||||
|
Reference in New Issue
Block a user