mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[Feature] llama shardformer fp8 support (#5938)
* add llama shardformer fp8 * Llama Shardformer Parity * fix typo * fix all reduce * fix pytest failure * fix reduce op and move function to fp8.py * fix typo
This commit is contained in:
@@ -3,9 +3,11 @@ from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
|
||||
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor):
|
||||
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> (torch.Tensor, torch.Tensor):
|
||||
r"""
|
||||
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
|
||||
Args:
|
||||
@@ -23,7 +25,7 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
fp8_max = torch.finfo(fp8_type).max
|
||||
|
||||
if inp.dim() == 2:
|
||||
if per_channel_scale:
|
||||
per_channel_max = inp.abs().max(dim=-1).values.float()
|
||||
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
|
||||
scale = fp8_max / per_channel_max[:, None]
|
||||
@@ -37,7 +39,9 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te
|
||||
return ret, scale_inv
|
||||
|
||||
|
||||
def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor:
|
||||
def cast_from_fp8(
|
||||
inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2].
|
||||
@@ -49,20 +53,23 @@ def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dt
|
||||
if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.")
|
||||
|
||||
if inp.dim() == 2:
|
||||
if per_channel_scale:
|
||||
ret = scale_inv[:, None] * inp.float()
|
||||
else:
|
||||
ret = scale_inv * inp.float()
|
||||
return ret.to(ret_type)
|
||||
|
||||
|
||||
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
|
||||
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None) -> None:
|
||||
r"""
|
||||
This is an in-place operation for compressed all_reduce using fp8.
|
||||
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
||||
|
||||
Args:
|
||||
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
||||
fp8_format: e4m3 or e5m2
|
||||
op: ReduceOp.SUM or ReduceOp.AVG
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
@@ -72,18 +79,20 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
|
||||
input_shape = tensor.shape
|
||||
input_device = tensor.device
|
||||
input_size = tensor.numel()
|
||||
tensor = tensor.flatten()
|
||||
flat_padded_x = tensor.flatten()
|
||||
|
||||
assert op in [ReduceOp.SUM, ReduceOp.AVG], "op can only be ReduceOp.SUM or ReduceOp.AVG"
|
||||
|
||||
if flat_padded_x.size(0) % world_size != 0:
|
||||
pad_size = world_size - flat_padded_x.size(0) % world_size
|
||||
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
|
||||
ret, scale = cast_to_fp8(tensor, fp8_format=fp8_format)
|
||||
ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format)
|
||||
|
||||
inp = ret.view(torch.uint8)
|
||||
input_chunks = list(torch.chunk(inp, world_size, dim=0))
|
||||
if dist.get_rank() == world_size - 1:
|
||||
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
|
||||
else:
|
||||
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
|
||||
output_chunks = list(torch.chunk(torch.empty_like(inp), world_size, dim=0))
|
||||
dist.all_to_all(output_chunks, input_chunks, group=group)
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
@@ -92,15 +101,18 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
|
||||
out = out.view(fp8_type)
|
||||
summed_out += cast_from_fp8(out, scale, input_type)
|
||||
|
||||
if op == ReduceOp.AVG:
|
||||
summed_out.div_(world_size)
|
||||
|
||||
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
|
||||
tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0))
|
||||
tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)]
|
||||
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group)
|
||||
for i in range(world_size):
|
||||
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
|
||||
tensor_out = torch.cat(tensor_list, dim=0)
|
||||
tensor.data = tensor_out.view(input_shape).to(input_type)
|
||||
out = torch.cat(tensor_list, dim=0)
|
||||
tensor.copy_(out[:input_size].view(input_shape).to(input_type))
|
||||
|
||||
|
||||
def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||
@@ -276,5 +288,74 @@ def all_gather_into_tensor_flat_fp8(
|
||||
dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group)
|
||||
numel = np.prod(output_shape)
|
||||
valid_buffer = buffer[:numel].reshape(output_shape)
|
||||
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type)
|
||||
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2))
|
||||
output_tensor[:numel].copy_(valid_buffer.view(-1))
|
||||
|
||||
|
||||
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"):
|
||||
|
||||
world_size = dist.get_world_size(group)
|
||||
|
||||
input_type = input_list[0].dtype
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
scale_list = []
|
||||
tensor_list = []
|
||||
|
||||
for i in range(world_size):
|
||||
input_tensor = input_list[i]
|
||||
ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format)
|
||||
scale_list.append(scale)
|
||||
ret = ret.view(torch.uint8)
|
||||
tensor_list.append(ret)
|
||||
|
||||
output_scale_list = [torch.empty_like(x) for x in scale_list]
|
||||
output_tensor_list = [torch.empty_like(x) for x in tensor_list]
|
||||
dist.all_to_all(output_tensor_list, tensor_list, group=group)
|
||||
dist.all_to_all(output_scale_list, scale_list, group=group)
|
||||
|
||||
for i in range(world_size):
|
||||
scale = output_scale_list[i]
|
||||
tensor = output_tensor_list[i]
|
||||
tensor = tensor.view(fp8_type)
|
||||
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))
|
||||
|
||||
|
||||
def all_to_all_single_fp8(output_tensor, input_tensor, group=None, fp8_format="e5m2"):
|
||||
|
||||
world_size = dist.get_world_size(group)
|
||||
|
||||
per_slice_len = input_tensor.size(0) // world_size
|
||||
input_type = input_tensor.dtype
|
||||
ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format)
|
||||
fp8_type = ret.dtype
|
||||
input_tensor = ret.view(torch.uint8)
|
||||
tensor = torch.empty_like(input_tensor)
|
||||
scale_list = [torch.empty_like(scale) for _ in range(world_size)]
|
||||
dist.all_to_all_single(tensor, input_tensor, group=group)
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
cast_tensor_list = []
|
||||
|
||||
for i in range(world_size):
|
||||
output_part = tensor[per_slice_len * i : per_slice_len * (i + 1)].view(fp8_type)
|
||||
output_part = cast_from_fp8(output_part, scale_list[i], input_type)
|
||||
cast_tensor_list.append(output_part)
|
||||
output_tensor.copy_(torch.concatenate(cast_tensor_list, dim=0))
|
||||
|
||||
|
||||
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"):
|
||||
|
||||
world_size = dist.get_world_size(group)
|
||||
|
||||
input_type = input_.dtype
|
||||
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format)
|
||||
fp8_type = ret.dtype
|
||||
input_ = ret.view(torch.uint8)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)]
|
||||
dist.all_gather(tensor_list, input_, group=group)
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
|
||||
for i in range(world_size):
|
||||
output = tensor_list[i].view(fp8_type)
|
||||
scale = scale_list[i]
|
||||
output_list[i].copy_(cast_from_fp8(output, scale, input_type))
|
||||
|
Reference in New Issue
Block a user