[zero] add sharded grad and refactor grad hooks for ShardedModel (#287)

This commit is contained in:
ver217 2022-03-02 18:28:29 +08:00 committed by GitHub
parent 4fbb8db586
commit 9b07ac81d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 305 additions and 75 deletions

View File

@ -1,10 +1,13 @@
from ._base_ophook import BaseOpHook
from ._memtracer_ophook import MemTracerOpHook
from ._shard_param_ophook import ShardParamHook
import torch
from typing import List from typing import List
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook"] import torch
from ._base_ophook import BaseOpHook
from ._memtracer_ophook import MemTracerOpHook
from ._shard_grad_ophook import ShardGradHook
from ._shard_param_ophook import ShardParamHook
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook", "ShardGradHook"]
# apply torch.autograd.Function that calls a backward_function to tensors in output # apply torch.autograd.Function that calls a backward_function to tensors in output

View File

@ -0,0 +1,31 @@
import torch
from colossalai.registry import OPHOOKS
from . import BaseOpHook
@OPHOOKS.register_module
class ShardGradHook(BaseOpHook):
"""
A hook to process sharded param before and afther FWD and BWD operator executing.
"""
def __init__(self):
super().__init__()
def pre_fwd_exec(self, module: torch.nn.Module, *args):
pass
def post_fwd_exec(self, module: torch.nn.Module, *args):
pass
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
for param in module.parameters():
assert hasattr(param, '_sharded_grad')
param._sharded_grad.setup()
def post_bwd_exec(self, module: torch.nn.Module, input):
pass
def post_iter(self):
pass

View File

@ -1,9 +1,10 @@
from enum import Enum from enum import Enum
import torch import torch
from colossalai.zero.sharded_model._zero3_utils import get_shard import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
import torch.distributed as dist from colossalai.zero.sharded_model._zero3_utils import get_shard
class TensorType(Enum): class TensorType(Enum):
@ -27,9 +28,11 @@ class ShardParam(object):
self.world_size = dist.get_world_size(self.process_group) self.world_size = dist.get_world_size(self.process_group)
self.local_rank = dist.get_rank(self.process_group) self.local_rank = dist.get_rank(self.process_group)
self._param_payload = param.data if tensor_type == TensorType.DATA else param.grad self._param_payload = param.data if tensor_type == TensorType.DATA else param.grad
self._payload_shape = None
self._payload_numel = None self._payload_numel = None
self._origin_shape = param.shape self._origin_shape = param.shape
self._origin_numel = param.numel() self._origin_numel = param.numel()
self._origin_dtype = param.dtype
self.is_sharded = False self.is_sharded = False
def payload(self, target_device: torch.device): def payload(self, target_device: torch.device):
@ -65,3 +68,7 @@ class ShardParam(object):
async_op=False) async_op=False)
self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape) self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
self.is_sharded = False self.is_sharded = False
@property
def origin_dtype(self):
return self._origin_dtype

View File

@ -190,10 +190,6 @@ class ReduceScatterBucketer:
return int(bucket_size // num_shards) return int(bucket_size // num_shards)
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket: def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
# TODO (Min): the `group` used here in the key is the object hash, not the content
# hash. That means if FSDP instances are initialized with different process groups,
# even when the group members are in fact the same, we end up creating different
# buckets here.
key = (tensor.dtype, tensor.device, group) key = (tensor.dtype, tensor.device, group)
if key not in self.buckets: if key not in self.buckets:
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size) # buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)

View File

@ -0,0 +1,85 @@
from typing import Optional
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
class ShardedGradient:
def __init__(self,
param: Parameter,
sharded_module: nn.Module,
offload_config: Optional[dict] = None
) -> None:
assert hasattr(
param, 'ca_attr') and param.ca_attr.is_sharded, 'ShardedGradient can only be initialized with sharded parameter'
self.param = param
self.sharded_module = sharded_module
self.offload_config = offload_config
self._cpu_offload = offload_config.get('device', None) == 'cpu' if offload_config else False
# _gpu_grad is either sharded or not
# all saved grads are fp32
self._gpu_grad: Optional[torch.Tensor] = None
self._cpu_grad: Optional[torch.Tensor] = None
if self._cpu_offload:
# this buffer will be held and reused every iteration
self._cpu_grad = torch.zeros(param.ca_attr.payload('cpu'), dtype=torch.float).pin_memory()
@torch.no_grad()
def setup(self) -> None:
"""This function will be called pre-backward. Save the local accumulated gradient to _gpu_grad.
When no_sync() is enable (_require_backward_grad_sync=False), the grad is accumulated locally in param.grad
:raises AssertionError: Raise if grad shape is wrong
"""
if self.sharded_module._require_backward_grad_sync and self.param.grad is not None:
if self.param.grad.device != self.param.data.device:
# TODO: offload?
raise RuntimeError(
'grad and param are on different device, grad {self.param.grad.device} vs. param {self.param.data.device}')
else:
self._gpu_grad = self.param.grad.data
self.param.grad = None
def reduce_scatter_callback(self, reduced_grad: torch.Tensor) -> None:
"""This function will be called in post-backward hook, so we cannot modify param.grad directly
:param reduced_grad: the reduced grad
:type reduced_grad: torch.Tensor
"""
# Make sure we store fp32 grad
if torch.is_floating_point(reduced_grad) and reduced_grad.dtype != torch.float:
reduced_grad.data = reduced_grad.data.to(torch.float)
if self._gpu_grad is None:
self._gpu_grad = reduced_grad.data
else:
self._gpu_grad += reduced_grad.data
# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
# backwards pass completes, we will set `.grad` to the CPU copy.
if self._cpu_offload:
self._cpu_grad.copy_(self._gpu_grad.data, non_blocking=True)
# Don't let this memory get reused until after the transfer.
self._gpu_grad.data.record_stream(torch.cuda.current_stream())
@torch.no_grad()
def write_back(self) -> None:
"""This function will be called in final backward hook
"""
if self._cpu_grad is not None:
assert self.param.device == torch.device(
'cpu'), f'Incorrect param device, expected CPU, got {self.param.device}'
self.param.grad.data = self._cpu_grad
elif self._gpu_grad is not None:
assert self.param.device == self._gpu_grad.device, f'Incorrect _gpu_grad device, param on {self.param.device} but _gpu_grad on {self._gpu_grad.device}'
self.param.grad.data = self._gpu_grad
else:
raise RuntimeError('No grad to write back')
# If using CPU offload, _cpu_grad will store the CPU tensor of _gpu_grad
# They should be released here
self._gpu_grad = None

View File

@ -1,31 +1,37 @@
import contextlib
import copy
import functools import functools
import os from typing import Any, Optional
import traceback
from collections import OrderedDict
from enum import Enum, auto
from typing import (Any, Callable, Dict, Generator, List, NamedTuple, Optional,
Set, Union)
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook,
register_ophooks_recursively)
from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from torch.distributed import ProcessGroup
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook, ShardParamHook
from colossalai.zero.shard_param import ShardParam from colossalai.zero.shard_param import ShardParam
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_model.sharded_grad import ShardedGradient
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from ._zero3_utils import chunk_and_pad, get_gradient_predivide_factor
class ShardedModelV2(nn.Module): class ShardedModelV2(nn.Module):
def __init__(self, def __init__(self,
module: nn.Module, module: nn.Module,
process_group: Optional[ProcessGroup] = None, process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None reduce_scatter_process_group: Optional[ProcessGroup] = None,
): reduce_scatter_bucket_size_mb: int = 25,
reshard_after_forward: bool = True,
mixed_precision: bool = False,
fp32_reduce_scatter: bool = False,
offload_config: Optional[dict] = None,
gradient_predivide_factor: Optional[float] = 1.0,
):
r""" r"""
A demo to reconfigure zero1 shared_model. A demo to reconfigure zero1 shared_model.
Currently do not consider the Optimizer States. Currently do not consider the Optimizer States.
@ -45,19 +51,111 @@ class ShardedModelV2(nn.Module):
for _, param in self.module.named_parameters(): for _, param in self.module.named_parameters():
param.ca_attr = ShardParam(param) param.ca_attr = ShardParam(param)
param.ca_attr.shard() param.ca_attr.shard()
param._sharded_grad = ShardedGradient(param, self, offload_config)
# Register hooks # Register hooks
register_ophooks_recursively(self.module, [ShardParamHook()]) register_ophooks_recursively(self.module, [ShardParamHook(), ShardGradHook()])
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
self.reshard_after_forward = reshard_after_forward
self.mixed_precision = mixed_precision
self.fp32_reduce_scatter = fp32_reduce_scatter
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor
# However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically
self.gradient_predivide_factor: float = gradient_predivide_factor if gradient_predivide_factor is not None else \
get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb)
self._require_backward_grad_sync: bool = True
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
outputs = self.module(*args, **kwargs) outputs = self.module(*args, **kwargs)
return outputs return outputs
def backward(self, loss): def backward(self, loss):
if self.loss_scaler: loss.backward()
self.loss_scaler.backward(loss) self._final_backward_hook()
else:
loss.backward() @torch.no_grad()
def _final_backward_hook(self) -> None:
if self._require_backward_grad_sync:
# Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self.comm_stream):
self.reducer.flush()
torch.cuda.current_stream().wait_stream(self.comm_stream)
if self._cpu_offload:
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch.cuda.current_stream().synchronize()
self.reducer.free()
for p in self.module.parameters():
if not p.requires_grad:
continue
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and _saved_grad_shard
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
# sync passes, if desired.
if not self._require_backward_grad_sync:
continue
p._sharded_grad.write_back()
@torch.no_grad()
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
"""
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will save a single shard of the summed gradient across all
GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example::
before reduce_scatter:
param.grad (GPU #0): [1, 2, 3, 4]
param.grad (GPU #1): [5, 6, 7, 8]
after reduce_scatter:
param.grad (GPU #0): [6, 8] # 1+5, 2+6
param.grad (GPU #1): [10, 12] # 3+7, 4+8
The local GPU's ``optim.step`` is responsible for updating a single
shard of params, also corresponding to the current GPU's rank. This
alignment is created by `param._sharded_grad`, which ensures that
the local optimizer only sees the relevant parameter shard.
"""
if grad is None:
return
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
if not self._require_backward_grad_sync:
return
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
new_grad = grad.clone()
if self.mixed_precision and self.fp32_reduce_scatter:
new_grad.data = new_grad.data.to(param.dtype)
if self.gradient_predivide_factor > 1.0:
# Average grad by world_size for consistency with PyTorch DDP.
new_grad.data.div_(self.gradient_predivide_factor)
orig_grad_data = new_grad.data
if self.world_size > 1:
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
self.reducer.reduce_scatter_async(
grad_chunks, group=self.reduce_scatter_process_group, callback_fn=functools.partial(self._reduce_scatter_callback, param))
else:
self._reduce_scatter_callback(param, new_grad)
orig_grad_data.record_stream(self.comm_stream)
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.data.div_(self.gradient_postdivide_factor)
# Cast grad to param's dtype (typically FP32). Note: we do this
# before the cpu offload step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
if self.mixed_precision:
orig_param_grad_data = reduced_grad.data
reduced_grad.data = reduced_grad.data.to(dtype=param.ca_attr.origin_dtype)
# Don't let this memory get reused until after the transfer.
orig_param_grad_data.record_stream(torch.cuda.current_stream())
param._sharded_grad.reduce_scatter_callback(reduced_grad)

View File

@ -1,9 +1,10 @@
from functools import partial from functools import partial
from operator import imod
from colossalai.utils import checkpoint
import torch.nn as nn
import torch import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import checkpoint
LOGGER = get_dist_logger() LOGGER = get_dist_logger()
@ -34,6 +35,7 @@ CONFIG = dict(
) )
) )
def checkpoint_wrapper(module, enable=True): def checkpoint_wrapper(module, enable=True):
if enable: if enable:
module.forward = partial(checkpoint, module.forward) module.forward = partial(checkpoint, module.forward)
@ -61,6 +63,7 @@ class Net(nn.Module):
x = layer(x) x = layer(x)
return x return x
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
if loose: if loose:
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
@ -72,7 +75,8 @@ def check_grads(model, zero_model, loose=False):
zero_grad = zero_p.grad.clone().to(p.device) zero_grad = zero_p.grad.clone().to(p.device)
assert p.grad.dtype == zero_grad.dtype assert p.grad.dtype == zero_grad.dtype
assert allclose(p.grad, zero_grad, loose=loose) assert allclose(p.grad, zero_grad, loose=loose)
LOGGER.info(torch.sum(p.grad-zero_grad)) LOGGER.info(torch.sum(p.grad - zero_grad))
def check_params(model, zero_model, loose=False): def check_params(model, zero_model, loose=False):
for p, zero_p in zip(model.parameters(), zero_model.parameters()): for p, zero_p in zip(model.parameters(), zero_model.parameters()):
@ -80,3 +84,30 @@ def check_params(model, zero_model, loose=False):
assert p.dtype == zero_p.dtype assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose) assert allclose(p, zero_p, loose=loose)
def check_grads_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_grad = zero_p.grad.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
grad = chunks[rank]
if zero_grad.size(0) > grad.size(0):
zero_grad = zero_grad[:grad.size(0)]
assert grad.dtype == zero_grad.dtype
assert allclose(grad, zero_grad, loose=loose)
def check_params_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_p = zero_p.clone().to(p.device)
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
p = chunks[rank]
if zero_p.size(0) > p.size(0):
zero_p = zero_p[:p.size(0)]
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose)

View File

@ -3,19 +3,18 @@
import copy import copy
from functools import partial from functools import partial
from operator import mod
from pyexpat import model
import colossalai import colossalai
import pytest import pytest
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.logging import disable_existing_loggers from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode from common import CONFIG, Net, check_grads, check_grads_padding
from tests.test_zero_data_parallel.common import Net, CONFIG, check_grads
def run_fwd_bwd(model, x, enable_autocast=False): def run_fwd_bwd(model, x, enable_autocast=False):
@ -24,8 +23,11 @@ def run_fwd_bwd(model, x, enable_autocast=False):
y = model(x) y = model(x)
loss = y.sum() loss = y.sum()
loss = loss.float() loss = loss.float()
loss.backward() if isinstance(model, ShardedModelV2):
model.backward(loss)
else:
loss.backward()
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, colossalai.launch(config=CONFIG,
@ -34,7 +36,7 @@ def run_dist(rank, world_size, port):
host='localhost', host='localhost',
port=port, port=port,
backend='nccl') backend='nccl')
model = Net(checkpoint=True).cuda() model = Net(checkpoint=True).cuda()
zero_model = copy.deepcopy(model) zero_model = copy.deepcopy(model)
zero_model = ShardedModelV2(zero_model, process_group=gpc.get_group(ParallelMode.DATA)) zero_model = ShardedModelV2(zero_model, process_group=gpc.get_group(ParallelMode.DATA))
@ -43,7 +45,10 @@ def run_dist(rank, world_size, port):
x = torch.rand(2, 5).cuda() x = torch.rand(2, 5).cuda()
run_fwd_bwd(zero_model, x, False) run_fwd_bwd(zero_model, x, False)
run_fwd_bwd(model, x, False) run_fwd_bwd(model, x, False)
check_grads(model, zero_model) if dist.get_world_size() > 1:
check_grads_padding(model, zero_model)
else:
check_grads(model, zero_model)
@pytest.mark.dist @pytest.mark.dist

View File

@ -14,7 +14,9 @@ from colossalai.logging import disable_existing_loggers
from colossalai.utils import checkpoint, free_port from colossalai.utils import checkpoint, free_port
from colossalai.zero.sharded_model import ShardedModel from colossalai.zero.sharded_model import ShardedModel
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from common import Net, allclose
from common import Net, check_grads_padding, check_params_padding
def run_step(model, optimizer, x, enable_autocast=False): def run_step(model, optimizer, x, enable_autocast=False):
model.train() model.train()
@ -26,34 +28,6 @@ def run_step(model, optimizer, x, enable_autocast=False):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
def check_grads_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_grad = zero_p.grad.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(4)
if rank >= len(chunks):
continue
grad = chunks[rank]
if zero_p.zero_shard_padding > 0:
zero_grad = zero_grad[:-zero_p.zero_shard_padding]
assert grad.dtype == zero_grad.dtype
assert allclose(grad, zero_grad, loose=loose)
def check_params_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_shard_padding = zero_p.zero_shard_padding
zero_p = zero_p.clone().to(p.device)
chunks = torch.flatten(p).chunk(4)
if rank >= len(chunks):
continue
p = chunks[rank]
if zero_shard_padding > 0:
zero_p = zero_p[:-zero_shard_padding]
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose)
def decode_booleans(intval, bits): def decode_booleans(intval, bits):
res = [] res = []