mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 07:31:19 +00:00
[shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
This commit is contained in:
@@ -167,6 +167,97 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
|
||||
# currently only support one single tensor as output
|
||||
group_size = dist.get_world_size(process_group)
|
||||
cur_rank = dist.get_rank(process_group)
|
||||
|
||||
# output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)]
|
||||
|
||||
# initialization of ring communication
|
||||
recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
|
||||
send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
|
||||
rank_map = list(dist.get_process_group_ranks(process_group))
|
||||
recv_rank = rank_map[recv_rank]
|
||||
send_rank = rank_map[send_rank]
|
||||
recv_tensors = {}
|
||||
send_tensors = {}
|
||||
for k, v in input_to_gather.items():
|
||||
recv_tensors[k] = torch.empty_like(v)
|
||||
send_tensors[k] = v.clone()
|
||||
|
||||
def communicate_step():
|
||||
comm_ops = []
|
||||
for k in recv_tensors:
|
||||
comm_ops.append(dist.P2POp(dist.irecv, recv_tensors[k], recv_rank, group=process_group))
|
||||
comm_ops.append(dist.P2POp(dist.isend, send_tensors[k], send_rank, group=process_group))
|
||||
return dist.batch_isend_irecv(comm_ops)
|
||||
|
||||
def switch_step():
|
||||
for k in recv_tensors:
|
||||
send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k]
|
||||
|
||||
output_tensors = []
|
||||
|
||||
handles = communicate_step()
|
||||
# first round: special case, retrive from local tensor
|
||||
output_tensors.append(func(**input_to_gather, **input_local))
|
||||
for i in range(group_size - 2):
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
|
||||
switch_step()
|
||||
|
||||
handles = communicate_step()
|
||||
|
||||
# actual computation
|
||||
output_tensors.append(func(**send_tensors, **input_local))
|
||||
|
||||
# final round: special case, no need to send/recv again
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
output_tensors.append(func(**recv_tensors, **input_local))
|
||||
|
||||
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim)
|
||||
|
||||
|
||||
class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
|
||||
|
||||
Args:
|
||||
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
|
||||
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
|
||||
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, dim):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
|
||||
return _gather(input_, dim, process_group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
|
||||
# do reduce-scatter
|
||||
new_shape = list(grad_output.shape)
|
||||
assert (
|
||||
new_shape[dim] % dist.get_world_size(process_group) == 0
|
||||
), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
|
||||
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
|
||||
grad_list = [
|
||||
item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)
|
||||
]
|
||||
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
|
||||
dist.reduce_scatter(output, grad_list, group=process_group)
|
||||
|
||||
return output, None, None
|
||||
|
||||
|
||||
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
|
||||
|
||||
@@ -178,7 +269,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
@@ -186,12 +277,25 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
ctx.dim = dim
|
||||
ctx.overlap = overlap
|
||||
|
||||
input_parallel = _gather(input_, dim, process_group)
|
||||
if ring is True:
|
||||
input_to_gather = {"input": input_}
|
||||
input_local = {"weight": weight}
|
||||
|
||||
if bias is not None:
|
||||
output = F.linear(input_parallel, weight, bias)
|
||||
output = _ring_as_gather(
|
||||
F.linear,
|
||||
input_to_gather=input_to_gather,
|
||||
input_local=input_local,
|
||||
process_group=process_group,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output += bias
|
||||
else:
|
||||
output = F.linear(input_parallel, weight)
|
||||
input_parallel = _gather(input_, dim, process_group)
|
||||
if bias is not None:
|
||||
output = F.linear(input_parallel, weight, bias)
|
||||
else:
|
||||
output = F.linear(input_parallel, weight)
|
||||
|
||||
return output
|
||||
|
||||
@@ -294,11 +398,146 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
# wait until reduce-scatter finished
|
||||
reducescatter_handle.wait()
|
||||
|
||||
return output, grad_weight, grad_bias, None, None, None, None
|
||||
return output, grad_weight, grad_bias, None, None, None, None, None
|
||||
|
||||
|
||||
def _ring_as_reducescatter(
|
||||
func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1
|
||||
):
|
||||
# currently only support one single tensor as output
|
||||
group_size = dist.get_world_size(process_group)
|
||||
cur_rank = dist.get_rank(process_group)
|
||||
|
||||
# initialization of ring communication
|
||||
recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
|
||||
send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
|
||||
rank_map = list(dist.get_process_group_ranks(process_group))
|
||||
recv_rank = rank_map[recv_rank]
|
||||
send_rank = rank_map[send_rank]
|
||||
input_tensors = []
|
||||
for _ in range(group_size):
|
||||
input_tensors.append({})
|
||||
for k, v in input_to_reducescatter.items():
|
||||
input_shape = v.shape
|
||||
assert input_shape[reducescatter_dim] % group_size == 0
|
||||
_input_tensors = list(torch.split(v, input_shape[reducescatter_dim] // group_size, dim=reducescatter_dim))
|
||||
for i in range(group_size):
|
||||
input_tensors[i][k] = _input_tensors[i]
|
||||
input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank]
|
||||
input_tensors.reverse()
|
||||
|
||||
output_tensor = func(**input_tensors[0], **input_local)
|
||||
recv_tensor = torch.empty_like(output_tensor)
|
||||
send_tensor = output_tensor.clone()
|
||||
|
||||
def communicate_step():
|
||||
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
|
||||
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
|
||||
return dist.batch_isend_irecv([recv_op, send_op])
|
||||
|
||||
handles = communicate_step()
|
||||
# first round: special case, retrive from local tensor
|
||||
for i in range(group_size - 2):
|
||||
# actual computation
|
||||
output_tensor = func(**input_tensors[i + 1], **input_local)
|
||||
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
output_tensor += recv_tensor
|
||||
|
||||
tmp_tensor = send_tensor
|
||||
send_tensor = output_tensor
|
||||
output_tensor = tmp_tensor
|
||||
|
||||
handles = communicate_step()
|
||||
|
||||
# final round: special case, no need to send/recv again
|
||||
output_tensor = func(**input_tensors[-1], **input_local)
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
output_tensor += recv_tensor
|
||||
return output_tensor
|
||||
|
||||
|
||||
class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
|
||||
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
|
||||
"""Reduce-scatter input from sequence parallel in forward and gather gradient in backward with ring
|
||||
|
||||
Args:
|
||||
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
|
||||
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
|
||||
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, dim, ring):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
|
||||
if ring is True:
|
||||
input_to_reducescatter = {"input": input_}
|
||||
input_local = {"weight": weight}
|
||||
|
||||
if bias is not None:
|
||||
input_to_reducescatter["bias"] = bias
|
||||
|
||||
output = _ring_as_reducescatter(
|
||||
F.linear,
|
||||
input_to_reducescatter=input_to_reducescatter,
|
||||
input_local=input_local,
|
||||
process_group=process_group,
|
||||
)
|
||||
else:
|
||||
if bias is not None:
|
||||
partial_output = F.linear(input_, weight, bias)
|
||||
else:
|
||||
partial_output = F.linear(input_, weight)
|
||||
|
||||
output_shape = list(partial_output.shape)
|
||||
assert (
|
||||
output_shape[dim] % dist.get_world_size(process_group) == 0
|
||||
), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
|
||||
output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group)
|
||||
|
||||
output_list = [
|
||||
item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim)
|
||||
]
|
||||
output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous()
|
||||
dist.reduce_scatter(output, output_list, group=process_group)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input_, weight, bias = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
|
||||
if use_bias:
|
||||
bias = bias.view(bias.shape)
|
||||
|
||||
grad_output = _gather(grad_output, dim, process_group)
|
||||
|
||||
# TODO Need to fully optimize
|
||||
total_input = input_
|
||||
grad_input = grad_output.matmul(weight)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
||||
"""Reduce-scatter input from sequence parallel in forward and gather gradient in backward
|
||||
|
||||
Args:
|
||||
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
|
||||
@@ -343,7 +582,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
@@ -351,9 +590,24 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
ctx.dim = dim
|
||||
ctx.overlap = overlap
|
||||
|
||||
input_parallel = _gather(input_, dim, process_group)
|
||||
if ring is True:
|
||||
input_to_gather = {}
|
||||
input_local = {}
|
||||
input_to_gather["input"] = input_
|
||||
input_local["other"] = weight
|
||||
|
||||
output = torch.matmul(input_parallel, weight)
|
||||
output = _ring_as_gather(
|
||||
torch.matmul,
|
||||
input_to_gather=input_to_gather,
|
||||
input_local=input_local,
|
||||
process_group=process_group,
|
||||
gather_dim=dim,
|
||||
)
|
||||
|
||||
else:
|
||||
input_parallel = _gather(input_, dim, process_group)
|
||||
|
||||
output = torch.matmul(input_parallel, weight)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
@@ -433,7 +687,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
# wait until reduce-scatter finished
|
||||
reducescatter_handle.wait()
|
||||
|
||||
return output, grad_weight, grad_bias, None, None, None, None
|
||||
return output, grad_weight, grad_bias, None, None, None, None, None
|
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
@@ -448,14 +702,17 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, process_group):
|
||||
def forward(ctx, input_, dim, process_group, grad_scale=None):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
ctx.grad_scale = grad_scale
|
||||
return _split(input_, dim, process_group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _gather(grad_output, ctx.dim, ctx.process_group), None, None
|
||||
if ctx.grad_scale is not None:
|
||||
grad_output = grad_output * ctx.grad_scale
|
||||
return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None
|
||||
|
||||
|
||||
class _ReduceForward(torch.autograd.Function):
|
||||
@@ -505,14 +762,50 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, process_group):
|
||||
def forward(ctx, input_, dim, process_group, grad_scale=None):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
ctx.grad_scale = grad_scale
|
||||
return _gather(input_, dim, process_group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _split(grad_output, ctx.dim, ctx.process_group), None, None
|
||||
if ctx.grad_scale is not None:
|
||||
grad_output = grad_output * ctx.grad_scale
|
||||
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None
|
||||
|
||||
|
||||
class _AllToAll(torch.autograd.Function):
|
||||
"""All-to-all communication.
|
||||
|
||||
Args:
|
||||
input_: input matrix
|
||||
process_group: communication group
|
||||
scatter_dim: scatter dimension
|
||||
gather_dim: gather dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
|
||||
ctx.process_group = process_group
|
||||
ctx.scatter_dim = scatter_dim
|
||||
ctx.gather_dim = gather_dim
|
||||
world_size = dist.get_world_size(process_group)
|
||||
bsz, _, _ = input_.shape
|
||||
|
||||
# using all_to_all_single when batch size is 1
|
||||
if bsz == 1:
|
||||
return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim)
|
||||
else:
|
||||
return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_output):
|
||||
process_group = ctx.process_group
|
||||
scatter_dim = ctx.gather_dim
|
||||
gather_dim = ctx.scatter_dim
|
||||
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
|
||||
return (return_grad, None, None, None)
|
||||
|
||||
|
||||
class HookParameter(torch.autograd.Function):
|
||||
@@ -608,6 +901,40 @@ def _reduce_scatter(input_, dim=1, process_group=None):
|
||||
return output
|
||||
|
||||
|
||||
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim):
|
||||
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
|
||||
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
|
||||
dist.all_to_all(output_list, input_list, group=group)
|
||||
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||
|
||||
|
||||
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
|
||||
inp_shape = list(input_.shape)
|
||||
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
|
||||
if scatter_dim < 2:
|
||||
input_t = input_.reshape([seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]).contiguous()
|
||||
else:
|
||||
input_t = (
|
||||
input_.reshape([-1, seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :])
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
output = torch.empty_like(input_t)
|
||||
dist.all_to_all_single(output, input_t, group=group)
|
||||
|
||||
if scatter_dim < 2:
|
||||
output = output.transpose(0, 1).contiguous()
|
||||
|
||||
return output.reshape(
|
||||
inp_shape[:gather_dim]
|
||||
+ [
|
||||
inp_shape[gather_dim] * seq_world_size,
|
||||
]
|
||||
+ inp_shape[gather_dim + 1 :]
|
||||
).contiguous()
|
||||
|
||||
|
||||
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
|
||||
return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
||||
|
||||
@@ -617,31 +944,39 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
|
||||
|
||||
|
||||
def linear_gather_forward_reducescatter_backward(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
|
||||
):
|
||||
return _LinearWithGatherForwardReduceScatterBackward.apply(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
|
||||
)
|
||||
|
||||
|
||||
def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
|
||||
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
|
||||
def gather_forward_reducescatter_backward(input_, process_group, dim):
|
||||
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim)
|
||||
|
||||
|
||||
def reducescatter_forward_gather_backward(input_, process_group, dim):
|
||||
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
|
||||
|
||||
|
||||
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False):
|
||||
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring)
|
||||
|
||||
|
||||
def matmul_gather_forward_reducescatter_backward(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
|
||||
):
|
||||
return _MatmulWithGatherForwardReduceScatterBackward.apply(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
|
||||
)
|
||||
|
||||
|
||||
def gather_forward_split_backward(input_, dim, process_group):
|
||||
return _GatherForwardSplitBackward.apply(input_, dim, process_group)
|
||||
def gather_forward_split_backward(input_, dim, process_group, grad_scale=None):
|
||||
return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale)
|
||||
|
||||
|
||||
def split_forward_gather_backward(input_, dim, process_group):
|
||||
return _SplitForwardGatherBackward.apply(input_, dim, process_group)
|
||||
def split_forward_gather_backward(input_, dim, process_group, grad_scale=None):
|
||||
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)
|
||||
|
||||
|
||||
def reduce_forward(input_, process_group):
|
||||
@@ -650,3 +985,7 @@ def reduce_forward(input_, process_group):
|
||||
|
||||
def reduce_backward(input_, process_group):
|
||||
return _ReduceBackward.apply(input_, process_group)
|
||||
|
||||
|
||||
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
|
||||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
|
||||
|
Reference in New Issue
Block a user