mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[zero] use bucket during allgather (#5860)
* [zero] use bucket during allgather * [zero] rename api
This commit is contained in:
@@ -1,3 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
|
||||
@@ -6,6 +10,7 @@ class TensorBucket:
|
||||
self._max_size = size
|
||||
self._current_size = 0
|
||||
self._bucket = []
|
||||
self._write_back_pairs = {}
|
||||
|
||||
@property
|
||||
def max_size(self):
|
||||
@@ -21,7 +26,7 @@ class TensorBucket:
|
||||
def is_empty(self):
|
||||
return len(self._bucket) == 0
|
||||
|
||||
def add_to_bucket(self, tensor, allow_oversize=False):
|
||||
def add_to_bucket(self, tensor, allow_oversize=False, write_back_tensor: Optional[torch.Tensor] = None):
|
||||
tensor_size = tensor.numel()
|
||||
|
||||
if not allow_oversize and self.will_exceed_max_size(tensor_size):
|
||||
@@ -30,6 +35,8 @@ class TensorBucket:
|
||||
|
||||
self._bucket.append(tensor)
|
||||
self._current_size += tensor_size
|
||||
write_back_tensor = write_back_tensor if write_back_tensor is not None else tensor
|
||||
self._write_back_pairs[tensor] = write_back_tensor
|
||||
|
||||
def will_exceed_max_size(self, tensor_size):
|
||||
expected_size = self._current_size + tensor_size
|
||||
@@ -40,12 +47,30 @@ class TensorBucket:
|
||||
|
||||
def empty(self):
|
||||
self._bucket = []
|
||||
self._size = 0
|
||||
self._current_size = 0
|
||||
self._write_back_pairs = {}
|
||||
|
||||
def flatten(self):
|
||||
return _flatten_dense_tensors(self._bucket)
|
||||
|
||||
def unflatten(self, flat_tensor):
|
||||
return _unflatten_dense_tensors(flat_tensor, self._bucket)
|
||||
|
||||
def unflatten_and_copy(self, flat_tensor):
|
||||
unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket)
|
||||
unflattened_tensor_list = self.unflatten(flat_tensor)
|
||||
for old, new in zip(self._bucket, unflattened_tensor_list):
|
||||
old.copy_(new)
|
||||
|
||||
def all_gather(self, group=None):
|
||||
flat = self.flatten()
|
||||
buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))]
|
||||
dist.all_gather(buffers, flat, group=group)
|
||||
unflat_buffers = [self.unflatten(buffer) for buffer in buffers]
|
||||
# transpose the list of list
|
||||
unflat_buffers = list(map(list, zip(*unflat_buffers)))
|
||||
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
|
||||
write_back_tensor = self._write_back_pairs[tensor]
|
||||
write_back_tensor.data.copy_(
|
||||
_flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()].reshape_as(write_back_tensor)
|
||||
)
|
||||
self.empty()
|
||||
|
Reference in New Issue
Block a user