ColossalAI/colossalai/zero/sharded_model/reduce_scatter.py
Jiarui Fang 5a560a060a Feature/zero (#279)
* add zero1 (#209)

* add zero1

* add test zero1

* update zero stage 1 develop (#212)

* Implement naive zero3 (#240)

* naive zero3 works well

* add zero3 param manager

* add TODOs in comments

* add gather full param ctx

* fix sub module streams

* add offload

* fix bugs of hook and add unit tests

* fix bugs of hook and add unit tests (#252)

* add gather full param ctx

* fix sub module streams

* add offload

* fix bugs of hook and add unit tests

* polish code and add state dict hook

* fix bug

* update unit test

* refactor reconstructed zero code

* clip_grad support zero3 and add unit test

* add unit test for Zero3ParameterManager

* [WIP] initialize the shard param class

* [WIP] Yet another sharded model implementation (#274)

* [WIP] initialize the shard param class

* [WIP] Yes another implementation of shardModel. Using a better hook method.

* torch.concat -> torch.cat

* fix test_zero_level_1.py::test_zero_level_1 unitest

* remove deepspeed implementation and refactor for the reconstructed zero module

* polish zero dp unittests

Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
2022-03-11 15:50:28 +08:00

205 lines
8.4 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
import os
from typing import Callable, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup
# TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
enable_nccl_base_collectives = False
else:
enable_nccl_base_collectives = True
class Bucket:
def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device)
self.group = group
self.offset = 0
self.callbacks: List[Callable] = []
self.output_shard = torch.zeros_like(self.buffer[0])
def flush(self) -> None:
"""Flush content of the bucket."""
if self.offset == 0:
assert len(self.callbacks) == 0
return
# reduce-scatter bucket
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
dist._reduce_scatter_base(
self.output_shard[: self.offset], self.buffer[:, : self.offset].contiguous(), group=self.group
)
else:
dist.reduce_scatter(
self.output_shard[: self.offset], list(self.buffer[:, : self.offset].unbind(0)), group=self.group
)
# execute post-reduction callbacks
for callback_fn in self.callbacks:
callback_fn()
# reuse input bucket but allocate a fresh output shard
self.buffer[:, : self.offset].zero_()
self.offset = 0
self.callbacks.clear()
self.output_shard = torch.zeros_like(self.buffer[0])
def alloc(self) -> None:
"""Setup the buffers if they are not allocated.
Using ``setup`` and ``teardown``, we can ensure that the bucket
buffers are only allocated during the backward pass, hence saving more
memory to other parts of the training process, such as the forward pass
for activation memory.
"""
for tensor in [self.buffer, self.output_shard]:
if tensor.storage().size() == 0:
tensor.storage().resize_(tensor.size().numel())
def free(self) -> None:
"""Tear down the bucket by freeing the memory"""
assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
for tensor in [self.buffer, self.output_shard]:
tensor.storage().resize_(0)
def append(self, tensor_list: List[Tensor], callback_fn: Callable):
# copy data from input_list into bucket
tensor_size = tensor_list[0].numel()
stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size)
offset = self.offset
self.buffer[:, offset: offset + tensor_size].copy_(stacked_input)
self.offset += tensor_size
# callback will be given the reduced result
if callback_fn is not None:
result_view = self.output_shard[offset: offset + tensor_size].view_as(tensor_list[0])
self.callbacks.append(functools.partial(callback_fn, result_view))
class ReduceScatterBucketer:
"""
Helper for bucketing multiple reduce-scatter operations on small tensors
into larger reduce-scatter ops to improve communication efficiency.
Usage::
bucketer = ReduceScatterBucketer()
bucketer.reduce_scatter_async(
small_tensors, callback_fn=lambda result: print("small")
)
bucketer.reduce_scatter_async(
big_tensors, callback_fn=lambda result: print("big")
)
bucketer.reduce_scatter_async(
more_small_tensors, callback_fn=lambda result: print("small2")
)
bucketer.flush() # callbacks only guaranteed to be called after flush()
# Example output (note that it is out of order, due to bucketing):
# big
# small
# small2
Args:
bucket_size_mb (int, Optional): bucket size for communicating. Buckets
are sub-divided based on world_size. Values <= 0 disable bucketing.
"""
def __init__(self, bucket_size_mb: int = 25):
self.bucket_size_mb = bucket_size_mb
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}
@torch.no_grad()
def reduce_scatter_async(
self,
input_list: List[Tensor],
group: ProcessGroup,
callback_fn: Optional[Callable] = None,
) -> None:
"""
Reduce-scatter a list of tensors asynchronously, so smaller reductions
can be bucketed together. The given callback (``callback_fn``) will be
called with the reduced result at some later time. Call ``flush()`` to
force all queued ops and callbacks to be executed.
Note that large inputs will be reduced immediately, and this function
may also flush the relevant bucket to make room for ``input_list``.
Args:
input_list (List[Tensor]): list of tensors to reduce-scatter. List
should contain ``group.size()`` tensors and each tensor should
have identical shape, dtype and device.
group (ProcessGroup): process group for reduction
callback_fn (Callable, Optional): callback function to call after
the reduction executes. Function will be called with a single
argument corresponding to the reduced result.
"""
world_size = group.size()
assert (
len(input_list) == world_size
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
first_input = input_list[0]
first_input_size = first_input.numel()
bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size)
if first_input_size > bucket_shard_size:
# TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors)
# input is too big to fit in the bucket, reduce-scatter directly
output = torch.zeros_like(input_list[0])
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
input_flattened = torch.cat(input_list)
dist._reduce_scatter_base(output, input_flattened, group=group)
else:
# fallback
dist.reduce_scatter(output, input_list, group=group)
if callback_fn is not None:
callback_fn(output)
return
bucket = self._get_bucket(first_input, group)
if first_input_size > bucket.buffer.size(1) - bucket.offset:
# not enough space remaining in bucket, flush it now
bucket.flush()
bucket.append(input_list, callback_fn)
@torch.no_grad()
def flush(self) -> None:
"""Reduce-scatter any partial buckets."""
for bucket in self.buckets.values():
bucket.flush()
@torch.no_grad()
def free(self) -> None:
"""Free buffers from all buckets."""
for bucket in self.buckets.values():
bucket.free()
@functools.lru_cache()
def _get_shard_size(self, element_size: int, num_shards: int) -> int:
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
return 0
MB = 1024 * 1024
bucket_size = self.bucket_size_mb * MB / element_size
return int(bucket_size // num_shards)
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
# TODO (Min): the `group` used here in the key is the object hash, not the content
# hash. That means if FSDP instances are initialized with different process groups,
# even when the group members are in fact the same, we end up creating different
# buckets here.
key = (tensor.dtype, tensor.device, group)
if key not in self.buckets:
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
world_size = group.size()
shard_size = self._get_shard_size(tensor.element_size(), world_size)
self.buckets[key] = Bucket(shard_size, tensor.dtype, tensor.device, group)
self.buckets[key].alloc()
return self.buckets[key]