mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -13,7 +13,6 @@ from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class Bucket:
|
||||
|
||||
def __init__(self, size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
|
||||
self.buffer = torch.zeros(size, dtype=dtype, device=device)
|
||||
self.group = group
|
||||
@@ -26,7 +25,7 @@ class Bucket:
|
||||
assert len(self.callbacks) == 0
|
||||
return
|
||||
# reduce-scatter bucket
|
||||
dist.all_reduce(self.buffer[:self.offset], group=self.group)
|
||||
dist.all_reduce(self.buffer[: self.offset], group=self.group)
|
||||
|
||||
# execute post-reduction callbacks
|
||||
for callback_fn in self.callbacks:
|
||||
@@ -37,24 +36,22 @@ class Bucket:
|
||||
self.buffer = torch.zeros_like(self.buffer)
|
||||
|
||||
def alloc(self) -> None:
|
||||
|
||||
if self.buffer.storage().size() == 0:
|
||||
self.buffer.storage().resize_(self.buffer.numel())
|
||||
|
||||
def free(self) -> None:
|
||||
|
||||
assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
|
||||
self.buffer.storage().resize_(0)
|
||||
|
||||
def append(self, tensor: Tensor, callback_fn: Callable):
|
||||
tensor_size = tensor.numel()
|
||||
offset = self.offset
|
||||
self.buffer[offset:offset + tensor_size].copy_(tensor.flatten())
|
||||
self.buffer[offset : offset + tensor_size].copy_(tensor.flatten())
|
||||
self.offset += tensor_size
|
||||
|
||||
# callback will be given the reduced result
|
||||
if callback_fn is not None:
|
||||
result_view = self.buffer[offset:offset + tensor_size].view(tensor.shape)
|
||||
result_view = self.buffer[offset : offset + tensor_size].view(tensor.shape)
|
||||
self.callbacks.append(functools.partial(callback_fn, result_view))
|
||||
|
||||
@property
|
||||
@@ -63,7 +60,6 @@ class Bucket:
|
||||
|
||||
|
||||
class Reducer:
|
||||
|
||||
def __init__(self, bucket_size_mb: int = 25):
|
||||
self.bucket_size_mb = bucket_size_mb
|
||||
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}
|
||||
@@ -101,7 +97,7 @@ class Reducer:
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_bucket_size(self, element_size: int) -> int:
|
||||
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
|
||||
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
|
||||
return 0
|
||||
MB = 1024 * 1024
|
||||
bucket_size = self.bucket_size_mb * MB / element_size
|
||||
|
Reference in New Issue
Block a user