diff --git a/colossalai/zero/sharded_model/reduce_scatter.py b/colossalai/zero/sharded_model/reduce_scatter.py index 8225b7566..4fb507382 100644 --- a/colossalai/zero/sharded_model/reduce_scatter.py +++ b/colossalai/zero/sharded_model/reduce_scatter.py @@ -20,6 +20,7 @@ else: class Bucket: + def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup): self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device) self.group = group @@ -34,18 +35,18 @@ class Bucket: return # reduce-scatter bucket if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives: - dist._reduce_scatter_base( - self.output_shard[: self.offset], self.buffer[:, : self.offset].contiguous(), group=self.group - ) + dist._reduce_scatter_base(self.output_shard[:self.offset], + self.buffer[:, :self.offset].contiguous(), + group=self.group) else: - dist.reduce_scatter( - self.output_shard[: self.offset], list(self.buffer[:, : self.offset].unbind(0)), group=self.group - ) + dist.reduce_scatter(self.output_shard[:self.offset], + list(self.buffer[:, :self.offset].unbind(0)), + group=self.group) # execute post-reduction callbacks for callback_fn in self.callbacks: callback_fn() # reuse input bucket but allocate a fresh output shard - self.buffer[:, : self.offset].zero_() + self.buffer[:, :self.offset].zero_() self.offset = 0 self.callbacks.clear() self.output_shard = torch.zeros_like(self.buffer[0]) @@ -73,12 +74,12 @@ class Bucket: tensor_size = tensor_list[0].numel() stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size) offset = self.offset - self.buffer[:, offset: offset + tensor_size].copy_(stacked_input) + self.buffer[:, offset:offset + tensor_size].copy_(stacked_input) self.offset += tensor_size # callback will be given the reduced result if callback_fn is not None: - result_view = self.output_shard[offset: offset + tensor_size].view_as(tensor_list[0]) + result_view = self.output_shard[offset:offset + tensor_size].view_as(tensor_list[0]) self.callbacks.append(functools.partial(callback_fn, result_view)) @@ -141,9 +142,8 @@ class ReduceScatterBucketer: """ world_size = group.size() - assert ( - len(input_list) == world_size - ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})" + assert (len(input_list) == world_size + ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})" first_input = input_list[0] first_input_size = first_input.numel() @@ -183,7 +183,7 @@ class ReduceScatterBucketer: @functools.lru_cache() def _get_shard_size(self, element_size: int, num_shards: int) -> int: - if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. + if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. return 0 MB = 1024 * 1024 bucket_size = self.bucket_size_mb * MB / element_size