mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci
This commit is contained in:
200
colossalai/legacy/zero/sharded_model/reduce_scatter.py
Normal file
200
colossalai/legacy/zero/sharded_model/reduce_scatter.py
Normal file
@@ -0,0 +1,200 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
# TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved.
|
||||
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
|
||||
enable_nccl_base_collectives = False
|
||||
else:
|
||||
enable_nccl_base_collectives = True
|
||||
|
||||
|
||||
class Bucket:
|
||||
|
||||
def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
|
||||
self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device)
|
||||
self.group = group
|
||||
self.offset = 0
|
||||
self.callbacks: List[Callable] = []
|
||||
self.output_shard = torch.zeros_like(self.buffer[0])
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Flush content of the bucket."""
|
||||
if self.offset == 0:
|
||||
assert len(self.callbacks) == 0
|
||||
return
|
||||
# reduce-scatter bucket
|
||||
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
|
||||
dist._reduce_scatter_base(self.output_shard[:self.offset],
|
||||
self.buffer[:, :self.offset].contiguous(),
|
||||
group=self.group)
|
||||
else:
|
||||
dist.reduce_scatter(self.output_shard[:self.offset],
|
||||
list(self.buffer[:, :self.offset].unbind(0)),
|
||||
group=self.group)
|
||||
# execute post-reduction callbacks
|
||||
for callback_fn in self.callbacks:
|
||||
callback_fn()
|
||||
# reuse input bucket but allocate a fresh output shard
|
||||
self.buffer[:, :self.offset].zero_()
|
||||
self.offset = 0
|
||||
self.callbacks.clear()
|
||||
self.output_shard = torch.zeros_like(self.buffer[0])
|
||||
|
||||
def alloc(self) -> None:
|
||||
"""Setup the buffers if they are not allocated.
|
||||
|
||||
Using ``setup`` and ``teardown``, we can ensure that the bucket
|
||||
buffers are only allocated during the backward pass, hence saving more
|
||||
memory to other parts of the training process, such as the forward pass
|
||||
for activation memory.
|
||||
"""
|
||||
for tensor in [self.buffer, self.output_shard]:
|
||||
if tensor.storage().size() == 0:
|
||||
tensor.storage().resize_(tensor.size().numel())
|
||||
|
||||
def free(self) -> None:
|
||||
"""Tear down the bucket by freeing the memory"""
|
||||
assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
|
||||
for tensor in [self.buffer, self.output_shard]:
|
||||
tensor.storage().resize_(0)
|
||||
|
||||
def append(self, tensor_list: List[Tensor], callback_fn: Callable):
|
||||
# copy data from input_list into bucket
|
||||
tensor_size = tensor_list[0].numel()
|
||||
stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size)
|
||||
offset = self.offset
|
||||
self.buffer[:, offset:offset + tensor_size].copy_(stacked_input)
|
||||
self.offset += tensor_size
|
||||
|
||||
# callback will be given the reduced result
|
||||
if callback_fn is not None:
|
||||
result_view = self.output_shard[offset:offset + tensor_size].view_as(tensor_list[0])
|
||||
self.callbacks.append(functools.partial(callback_fn, result_view))
|
||||
|
||||
|
||||
class ReduceScatterBucketer:
|
||||
"""
|
||||
Helper for bucketing multiple reduce-scatter operations on small tensors
|
||||
into larger reduce-scatter ops to improve communication efficiency.
|
||||
|
||||
Usage::
|
||||
|
||||
bucketer = ReduceScatterBucketer()
|
||||
bucketer.reduce_scatter_async(
|
||||
small_tensors, callback_fn=lambda result: print("small")
|
||||
)
|
||||
bucketer.reduce_scatter_async(
|
||||
big_tensors, callback_fn=lambda result: print("big")
|
||||
)
|
||||
bucketer.reduce_scatter_async(
|
||||
more_small_tensors, callback_fn=lambda result: print("small2")
|
||||
)
|
||||
bucketer.flush() # callbacks only guaranteed to be called after flush()
|
||||
# Example output (note that it is out of order, due to bucketing):
|
||||
# big
|
||||
# small
|
||||
# small2
|
||||
|
||||
Args:
|
||||
bucket_size_mb (int, Optional): bucket size for communicating. Buckets
|
||||
are sub-divided based on world_size. Values <= 0 disable bucketing.
|
||||
"""
|
||||
|
||||
def __init__(self, bucket_size_mb: int = 25):
|
||||
self.bucket_size_mb = bucket_size_mb
|
||||
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def reduce_scatter_async(
|
||||
self,
|
||||
input_list: List[Tensor],
|
||||
group: ProcessGroup,
|
||||
callback_fn: Optional[Callable] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Reduce-scatter a list of tensors asynchronously, so smaller reductions
|
||||
can be bucketed together. The given callback (``callback_fn``) will be
|
||||
called with the reduced result at some later time. Call ``flush()`` to
|
||||
force all queued ops and callbacks to be executed.
|
||||
|
||||
Note that large inputs will be reduced immediately, and this function
|
||||
may also flush the relevant bucket to make room for ``input_list``.
|
||||
|
||||
Args:
|
||||
input_list (List[Tensor]): list of tensors to reduce-scatter. List
|
||||
should contain ``group.size()`` tensors and each tensor should
|
||||
have identical shape, dtype and device.
|
||||
group (ProcessGroup): process group for reduction
|
||||
callback_fn (Callable, Optional): callback function to call after
|
||||
the reduction executes. Function will be called with a single
|
||||
argument corresponding to the reduced result.
|
||||
"""
|
||||
world_size = group.size()
|
||||
|
||||
assert (len(input_list) == world_size
|
||||
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
|
||||
|
||||
first_input = input_list[0]
|
||||
first_input_size = first_input.numel()
|
||||
|
||||
bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size)
|
||||
if first_input_size > bucket_shard_size:
|
||||
# TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors)
|
||||
# input is too big to fit in the bucket, reduce-scatter directly
|
||||
output = torch.zeros_like(input_list[0])
|
||||
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
|
||||
input_flattened = torch.cat(input_list)
|
||||
dist._reduce_scatter_base(output, input_flattened, group=group)
|
||||
else:
|
||||
# fallback
|
||||
dist.reduce_scatter(output, input_list, group=group)
|
||||
if callback_fn is not None:
|
||||
callback_fn(output)
|
||||
return
|
||||
|
||||
bucket = self._get_bucket(first_input, group)
|
||||
if first_input_size > bucket.buffer.size(1) - bucket.offset:
|
||||
# not enough space remaining in bucket, flush it now
|
||||
bucket.flush()
|
||||
bucket.append(input_list, callback_fn)
|
||||
|
||||
@torch.no_grad()
|
||||
def flush(self) -> None:
|
||||
"""Reduce-scatter any partial buckets."""
|
||||
for bucket in self.buckets.values():
|
||||
bucket.flush()
|
||||
|
||||
@torch.no_grad()
|
||||
def free(self) -> None:
|
||||
"""Free buffers from all buckets."""
|
||||
for bucket in self.buckets.values():
|
||||
bucket.free()
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_shard_size(self, element_size: int, num_shards: int) -> int:
|
||||
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
|
||||
return 0
|
||||
MB = 1024 * 1024
|
||||
bucket_size = self.bucket_size_mb * MB / element_size
|
||||
return int(bucket_size // num_shards)
|
||||
|
||||
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
|
||||
key = (tensor.dtype, tensor.device, group)
|
||||
if key not in self.buckets:
|
||||
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
|
||||
world_size = group.size()
|
||||
shard_size = self._get_shard_size(tensor.element_size(), world_size)
|
||||
self.buckets[key] = Bucket(shard_size, tensor.dtype, tensor.device, group)
|
||||
self.buckets[key].alloc()
|
||||
return self.buckets[key]
|
Reference in New Issue
Block a user