mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-11-03 23:48:41 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			201 lines
		
	
	
		
			8.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			201 lines
		
	
	
		
			8.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright (c) Facebook, Inc. and its affiliates.
 | 
						|
#
 | 
						|
# This source code is licensed under the BSD license found in the
 | 
						|
# LICENSE file in the root directory of this source tree.
 | 
						|
 | 
						|
import functools
 | 
						|
import os
 | 
						|
from typing import Callable, Dict, List, Optional, Tuple
 | 
						|
 | 
						|
import torch
 | 
						|
import torch.distributed as dist
 | 
						|
from torch import Tensor
 | 
						|
from torch.distributed import ProcessGroup
 | 
						|
 | 
						|
# TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved.
 | 
						|
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
 | 
						|
    enable_nccl_base_collectives = False
 | 
						|
else:
 | 
						|
    enable_nccl_base_collectives = True
 | 
						|
 | 
						|
 | 
						|
class Bucket:
 | 
						|
    def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
 | 
						|
        self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device)
 | 
						|
        self.group = group
 | 
						|
        self.offset = 0
 | 
						|
        self.callbacks: List[Callable] = []
 | 
						|
        self.output_shard = torch.zeros_like(self.buffer[0])
 | 
						|
 | 
						|
    def flush(self) -> None:
 | 
						|
        """Flush content of the bucket."""
 | 
						|
        if self.offset == 0:
 | 
						|
            assert len(self.callbacks) == 0
 | 
						|
            return
 | 
						|
        # reduce-scatter bucket
 | 
						|
        if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
 | 
						|
            dist._reduce_scatter_base(
 | 
						|
                self.output_shard[: self.offset], self.buffer[:, : self.offset].contiguous(), group=self.group
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            dist.reduce_scatter(
 | 
						|
                self.output_shard[: self.offset], list(self.buffer[:, : self.offset].unbind(0)), group=self.group
 | 
						|
            )
 | 
						|
        # execute post-reduction callbacks
 | 
						|
        for callback_fn in self.callbacks:
 | 
						|
            callback_fn()
 | 
						|
        # reuse input bucket but allocate a fresh output shard
 | 
						|
        self.buffer[:, : self.offset].zero_()
 | 
						|
        self.offset = 0
 | 
						|
        self.callbacks.clear()
 | 
						|
        self.output_shard = torch.zeros_like(self.buffer[0])
 | 
						|
 | 
						|
    def alloc(self) -> None:
 | 
						|
        """Setup the buffers if they are not allocated.
 | 
						|
 | 
						|
        Using ``setup`` and ``teardown``, we can ensure that the bucket
 | 
						|
        buffers are only allocated during the backward pass, hence saving more
 | 
						|
        memory to other parts of the training process, such as the forward pass
 | 
						|
        for activation memory.
 | 
						|
        """
 | 
						|
        for tensor in [self.buffer, self.output_shard]:
 | 
						|
            if tensor.storage().size() == 0:
 | 
						|
                tensor.storage().resize_(tensor.size().numel())
 | 
						|
 | 
						|
    def free(self) -> None:
 | 
						|
        """Tear down the bucket by freeing the memory"""
 | 
						|
        assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
 | 
						|
        for tensor in [self.buffer, self.output_shard]:
 | 
						|
            tensor.storage().resize_(0)
 | 
						|
 | 
						|
    def append(self, tensor_list: List[Tensor], callback_fn: Callable):
 | 
						|
        # copy data from input_list into bucket
 | 
						|
        tensor_size = tensor_list[0].numel()
 | 
						|
        stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size)
 | 
						|
        offset = self.offset
 | 
						|
        self.buffer[:, offset: offset + tensor_size].copy_(stacked_input)
 | 
						|
        self.offset += tensor_size
 | 
						|
 | 
						|
        # callback will be given the reduced result
 | 
						|
        if callback_fn is not None:
 | 
						|
            result_view = self.output_shard[offset: offset + tensor_size].view_as(tensor_list[0])
 | 
						|
            self.callbacks.append(functools.partial(callback_fn, result_view))
 | 
						|
 | 
						|
 | 
						|
class ReduceScatterBucketer:
 | 
						|
    """
 | 
						|
    Helper for bucketing multiple reduce-scatter operations on small tensors
 | 
						|
    into larger reduce-scatter ops to improve communication efficiency.
 | 
						|
 | 
						|
    Usage::
 | 
						|
 | 
						|
        bucketer = ReduceScatterBucketer()
 | 
						|
        bucketer.reduce_scatter_async(
 | 
						|
            small_tensors, callback_fn=lambda result: print("small")
 | 
						|
        )
 | 
						|
        bucketer.reduce_scatter_async(
 | 
						|
            big_tensors, callback_fn=lambda result: print("big")
 | 
						|
        )
 | 
						|
        bucketer.reduce_scatter_async(
 | 
						|
            more_small_tensors, callback_fn=lambda result: print("small2")
 | 
						|
        )
 | 
						|
        bucketer.flush()  # callbacks only guaranteed to be called after flush()
 | 
						|
        # Example output (note that it is out of order, due to bucketing):
 | 
						|
        # big
 | 
						|
        # small
 | 
						|
        # small2
 | 
						|
 | 
						|
    Args:
 | 
						|
        bucket_size_mb (int, Optional): bucket size for communicating. Buckets
 | 
						|
            are sub-divided based on world_size. Values <= 0 disable bucketing.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, bucket_size_mb: int = 25):
 | 
						|
        self.bucket_size_mb = bucket_size_mb
 | 
						|
        self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}
 | 
						|
 | 
						|
    @torch.no_grad()
 | 
						|
    def reduce_scatter_async(
 | 
						|
        self,
 | 
						|
        input_list: List[Tensor],
 | 
						|
        group: ProcessGroup,
 | 
						|
        callback_fn: Optional[Callable] = None,
 | 
						|
    ) -> None:
 | 
						|
        """
 | 
						|
        Reduce-scatter a list of tensors asynchronously, so smaller reductions
 | 
						|
        can be bucketed together. The given callback (``callback_fn``) will be
 | 
						|
        called with the reduced result at some later time. Call ``flush()`` to
 | 
						|
        force all queued ops and callbacks to be executed.
 | 
						|
 | 
						|
        Note that large inputs will be reduced immediately, and this function
 | 
						|
        may also flush the relevant bucket to make room for ``input_list``.
 | 
						|
 | 
						|
        Args:
 | 
						|
            input_list (List[Tensor]): list of tensors to reduce-scatter. List
 | 
						|
                should contain ``group.size()`` tensors and each tensor should
 | 
						|
                have identical shape, dtype and device.
 | 
						|
            group (ProcessGroup): process group for reduction
 | 
						|
            callback_fn (Callable, Optional): callback function to call after
 | 
						|
                the reduction executes. Function will be called with a single
 | 
						|
                argument corresponding to the reduced result.
 | 
						|
        """
 | 
						|
        world_size = group.size()
 | 
						|
 | 
						|
        assert (
 | 
						|
            len(input_list) == world_size
 | 
						|
        ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
 | 
						|
 | 
						|
        first_input = input_list[0]
 | 
						|
        first_input_size = first_input.numel()
 | 
						|
 | 
						|
        bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size)
 | 
						|
        if first_input_size > bucket_shard_size:
 | 
						|
            # TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors)
 | 
						|
            # input is too big to fit in the bucket, reduce-scatter directly
 | 
						|
            output = torch.zeros_like(input_list[0])
 | 
						|
            if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
 | 
						|
                input_flattened = torch.cat(input_list)
 | 
						|
                dist._reduce_scatter_base(output, input_flattened, group=group)
 | 
						|
            else:
 | 
						|
                # fallback
 | 
						|
                dist.reduce_scatter(output, input_list, group=group)
 | 
						|
            if callback_fn is not None:
 | 
						|
                callback_fn(output)
 | 
						|
            return
 | 
						|
 | 
						|
        bucket = self._get_bucket(first_input, group)
 | 
						|
        if first_input_size > bucket.buffer.size(1) - bucket.offset:
 | 
						|
            # not enough space remaining in bucket, flush it now
 | 
						|
            bucket.flush()
 | 
						|
        bucket.append(input_list, callback_fn)
 | 
						|
 | 
						|
    @torch.no_grad()
 | 
						|
    def flush(self) -> None:
 | 
						|
        """Reduce-scatter any partial buckets."""
 | 
						|
        for bucket in self.buckets.values():
 | 
						|
            bucket.flush()
 | 
						|
 | 
						|
    @torch.no_grad()
 | 
						|
    def free(self) -> None:
 | 
						|
        """Free buffers from all buckets."""
 | 
						|
        for bucket in self.buckets.values():
 | 
						|
            bucket.free()
 | 
						|
 | 
						|
    @functools.lru_cache()
 | 
						|
    def _get_shard_size(self, element_size: int, num_shards: int) -> int:
 | 
						|
        if self.bucket_size_mb <= 0:  # Values <= 0 disable bucketing.
 | 
						|
            return 0
 | 
						|
        MB = 1024 * 1024
 | 
						|
        bucket_size = self.bucket_size_mb * MB / element_size
 | 
						|
        return int(bucket_size // num_shards)
 | 
						|
 | 
						|
    def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
 | 
						|
        key = (tensor.dtype, tensor.device, group)
 | 
						|
        if key not in self.buckets:
 | 
						|
            # buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
 | 
						|
            world_size = group.size()
 | 
						|
            shard_size = self._get_shard_size(tensor.element_size(), world_size)
 | 
						|
            self.buckets[key] = Bucket(shard_size, tensor.dtype, tensor.device, group)
 | 
						|
        self.buckets[key].alloc()
 | 
						|
        return self.buckets[key]
 |