mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 22:10:37 +00:00
[shardformer] integrated linear 1D with dtensor (#3996)
* [shardformer] integrated linear 1D with dtensor * polish code
This commit is contained in:
@@ -54,10 +54,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
|
||||
ctx.save_for_backward(input_, weight)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.parallel_mode = parallel_mode
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
|
||||
output = torch.matmul(input_, weight.t())
|
||||
@@ -74,12 +74,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
grad_input = grad_output.matmul(weight)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
|
||||
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
# Asynchronous all-reduce
|
||||
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
|
||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||
# Delay the start of weight gradient computation shortly (3us) to have
|
||||
# all-reduce scheduled first and have GPU resources allocated
|
||||
_ = torch.empty(1, device=grad_output.device) + 1
|
||||
@@ -93,5 +94,123 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
|
||||
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
"""
|
||||
Split the input and keep only the corresponding chuck to the rank.
|
||||
|
||||
Args:
|
||||
input_ (`torch.Tensor`): input matrix.
|
||||
dim (int): the dimension to perform split and gather
|
||||
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, process_group):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
return _split(input_, dim, process_group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _gather(grad_output, ctx.dim, ctx.process_group), None, None
|
||||
|
||||
|
||||
class _ReduceInput(torch.autograd.Function):
|
||||
"""
|
||||
All-reduce the input from the model parallel region.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
parallel_mode: parallel mode.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group):
|
||||
return _reduce(input_, process_group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output, None
|
||||
|
||||
|
||||
def _reduce(input_, process_group):
|
||||
# skip if only one rank involved
|
||||
if dist.get_world_size(process_group) == 1:
|
||||
return input_
|
||||
else:
|
||||
dist.all_reduce(input_, group=process_group)
|
||||
return input_
|
||||
|
||||
|
||||
def _split(input_, dim=-1, process_group=None):
|
||||
# skip if only one rank involved
|
||||
world_size = dist.get_world_size(process_group)
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# Split along last dimension.
|
||||
dim_size = input_.size(dim)
|
||||
assert dim_size % world_size == 0, \
|
||||
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \
|
||||
f'cannot split tensor evenly'
|
||||
|
||||
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
|
||||
rank = dist.get_rank(process_group)
|
||||
output = tensor_list[rank].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _gather(input_, dim=-1, process_group=None):
|
||||
# skip if only one rank involved
|
||||
world_size = dist.get_world_size(process_group)
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# all gather
|
||||
rank = dist.get_rank(process_group)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
tensor_list[rank] = input_
|
||||
torch.distributed.all_gather(tensor_list, input_, group=process_group)
|
||||
|
||||
# concat
|
||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||
"""Gather the input from model parallel region and concatenate.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
parallel_mode: parallel mode.
|
||||
dim: dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, process_group):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
return _gather(input_, dim, process_group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _split(grad_output, ctx.dim, ctx.process_group), None, None
|
||||
|
||||
|
||||
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
|
||||
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
||||
|
||||
|
||||
def gather_forward_split_backward(input_, dim, process_group):
|
||||
return _GatherForwardSplitBackward.apply(input_, dim, process_group)
|
||||
|
||||
|
||||
def split_forward_gather_backward(input_, dim, process_group):
|
||||
return _SplitForwardGatherBackward.apply(input_, dim, process_group)
|
||||
|
||||
|
||||
def reduce_input(input_, process_group):
|
||||
return _ReduceInput.apply(input_, process_group)
|
||||
|
Reference in New Issue
Block a user