mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-16 12:20:45 +00:00
[NFC] polish colossalai/zero/sharded_model/reduce_scatter.py code style (#1554)
This commit is contained in:
parent
2ac46f7be4
commit
06dccdde44
@ -20,6 +20,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
class Bucket:
|
class Bucket:
|
||||||
|
|
||||||
def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
|
def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
|
||||||
self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device)
|
self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device)
|
||||||
self.group = group
|
self.group = group
|
||||||
@ -34,18 +35,18 @@ class Bucket:
|
|||||||
return
|
return
|
||||||
# reduce-scatter bucket
|
# reduce-scatter bucket
|
||||||
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
|
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
|
||||||
dist._reduce_scatter_base(
|
dist._reduce_scatter_base(self.output_shard[:self.offset],
|
||||||
self.output_shard[: self.offset], self.buffer[:, : self.offset].contiguous(), group=self.group
|
self.buffer[:, :self.offset].contiguous(),
|
||||||
)
|
group=self.group)
|
||||||
else:
|
else:
|
||||||
dist.reduce_scatter(
|
dist.reduce_scatter(self.output_shard[:self.offset],
|
||||||
self.output_shard[: self.offset], list(self.buffer[:, : self.offset].unbind(0)), group=self.group
|
list(self.buffer[:, :self.offset].unbind(0)),
|
||||||
)
|
group=self.group)
|
||||||
# execute post-reduction callbacks
|
# execute post-reduction callbacks
|
||||||
for callback_fn in self.callbacks:
|
for callback_fn in self.callbacks:
|
||||||
callback_fn()
|
callback_fn()
|
||||||
# reuse input bucket but allocate a fresh output shard
|
# reuse input bucket but allocate a fresh output shard
|
||||||
self.buffer[:, : self.offset].zero_()
|
self.buffer[:, :self.offset].zero_()
|
||||||
self.offset = 0
|
self.offset = 0
|
||||||
self.callbacks.clear()
|
self.callbacks.clear()
|
||||||
self.output_shard = torch.zeros_like(self.buffer[0])
|
self.output_shard = torch.zeros_like(self.buffer[0])
|
||||||
@ -73,12 +74,12 @@ class Bucket:
|
|||||||
tensor_size = tensor_list[0].numel()
|
tensor_size = tensor_list[0].numel()
|
||||||
stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size)
|
stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size)
|
||||||
offset = self.offset
|
offset = self.offset
|
||||||
self.buffer[:, offset: offset + tensor_size].copy_(stacked_input)
|
self.buffer[:, offset:offset + tensor_size].copy_(stacked_input)
|
||||||
self.offset += tensor_size
|
self.offset += tensor_size
|
||||||
|
|
||||||
# callback will be given the reduced result
|
# callback will be given the reduced result
|
||||||
if callback_fn is not None:
|
if callback_fn is not None:
|
||||||
result_view = self.output_shard[offset: offset + tensor_size].view_as(tensor_list[0])
|
result_view = self.output_shard[offset:offset + tensor_size].view_as(tensor_list[0])
|
||||||
self.callbacks.append(functools.partial(callback_fn, result_view))
|
self.callbacks.append(functools.partial(callback_fn, result_view))
|
||||||
|
|
||||||
|
|
||||||
@ -141,8 +142,7 @@ class ReduceScatterBucketer:
|
|||||||
"""
|
"""
|
||||||
world_size = group.size()
|
world_size = group.size()
|
||||||
|
|
||||||
assert (
|
assert (len(input_list) == world_size
|
||||||
len(input_list) == world_size
|
|
||||||
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
|
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
|
||||||
|
|
||||||
first_input = input_list[0]
|
first_input = input_list[0]
|
||||||
|
Loading…
Reference in New Issue
Block a user