mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[Feature] llama shardformer fp8 support (#5938)
* add llama shardformer fp8 * Llama Shardformer Parity * fix typo * fix all reduce * fix pytest failure * fix reduce op and move function to fp8.py * fix typo
This commit is contained in:
@@ -14,7 +14,13 @@ try:
|
||||
except ImportError:
|
||||
_grad_accum_fusion_available = False
|
||||
|
||||
from colossalai.quantization.fp8 import all_reduce_fp8, cast_from_fp8, cast_to_fp8, reduce_scatter_fp8
|
||||
from colossalai.quantization.fp8 import (
|
||||
all_reduce_fp8,
|
||||
all_to_all_fp8,
|
||||
all_to_all_single_fp8,
|
||||
gather_fp8,
|
||||
reduce_scatter_fp8,
|
||||
)
|
||||
|
||||
|
||||
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
||||
@@ -117,11 +123,12 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
ctx.fp8_communication = fp8_communication
|
||||
if bias is not None:
|
||||
output = F.linear(input_, weight, bias)
|
||||
else:
|
||||
@@ -133,6 +140,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
def backward(ctx, grad_output):
|
||||
input, weight, bias = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
fp8_communication = ctx.fp8_communication
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
|
||||
if use_bias:
|
||||
@@ -148,7 +156,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
# Asynchronous all-reduce
|
||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||
if fp8_communication:
|
||||
all_reduce_fp8(grad_input, group=ctx.process_group)
|
||||
else:
|
||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||
|
||||
@@ -167,10 +178,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
if ctx.async_grad_allreduce and not fp8_communication:
|
||||
handle.wait()
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
|
||||
@@ -238,16 +249,18 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, dim):
|
||||
def forward(ctx, input_, process_group, dim, fp8_communication=False):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
ctx.fp8_communication = fp8_communication
|
||||
|
||||
return _gather(input_, dim, process_group)
|
||||
return _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
fp8_communication = ctx.fp8_communication
|
||||
# do reduce-scatter
|
||||
new_shape = list(grad_output.shape)
|
||||
assert (
|
||||
@@ -259,9 +272,12 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
]
|
||||
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
|
||||
|
||||
dist.reduce_scatter(output, grad_list, group=process_group)
|
||||
if fp8_communication:
|
||||
reduce_scatter_fp8(output, grad_list, group=process_group, fp8_format="e5m2")
|
||||
else:
|
||||
dist.reduce_scatter(output, grad_list, group=process_group)
|
||||
|
||||
return output, None, None
|
||||
return output, None, None, None
|
||||
|
||||
|
||||
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
@@ -577,12 +593,8 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
||||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
fp8_communication = ctx.fp8_communication
|
||||
return (
|
||||
_gather(grad_output, dim, process_group, fp8_communication=fp8_communication, fp8_format="e5m2"),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
return _gather(grad_output, dim, process_group, fp8_communication, fp8_format="e5m2"), None, None, None
|
||||
|
||||
|
||||
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
@@ -816,26 +828,67 @@ class _AllToAll(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
|
||||
def forward(ctx, input_, process_group, scatter_dim, gather_dim, fp8_communication=False):
|
||||
ctx.process_group = process_group
|
||||
ctx.scatter_dim = scatter_dim
|
||||
ctx.gather_dim = gather_dim
|
||||
ctx.fp8_communication = fp8_communication
|
||||
world_size = dist.get_world_size(process_group)
|
||||
bsz, _, _ = input_.shape
|
||||
|
||||
# using all_to_all_single when batch size is 1
|
||||
if bsz == 1:
|
||||
return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim)
|
||||
return _all_to_all_single(
|
||||
input_,
|
||||
world_size,
|
||||
process_group,
|
||||
scatter_dim,
|
||||
gather_dim,
|
||||
fp8_communication=fp8_communication,
|
||||
fp8_format="e4m3",
|
||||
)
|
||||
else:
|
||||
return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim)
|
||||
return _all_to_all(
|
||||
input_,
|
||||
world_size,
|
||||
process_group,
|
||||
scatter_dim,
|
||||
gather_dim,
|
||||
fp8_communication=fp8_communication,
|
||||
fp8_format="e4m3",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_output):
|
||||
def backward(ctx, grad_output):
|
||||
process_group = ctx.process_group
|
||||
scatter_dim = ctx.gather_dim
|
||||
gather_dim = ctx.scatter_dim
|
||||
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
|
||||
return (return_grad, None, None, None)
|
||||
fp8_communication = ctx.fp8_communication
|
||||
world_size = dist.get_world_size(process_group)
|
||||
bsz, _, _ = grad_output.shape
|
||||
|
||||
if bsz == 1:
|
||||
return_grad = _all_to_all_single(
|
||||
grad_output,
|
||||
world_size,
|
||||
process_group,
|
||||
scatter_dim,
|
||||
gather_dim,
|
||||
fp8_communication=fp8_communication,
|
||||
fp8_format="e5m2",
|
||||
)
|
||||
else:
|
||||
return_grad = _all_to_all(
|
||||
grad_output,
|
||||
world_size,
|
||||
process_group,
|
||||
scatter_dim,
|
||||
gather_dim,
|
||||
fp8_communication=fp8_communication,
|
||||
fp8_format="e5m2",
|
||||
)
|
||||
|
||||
return (return_grad, None, None, None, None)
|
||||
|
||||
|
||||
class HookParameter(torch.autograd.Function):
|
||||
@@ -899,33 +952,14 @@ def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_for
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
input_ = input_.contiguous()
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
if fp8_communication:
|
||||
input_type = input_.dtype
|
||||
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format)
|
||||
fp8_type = ret.dtype
|
||||
input_ = ret.view(torch.uint8)
|
||||
input_ = input_.contiguous()
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
scale = torch.tensor(scale, dtype=torch.float32).to(input_.device)
|
||||
scale_list = [torch.ones(1, dtype=torch.float32, device=input_.device) for _ in range(world_size)]
|
||||
|
||||
scale = torch.tensor(scale).to(input_.device)
|
||||
torch.distributed.all_gather(tensor_list, input_, group=process_group)
|
||||
torch.distributed.all_gather(scale_list, scale, group=process_group)
|
||||
|
||||
cast_tensor_list = []
|
||||
for output, scale in zip(tensor_list, scale_list):
|
||||
output = output.view(fp8_type)
|
||||
output = cast_from_fp8(output, scale, input_type)
|
||||
cast_tensor_list.append(output)
|
||||
|
||||
output = torch.cat(cast_tensor_list, dim=dim).contiguous()
|
||||
|
||||
gather_fp8(tensor_list, input_, fp8_format=fp8_format, group=process_group)
|
||||
else:
|
||||
input_ = input_.contiguous()
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(tensor_list, input_, group=process_group)
|
||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
dist.all_gather(tensor_list, input_, group=process_group)
|
||||
|
||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
@@ -954,14 +988,19 @@ def _reduce_scatter(input_, dim=1, process_group=None):
|
||||
return output
|
||||
|
||||
|
||||
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim):
|
||||
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"):
|
||||
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
|
||||
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
|
||||
dist.all_to_all(output_list, input_list, group=group)
|
||||
if fp8_communication:
|
||||
all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format)
|
||||
else:
|
||||
dist.all_to_all(output_list, input_list, group=group)
|
||||
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||
|
||||
|
||||
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
|
||||
def _all_to_all_single(
|
||||
input_, seq_world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"
|
||||
):
|
||||
inp_shape = list(input_.shape)
|
||||
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
|
||||
if scatter_dim < 2:
|
||||
@@ -974,7 +1013,11 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
|
||||
)
|
||||
|
||||
output = torch.empty_like(input_t)
|
||||
dist.all_to_all_single(output, input_t, group=group)
|
||||
if fp8_communication:
|
||||
all_to_all_single_fp8(output, input_t, group=group, fp8_format=fp8_format)
|
||||
else:
|
||||
|
||||
dist.all_to_all_single(output, input_t, group=group)
|
||||
|
||||
if scatter_dim < 2:
|
||||
output = output.transpose(0, 1).contiguous()
|
||||
@@ -994,8 +1037,10 @@ def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allre
|
||||
)
|
||||
|
||||
|
||||
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
|
||||
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
||||
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
||||
return LinearWithAsyncCommunication.apply(
|
||||
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
|
||||
)
|
||||
|
||||
|
||||
def linear_gather_forward_reducescatter_backward(
|
||||
@@ -1006,8 +1051,8 @@ def linear_gather_forward_reducescatter_backward(
|
||||
)
|
||||
|
||||
|
||||
def gather_forward_reducescatter_backward(input_, process_group, dim):
|
||||
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim)
|
||||
def gather_forward_reducescatter_backward(input_, process_group, dim, fp8_communication=False):
|
||||
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim, fp8_communication)
|
||||
|
||||
|
||||
def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_communication=False):
|
||||
@@ -1042,5 +1087,5 @@ def reduce_backward(input_, process_group, fp8_communication=False):
|
||||
return _ReduceBackward.apply(input_, process_group, fp8_communication)
|
||||
|
||||
|
||||
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
|
||||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
|
||||
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False):
|
||||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)
|
||||
|
Reference in New Issue
Block a user