fix rebase

This commit is contained in:
GuangyaoZhang
2024-07-17 02:56:07 +00:00
parent 457a0de79f
commit 5a310b9ee1
10 changed files with 194 additions and 232 deletions

View File

@@ -94,7 +94,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if fp8_communication and ctx.async_grad_allreduce:
if ctx.async_grad_allreduce and fp8_communication:
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication)
elif ctx.async_grad_allreduce:
# Asynchronous all-reduce
@@ -117,12 +117,11 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication
if bias is not None:
output = F.linear(input_, weight, bias)
else:
@@ -134,7 +133,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
fp8_communication = ctx.fp8_communication
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
@@ -150,10 +148,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
if fp8_communication:
all_reduce_fp8(grad_input, group=ctx.process_group)
else:
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
@@ -172,7 +167,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce and not fp8_communication:
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None
@@ -243,18 +238,16 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, process_group, dim, fp8_communication=False):
def forward(ctx, input_, process_group, dim):
ctx.process_group = process_group
ctx.dim = dim
ctx.fp8_communication = fp8_communication
return _gather(input_, dim, process_group, fp8_communication)
return _gather(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group
fp8_communication = ctx.fp8_communication
# do reduce-scatter
new_shape = list(grad_output.shape)
assert (
@@ -266,10 +259,7 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
]
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
if fp8_communication:
reduce_scatter_fp8(output, grad_list, group=process_group)
else:
dist.reduce_scatter(output, grad_list, group=process_group)
dist.reduce_scatter(output, grad_list, group=process_group)
return output, None, None, None
@@ -576,7 +566,6 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
if fp8_communication:
# if False:
reduce_scatter_fp8(output, input_list, group=process_group)
else:
dist.reduce_scatter(output, input_list, group=process_group)
@@ -588,8 +577,7 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
dim = ctx.dim
process_group = ctx.process_group
fp8_communication = ctx.fp8_communication
return _gather(grad_output, dim, process_group, fp8_communication), None, None, None
return _gather(grad_output, dim, process_group, fp8_communication=fp8_communication), None, None, None
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
@@ -793,12 +781,12 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_comm=False):
def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False):
ctx.process_group = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
return _gather(input_, dim, process_group, fp8_comm=fp8_comm, fp8_format="e4m3")
return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3")
@staticmethod
def backward(ctx, grad_output):
@@ -829,11 +817,23 @@ class _AllToAll(torch.autograd.Function):
# using all_to_all_single when batch size is 1
if bsz == 1:
return _all_to_all_single(
input_, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2"
input_,
world_size,
process_group,
scatter_dim,
gather_dim,
fp8_communication=fp8_communication,
fp8_format="e5m2",
)
else:
return _all_to_all(
input_, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2"
input_,
world_size,
process_group,
scatter_dim,
gather_dim,
fp8_communication=fp8_communication,
fp8_format="e5m2",
)
@staticmethod
@@ -841,17 +841,29 @@ class _AllToAll(torch.autograd.Function):
process_group = ctx.process_group
scatter_dim = ctx.gather_dim
gather_dim = ctx.scatter_dim
ctx.fp8_communication
fp8_communication = ctx.fp8_communication
world_size = dist.get_world_size(process_group)
bsz, _, _ = grad_output.shape
if bsz == 1:
return_grad = _all_to_all_single(
grad_output, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2"
grad_output,
world_size,
process_group,
scatter_dim,
gather_dim,
fp8_communication=fp8_communication,
fp8_format="e5m2",
)
else:
return_grad = _all_to_all(
grad_output, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2"
grad_output,
world_size,
process_group,
scatter_dim,
gather_dim,
fp8_communication=fp8_communication,
fp8_format="e5m2",
)
return (return_grad, None, None, None, None)
@@ -912,10 +924,7 @@ def _split(input_, dim=-1, process_group=None):
return output
from colossalai.params import to_cast
def _gather(input_, dim=-1, process_group=None, fp8_comm=False, fp8_format="e4m3"):
def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e4m3"):
# skip if only one rank involved
world_size = dist.get_world_size(process_group)
if world_size == 1:
@@ -926,13 +935,12 @@ def _gather(input_, dim=-1, process_group=None, fp8_comm=False, fp8_format="e4m3
from colossalai.zero.low_level._utils import has_inf_or_nan
if fp8_comm:
if fp8_communication:
# if False:
if has_inf_or_nan(input_):
print("input has nan")
exit(0)
input_type = input_.dtype
to_cast.append(input_)
ret, scale = cast_to_fp8(input_, fp8_format="e5m2")
if has_inf_or_nan(ret):
import pdb
@@ -1012,8 +1020,8 @@ def _reduce_scatter(input_, dim=1, process_group=None):
return output
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_comm=False, fp8_format="e5m2"):
if fp8_comm:
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"):
if fp8_communication:
input_type = input_.dtype
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format)
fp8_type = ret.dtype
@@ -1036,7 +1044,9 @@ def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_comm=Fal
return torch.cat(output_list, dim=gather_dim).contiguous()
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim, fp8_comm=False, fp8_format="e5m2"):
def _all_to_all_single(
input_, seq_world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"
):
inp_shape = list(input_.shape)
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
if scatter_dim < 2:
@@ -1048,7 +1058,7 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim, f
.contiguous()
)
if fp8_comm:
if fp8_communication:
input_type = input_t.dtype
ret, scale = cast_to_fp8(input_t, fp8_format=fp8_format)
fp8_type = ret.dtype
@@ -1085,10 +1095,8 @@ def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allre
)
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
return LinearWithAsyncCommunication.apply(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
)
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
def linear_gather_forward_reducescatter_backward(
@@ -1099,8 +1107,8 @@ def linear_gather_forward_reducescatter_backward(
)
def gather_forward_reducescatter_backward(input_, process_group, dim, fp8_communication=False):
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim, fp8_communication)
def gather_forward_reducescatter_backward(input_, process_group, dim):
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim)
def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_communication=False):
@@ -1132,8 +1140,8 @@ def reduce_forward(input_, process_group, fp8_communication=False):
def reduce_backward(input_, process_group, fp8_communication=False):
return _ReduceBackward.apply(input_, process_group, fp8_communication=fp8_communication)
return _ReduceBackward.apply(input_, process_group, fp8_communication)
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_comm=False):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_comm)
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)