[Feature] llama shardformer fp8 support (#5938)

* add llama shardformer fp8

* Llama Shardformer Parity

* fix typo

* fix all reduce

* fix pytest failure

* fix reduce op and move function to fp8.py

* fix typo
This commit is contained in:
Guangyao Zhang
2024-08-05 10:05:47 +08:00
committed by GitHub
parent c297e21bea
commit 53cb9606bd
11 changed files with 453 additions and 98 deletions

View File

@@ -3,9 +3,11 @@ from typing import Any
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed import ReduceOp
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor):
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> (torch.Tensor, torch.Tensor):
r"""
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
Args:
@@ -23,7 +25,7 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
fp8_max = torch.finfo(fp8_type).max
if inp.dim() == 2:
if per_channel_scale:
per_channel_max = inp.abs().max(dim=-1).values.float()
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale = fp8_max / per_channel_max[:, None]
@@ -37,7 +39,9 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te
return ret, scale_inv
def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor:
def cast_from_fp8(
inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False
) -> torch.Tensor:
r"""
Args:
inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2].
@@ -49,20 +53,23 @@ def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dt
if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.")
if inp.dim() == 2:
if per_channel_scale:
ret = scale_inv[:, None] * inp.float()
else:
ret = scale_inv * inp.float()
return ret.to(ret_type)
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None) -> None:
r"""
This is an in-place operation for compressed all_reduce using fp8.
It works like dist.all_reduce but during communication the data is cast to fp8 format.
Args:
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
fp8_format: e4m3 or e5m2
op: ReduceOp.SUM or ReduceOp.AVG
Returns:
None
"""
@@ -72,18 +79,20 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
input_shape = tensor.shape
input_device = tensor.device
input_size = tensor.numel()
tensor = tensor.flatten()
flat_padded_x = tensor.flatten()
assert op in [ReduceOp.SUM, ReduceOp.AVG], "op can only be ReduceOp.SUM or ReduceOp.AVG"
if flat_padded_x.size(0) % world_size != 0:
pad_size = world_size - flat_padded_x.size(0) % world_size
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
ret, scale = cast_to_fp8(tensor, fp8_format=fp8_format)
ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format)
inp = ret.view(torch.uint8)
input_chunks = list(torch.chunk(inp, world_size, dim=0))
if dist.get_rank() == world_size - 1:
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
else:
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
output_chunks = list(torch.chunk(torch.empty_like(inp), world_size, dim=0))
dist.all_to_all(output_chunks, input_chunks, group=group)
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
dist.all_gather(scale_list, scale, group=group)
@@ -92,15 +101,18 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)
if op == ReduceOp.AVG:
summed_out.div_(world_size)
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
dist.all_gather(scale_list, scale, group=group)
tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0))
tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)]
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group)
for i in range(world_size):
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
tensor_out = torch.cat(tensor_list, dim=0)
tensor.data = tensor_out.view(input_shape).to(input_type)
out = torch.cat(tensor_list, dim=0)
tensor.copy_(out[:input_size].view(input_shape).to(input_type))
def cast_to_fp8_pipeline(inp: Any) -> None:
@@ -276,5 +288,74 @@ def all_gather_into_tensor_flat_fp8(
dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group)
numel = np.prod(output_shape)
valid_buffer = buffer[:numel].reshape(output_shape)
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type)
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2))
output_tensor[:numel].copy_(valid_buffer.view(-1))
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"):
world_size = dist.get_world_size(group)
input_type = input_list[0].dtype
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
scale_list = []
tensor_list = []
for i in range(world_size):
input_tensor = input_list[i]
ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format)
scale_list.append(scale)
ret = ret.view(torch.uint8)
tensor_list.append(ret)
output_scale_list = [torch.empty_like(x) for x in scale_list]
output_tensor_list = [torch.empty_like(x) for x in tensor_list]
dist.all_to_all(output_tensor_list, tensor_list, group=group)
dist.all_to_all(output_scale_list, scale_list, group=group)
for i in range(world_size):
scale = output_scale_list[i]
tensor = output_tensor_list[i]
tensor = tensor.view(fp8_type)
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))
def all_to_all_single_fp8(output_tensor, input_tensor, group=None, fp8_format="e5m2"):
world_size = dist.get_world_size(group)
per_slice_len = input_tensor.size(0) // world_size
input_type = input_tensor.dtype
ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format)
fp8_type = ret.dtype
input_tensor = ret.view(torch.uint8)
tensor = torch.empty_like(input_tensor)
scale_list = [torch.empty_like(scale) for _ in range(world_size)]
dist.all_to_all_single(tensor, input_tensor, group=group)
dist.all_gather(scale_list, scale, group=group)
cast_tensor_list = []
for i in range(world_size):
output_part = tensor[per_slice_len * i : per_slice_len * (i + 1)].view(fp8_type)
output_part = cast_from_fp8(output_part, scale_list[i], input_type)
cast_tensor_list.append(output_part)
output_tensor.copy_(torch.concatenate(cast_tensor_list, dim=0))
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"):
world_size = dist.get_world_size(group)
input_type = input_.dtype
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format)
fp8_type = ret.dtype
input_ = ret.view(torch.uint8)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)]
dist.all_gather(tensor_list, input_, group=group)
dist.all_gather(scale_list, scale, group=group)
for i in range(world_size):
output = tensor_list[i].view(fp8_type)
scale = scale_list[i]
output_list[i].copy_(cast_from_fp8(output, scale, input_type))