diff --git a/colossalai/engine/ophooks/__init__.py b/colossalai/engine/ophooks/__init__.py index b1130dc5d..1f3b2b38f 100644 --- a/colossalai/engine/ophooks/__init__.py +++ b/colossalai/engine/ophooks/__init__.py @@ -1,10 +1,13 @@ -from ._base_ophook import BaseOpHook -from ._memtracer_ophook import MemTracerOpHook -from ._shard_param_ophook import ShardParamHook -import torch from typing import List -all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook"] +import torch + +from ._base_ophook import BaseOpHook +from ._memtracer_ophook import MemTracerOpHook +from ._shard_grad_ophook import ShardGradHook +from ._shard_param_ophook import ShardParamHook + +all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook", "ShardGradHook"] # apply torch.autograd.Function that calls a backward_function to tensors in output diff --git a/colossalai/engine/ophooks/_shard_grad_ophook.py b/colossalai/engine/ophooks/_shard_grad_ophook.py new file mode 100644 index 000000000..582f95802 --- /dev/null +++ b/colossalai/engine/ophooks/_shard_grad_ophook.py @@ -0,0 +1,31 @@ +import torch +from colossalai.registry import OPHOOKS + +from . import BaseOpHook + + +@OPHOOKS.register_module +class ShardGradHook(BaseOpHook): + """ + A hook to process sharded param before and afther FWD and BWD operator executing. + """ + + def __init__(self): + super().__init__() + + def pre_fwd_exec(self, module: torch.nn.Module, *args): + pass + + def post_fwd_exec(self, module: torch.nn.Module, *args): + pass + + def pre_bwd_exec(self, module: torch.nn.Module, input, output): + for param in module.parameters(): + assert hasattr(param, '_sharded_grad') + param._sharded_grad.setup() + + def post_bwd_exec(self, module: torch.nn.Module, input): + pass + + def post_iter(self): + pass diff --git a/colossalai/zero/shard_param/shard_param.py b/colossalai/zero/shard_param/shard_param.py index c575767b8..7bc36470f 100644 --- a/colossalai/zero/shard_param/shard_param.py +++ b/colossalai/zero/shard_param/shard_param.py @@ -1,9 +1,10 @@ from enum import Enum + import torch -from colossalai.zero.sharded_model._zero3_utils import get_shard +import torch.distributed as dist from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -import torch.distributed as dist +from colossalai.zero.sharded_model._zero3_utils import get_shard class TensorType(Enum): @@ -27,9 +28,11 @@ class ShardParam(object): self.world_size = dist.get_world_size(self.process_group) self.local_rank = dist.get_rank(self.process_group) self._param_payload = param.data if tensor_type == TensorType.DATA else param.grad + self._payload_shape = None self._payload_numel = None self._origin_shape = param.shape self._origin_numel = param.numel() + self._origin_dtype = param.dtype self.is_sharded = False def payload(self, target_device: torch.device): @@ -65,3 +68,7 @@ class ShardParam(object): async_op=False) self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape) self.is_sharded = False + + @property + def origin_dtype(self): + return self._origin_dtype diff --git a/colossalai/zero/sharded_model/reduce_scatter.py b/colossalai/zero/sharded_model/reduce_scatter.py index 25f76daf5..8225b7566 100644 --- a/colossalai/zero/sharded_model/reduce_scatter.py +++ b/colossalai/zero/sharded_model/reduce_scatter.py @@ -190,10 +190,6 @@ class ReduceScatterBucketer: return int(bucket_size // num_shards) def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket: - # TODO (Min): the `group` used here in the key is the object hash, not the content - # hash. That means if FSDP instances are initialized with different process groups, - # even when the group members are in fact the same, we end up creating different - # buckets here. key = (tensor.dtype, tensor.device, group) if key not in self.buckets: # buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size) diff --git a/colossalai/zero/sharded_model/sharded_grad.py b/colossalai/zero/sharded_model/sharded_grad.py new file mode 100644 index 000000000..7c8667f1b --- /dev/null +++ b/colossalai/zero/sharded_model/sharded_grad.py @@ -0,0 +1,85 @@ +from typing import Optional + +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter + + +class ShardedGradient: + def __init__(self, + param: Parameter, + sharded_module: nn.Module, + offload_config: Optional[dict] = None + ) -> None: + assert hasattr( + param, 'ca_attr') and param.ca_attr.is_sharded, 'ShardedGradient can only be initialized with sharded parameter' + + self.param = param + self.sharded_module = sharded_module + self.offload_config = offload_config + + self._cpu_offload = offload_config.get('device', None) == 'cpu' if offload_config else False + + # _gpu_grad is either sharded or not + # all saved grads are fp32 + self._gpu_grad: Optional[torch.Tensor] = None + self._cpu_grad: Optional[torch.Tensor] = None + + if self._cpu_offload: + # this buffer will be held and reused every iteration + self._cpu_grad = torch.zeros(param.ca_attr.payload('cpu'), dtype=torch.float).pin_memory() + + @torch.no_grad() + def setup(self) -> None: + """This function will be called pre-backward. Save the local accumulated gradient to _gpu_grad. + When no_sync() is enable (_require_backward_grad_sync=False), the grad is accumulated locally in param.grad + + :raises AssertionError: Raise if grad shape is wrong + """ + if self.sharded_module._require_backward_grad_sync and self.param.grad is not None: + if self.param.grad.device != self.param.data.device: + # TODO: offload? + raise RuntimeError( + 'grad and param are on different device, grad {self.param.grad.device} vs. param {self.param.data.device}') + else: + self._gpu_grad = self.param.grad.data + self.param.grad = None + + def reduce_scatter_callback(self, reduced_grad: torch.Tensor) -> None: + """This function will be called in post-backward hook, so we cannot modify param.grad directly + + :param reduced_grad: the reduced grad + :type reduced_grad: torch.Tensor + """ + # Make sure we store fp32 grad + if torch.is_floating_point(reduced_grad) and reduced_grad.dtype != torch.float: + reduced_grad.data = reduced_grad.data.to(torch.float) + + if self._gpu_grad is None: + self._gpu_grad = reduced_grad.data + else: + self._gpu_grad += reduced_grad.data + + # Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full + # backwards pass completes, we will set `.grad` to the CPU copy. + if self._cpu_offload: + self._cpu_grad.copy_(self._gpu_grad.data, non_blocking=True) + # Don't let this memory get reused until after the transfer. + self._gpu_grad.data.record_stream(torch.cuda.current_stream()) + + @torch.no_grad() + def write_back(self) -> None: + """This function will be called in final backward hook + """ + if self._cpu_grad is not None: + assert self.param.device == torch.device( + 'cpu'), f'Incorrect param device, expected CPU, got {self.param.device}' + self.param.grad.data = self._cpu_grad + elif self._gpu_grad is not None: + assert self.param.device == self._gpu_grad.device, f'Incorrect _gpu_grad device, param on {self.param.device} but _gpu_grad on {self._gpu_grad.device}' + self.param.grad.data = self._gpu_grad + else: + raise RuntimeError('No grad to write back') + # If using CPU offload, _cpu_grad will store the CPU tensor of _gpu_grad + # They should be released here + self._gpu_grad = None diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 0168d443e..a32afdff2 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -1,31 +1,37 @@ -import contextlib -import copy import functools -import os -import traceback -from collections import OrderedDict -from enum import Enum, auto -from typing import (Any, Callable, Dict, Generator, List, NamedTuple, Optional, - Set, Union) +from typing import Any, Optional import torch import torch.distributed as dist import torch.nn as nn from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc +from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, + register_ophooks_recursively) +from colossalai.engine.paramhooks import BaseParamHookMgr from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device -from torch.distributed import ProcessGroup -from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook, ShardParamHook from colossalai.zero.shard_param import ShardParam +from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer +from colossalai.zero.sharded_model.sharded_grad import ShardedGradient +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from ._zero3_utils import chunk_and_pad, get_gradient_predivide_factor + class ShardedModelV2(nn.Module): def __init__(self, module: nn.Module, process_group: Optional[ProcessGroup] = None, - reduce_scatter_process_group: Optional[ProcessGroup] = None - ): + reduce_scatter_process_group: Optional[ProcessGroup] = None, + reduce_scatter_bucket_size_mb: int = 25, + reshard_after_forward: bool = True, + mixed_precision: bool = False, + fp32_reduce_scatter: bool = False, + offload_config: Optional[dict] = None, + gradient_predivide_factor: Optional[float] = 1.0, + ): r""" A demo to reconfigure zero1 shared_model. Currently do not consider the Optimizer States. @@ -45,19 +51,111 @@ class ShardedModelV2(nn.Module): for _, param in self.module.named_parameters(): param.ca_attr = ShardParam(param) param.ca_attr.shard() + param._sharded_grad = ShardedGradient(param, self, offload_config) # Register hooks - register_ophooks_recursively(self.module, [ShardParamHook()]) + register_ophooks_recursively(self.module, [ShardParamHook(), ShardGradHook()]) + self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters())) + self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) + + self.reshard_after_forward = reshard_after_forward + self.mixed_precision = mixed_precision + self.fp32_reduce_scatter = fp32_reduce_scatter + self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False + # We find if gradient_predivide_factor != 1.0, there may be wrong precision problem + # So we use 1.0 as the default gradient_predivide_factor + # However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically + self.gradient_predivide_factor: float = gradient_predivide_factor if gradient_predivide_factor is not None else \ + get_gradient_predivide_factor(self.world_size) + self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor + + self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() + self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb) + self._require_backward_grad_sync: bool = True def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: outputs = self.module(*args, **kwargs) return outputs - def backward(self, loss): - if self.loss_scaler: - self.loss_scaler.backward(loss) - else: - loss.backward() - - \ No newline at end of file + loss.backward() + self._final_backward_hook() + + @torch.no_grad() + def _final_backward_hook(self) -> None: + if self._require_backward_grad_sync: + # Flush any unreduced buckets in the post_backward stream. + with torch.cuda.stream(self.comm_stream): + self.reducer.flush() + torch.cuda.current_stream().wait_stream(self.comm_stream) + if self._cpu_offload: + # Wait for the non-blocking GPU -> CPU grad transfers to finish. + torch.cuda.current_stream().synchronize() + self.reducer.free() + for p in self.module.parameters(): + if not p.requires_grad: + continue + # Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad + # remains the unsharded gradient accumulated from prior no-sync passes, and _saved_grad_shard + # remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and + # sync passes, if desired. + if not self._require_backward_grad_sync: + continue + p._sharded_grad.write_back() + + @torch.no_grad() + def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: + """ + At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the + full gradient for the local batch. The reduce-scatter op will save a single shard of the summed gradient across all + GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example:: + + before reduce_scatter: + param.grad (GPU #0): [1, 2, 3, 4] + param.grad (GPU #1): [5, 6, 7, 8] + + after reduce_scatter: + param.grad (GPU #0): [6, 8] # 1+5, 2+6 + param.grad (GPU #1): [10, 12] # 3+7, 4+8 + + The local GPU's ``optim.step`` is responsible for updating a single + shard of params, also corresponding to the current GPU's rank. This + alignment is created by `param._sharded_grad`, which ensures that + the local optimizer only sees the relevant parameter shard. + """ + if grad is None: + return + assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients' + if not self._require_backward_grad_sync: + return + self.comm_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.comm_stream): + new_grad = grad.clone() + if self.mixed_precision and self.fp32_reduce_scatter: + new_grad.data = new_grad.data.to(param.dtype) + if self.gradient_predivide_factor > 1.0: + # Average grad by world_size for consistency with PyTorch DDP. + new_grad.data.div_(self.gradient_predivide_factor) + orig_grad_data = new_grad.data + if self.world_size > 1: + grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size()) + self.reducer.reduce_scatter_async( + grad_chunks, group=self.reduce_scatter_process_group, callback_fn=functools.partial(self._reduce_scatter_callback, param)) + else: + self._reduce_scatter_callback(param, new_grad) + orig_grad_data.record_stream(self.comm_stream) + + def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None: + if self.gradient_postdivide_factor > 1: + # Average grad by world_size for consistency with PyTorch DDP. + reduced_grad.data.div_(self.gradient_postdivide_factor) + # Cast grad to param's dtype (typically FP32). Note: we do this + # before the cpu offload step so that this entire hook remains + # non-blocking. The downside is a bit more D2H transfer in that case. + if self.mixed_precision: + orig_param_grad_data = reduced_grad.data + reduced_grad.data = reduced_grad.data.to(dtype=param.ca_attr.origin_dtype) + # Don't let this memory get reused until after the transfer. + orig_param_grad_data.record_stream(torch.cuda.current_stream()) + + param._sharded_grad.reduce_scatter_callback(reduced_grad) diff --git a/tests/test_zero_data_parallel/common.py b/tests/test_zero_data_parallel/common.py index 353d759cb..351831f97 100644 --- a/tests/test_zero_data_parallel/common.py +++ b/tests/test_zero_data_parallel/common.py @@ -1,9 +1,10 @@ from functools import partial -from operator import imod -from colossalai.utils import checkpoint -import torch.nn as nn + import torch +import torch.distributed as dist +import torch.nn as nn from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.utils import checkpoint LOGGER = get_dist_logger() @@ -34,6 +35,7 @@ CONFIG = dict( ) ) + def checkpoint_wrapper(module, enable=True): if enable: module.forward = partial(checkpoint, module.forward) @@ -61,6 +63,7 @@ class Net(nn.Module): x = layer(x) return x + def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: if loose: return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) @@ -72,7 +75,8 @@ def check_grads(model, zero_model, loose=False): zero_grad = zero_p.grad.clone().to(p.device) assert p.grad.dtype == zero_grad.dtype assert allclose(p.grad, zero_grad, loose=loose) - LOGGER.info(torch.sum(p.grad-zero_grad)) + LOGGER.info(torch.sum(p.grad - zero_grad)) + def check_params(model, zero_model, loose=False): for p, zero_p in zip(model.parameters(), zero_model.parameters()): @@ -80,3 +84,30 @@ def check_params(model, zero_model, loose=False): assert p.dtype == zero_p.dtype assert allclose(p, zero_p, loose=loose) + +def check_grads_padding(model, zero_model, loose=False): + rank = dist.get_rank() + for p, zero_p in zip(model.parameters(), zero_model.parameters()): + zero_grad = zero_p.grad.clone().to(p.device) + chunks = torch.flatten(p.grad).chunk(dist.get_world_size()) + if rank >= len(chunks): + continue + grad = chunks[rank] + if zero_grad.size(0) > grad.size(0): + zero_grad = zero_grad[:grad.size(0)] + assert grad.dtype == zero_grad.dtype + assert allclose(grad, zero_grad, loose=loose) + + +def check_params_padding(model, zero_model, loose=False): + rank = dist.get_rank() + for p, zero_p in zip(model.parameters(), zero_model.parameters()): + zero_p = zero_p.clone().to(p.device) + chunks = torch.flatten(p).chunk(dist.get_world_size()) + if rank >= len(chunks): + continue + p = chunks[rank] + if zero_p.size(0) > p.size(0): + zero_p = zero_p[:p.size(0)] + assert p.dtype == zero_p.dtype + assert allclose(p, zero_p, loose=loose) diff --git a/tests/test_zero_data_parallel/test_shard_model_v2.py b/tests/test_zero_data_parallel/test_shard_model_v2.py index e25224dca..175abac10 100644 --- a/tests/test_zero_data_parallel/test_shard_model_v2.py +++ b/tests/test_zero_data_parallel/test_shard_model_v2.py @@ -3,19 +3,18 @@ import copy from functools import partial -from operator import mod -from pyexpat import model import colossalai import pytest import torch +import torch.distributed as dist import torch.multiprocessing as mp -from colossalai.logging import disable_existing_loggers +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc from colossalai.utils import free_port from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.core import global_context as gpc -from colossalai.context.parallel_mode import ParallelMode -from tests.test_zero_data_parallel.common import Net, CONFIG, check_grads + +from common import CONFIG, Net, check_grads, check_grads_padding def run_fwd_bwd(model, x, enable_autocast=False): @@ -24,8 +23,11 @@ def run_fwd_bwd(model, x, enable_autocast=False): y = model(x) loss = y.sum() loss = loss.float() - loss.backward() - + if isinstance(model, ShardedModelV2): + model.backward(loss) + else: + loss.backward() + def run_dist(rank, world_size, port): colossalai.launch(config=CONFIG, @@ -34,7 +36,7 @@ def run_dist(rank, world_size, port): host='localhost', port=port, backend='nccl') - + model = Net(checkpoint=True).cuda() zero_model = copy.deepcopy(model) zero_model = ShardedModelV2(zero_model, process_group=gpc.get_group(ParallelMode.DATA)) @@ -43,7 +45,10 @@ def run_dist(rank, world_size, port): x = torch.rand(2, 5).cuda() run_fwd_bwd(zero_model, x, False) run_fwd_bwd(model, x, False) - check_grads(model, zero_model) + if dist.get_world_size() > 1: + check_grads_padding(model, zero_model) + else: + check_grads(model, zero_model) @pytest.mark.dist diff --git a/tests/test_zero_data_parallel/test_zero_dev_3_mp4.py b/tests/test_zero_data_parallel/test_zero_dev_3_mp4.py index bfc805a89..a3ce53eeb 100644 --- a/tests/test_zero_data_parallel/test_zero_dev_3_mp4.py +++ b/tests/test_zero_data_parallel/test_zero_dev_3_mp4.py @@ -14,7 +14,9 @@ from colossalai.logging import disable_existing_loggers from colossalai.utils import checkpoint, free_port from colossalai.zero.sharded_model import ShardedModel from torch.nn.parallel import DistributedDataParallel as DDP -from common import Net, allclose + +from common import Net, check_grads_padding, check_params_padding + def run_step(model, optimizer, x, enable_autocast=False): model.train() @@ -26,34 +28,6 @@ def run_step(model, optimizer, x, enable_autocast=False): loss.backward() optimizer.step() -def check_grads_padding(model, zero_model, loose=False): - rank = dist.get_rank() - for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_grad = zero_p.grad.clone().to(p.device) - chunks = torch.flatten(p.grad).chunk(4) - if rank >= len(chunks): - continue - grad = chunks[rank] - if zero_p.zero_shard_padding > 0: - zero_grad = zero_grad[:-zero_p.zero_shard_padding] - assert grad.dtype == zero_grad.dtype - assert allclose(grad, zero_grad, loose=loose) - - -def check_params_padding(model, zero_model, loose=False): - rank = dist.get_rank() - for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_shard_padding = zero_p.zero_shard_padding - zero_p = zero_p.clone().to(p.device) - chunks = torch.flatten(p).chunk(4) - if rank >= len(chunks): - continue - p = chunks[rank] - if zero_shard_padding > 0: - zero_p = zero_p[:-zero_shard_padding] - assert p.dtype == zero_p.dtype - assert allclose(p, zero_p, loose=loose) - def decode_booleans(intval, bits): res = []