mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 13:05:26 +00:00
add scatter/gather optim for pipeline (#123)
This commit is contained in:
@@ -62,3 +62,31 @@ def recv_tensor_meta(tensor_shape, prev_rank=None):
|
||||
tensor_shape = torch.Size(recv_shape)
|
||||
|
||||
return tensor_shape
|
||||
|
||||
|
||||
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
|
||||
"""Break a tensor into equal 1D chunks."""
|
||||
partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
end_index = start_index + partition_size
|
||||
if new_buffer:
|
||||
data = torch.empty(partition_size, dtype=tensor.dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False)
|
||||
data.copy_(tensor.view(-1)[start_index:end_index])
|
||||
else:
|
||||
data = tensor.view(-1)[start_index:end_index]
|
||||
return data
|
||||
|
||||
|
||||
def gather_split_1d_tensor(tensor):
|
||||
"""Opposite of above function, gather values from model parallel ranks."""
|
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
numel = torch.numel(tensor)
|
||||
numel_gathered = world_size * numel
|
||||
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False)
|
||||
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)]
|
||||
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
return gathered
|
||||
|
Reference in New Issue
Block a user