[deepgemm] adapt deepgemm

This commit is contained in:
hxwang 2025-03-12 14:47:24 +08:00
parent 44d4053fec
commit c832927d4e
No known key found for this signature in database
GPG Key ID: 0EC383D418F0B9F8
3 changed files with 169 additions and 6 deletions

View File

@ -0,0 +1,94 @@
# This file was modifed from https://github.com/deepseek-ai/DeepGEMM
# as the utils are not included in library
# Thanks for developing and open-sourcing the performant kernel
# Original LICENSE:
# MIT License
# Copyright (c) 2025 DeepSeek
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import warnings
from typing import Tuple
import torch
__WARNING_MSG = "Couldn't find deep_gemm library, please install from https://github.com/deepseek-ai/DeepGEMM and run corresponding tests"
try:
from deep_gemm import ceil_div, gemm_fp8_fp8_bf16_nt
IS_DEEP_GEMM_AVAIL = True
except ImportError:
IS_DEEP_GEMM_AVAIL = False
warnings.warn(__WARNING_MSG)
def ceil_dev(*args, **kwargs): # to surpass code lint
raise NotImplementedError(__WARNING_MSG)
def gemm_fp8_fp8_bf16_nt(*args, **kwargs):
raise NotImplementedError(__WARNING_MSG)
def deepgemm_fp8_gemm(
lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor
) -> None:
gemm_fp8_fp8_bf16_nt(lhs, rhs, out)
# TODO: There seems to be better kernel implemented in triton
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Casting input tensor to float8_e4m3fn percicision and cooresponding scaler in token-wise mannar
Args:
x (`torch.Tensor`):
Matmul x in x @ y.t(), where x.shape() is (m, k)
Returns:
`Tuple[torch.Tensor, torch.Tensor]`: x_float8_e4m3fn and scaler
"""
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
# TODO: There seems to be better kernel implemented in triton
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def per_block_cast_to_fp8(y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Casting input tensor to float8_e4m3fn percicision and cooresponding scaler in block-wise mannar
Args:
y (`torch.Tensor`):
Matmul y in x @ y.t(), where y.shape() is (n, k)
Returns:
`Tuple[torch.Tensor, torch.Tensor]`: y_float8_e4m3fn and scaler
"""
assert y.dim() == 2
m, n = y.shape
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=y.dtype, device=y.device)
x_padded[:m, :n] = y
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))

View File

@ -8,6 +8,7 @@ import torch.nn.functional as F
from packaging.version import Version
from torch.distributed import ReduceOp
from .deep_gemm_utils import deepgemm_fp8_gemm, per_block_cast_to_fp8, per_token_cast_to_fp8
from .fp8_config import dynamic_kernel
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
@ -699,17 +700,11 @@ def all_gather_fp8_lagacy(
ret = cur_buffer[SCALE_BYTES:].view(fp8_type)
ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret)
cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale
# cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8)
dist.all_gather(combined_buffers, cur_buffer, group=group, async_op=async_op)
for out, buf in zip(output_list, combined_buffers):
scale = buf[:SCALE_BYTES].clone().view(scale.dtype)
output = buf[SCALE_BYTES:].view(fp8_type)
cast_from_fp8(output.view(shape), scale, input_type, out=out)
# output = combined_buffer.view(world_size, -1)[:, SCALE_BYTES:].view(fp8_type)
# scales = combined_buffer.view(world_size, -1)[:, :SCALE_BYTES].view(torch.float)
# output = output.float() * scales
# for i, out in enumerate(output_list):
# out.copy_(output[i].view(shape))
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
@ -834,6 +829,50 @@ class _LinearFp8(torch.autograd.Function):
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
class _LinearFp8DeepGemm(torch.autograd.Function):
"""
Behave similar to torch.nn.functional.linear
"""
def forward(ctx: Any, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
if not (x.dim() == 2 and w.dim() == 2):
raise ValueError("Batched fp8 deep_gemm is not supported")
# x: (m, k), w: (n, k)
# x @ w_t -> (m, k) @ (k, n) -> deep_gemm((m, k), (n, k))
(m, k), (n, _) = x.shape, w.shape
x_per_tok, w_per_blk = per_token_cast_to_fp8(x), per_block_cast_to_fp8(w)
out = torch.empty((m, n), dtype=torch.bfloat16, device=x.device) # NOTE: DeepGemm only supports bf16 output
deepgemm_fp8_gemm(x_per_tok, w_per_blk, out)
ctx.w_t_per_plk = per_block_cast_to_fp8(w.t())
ctx.x_t_per_blk = per_block_cast_to_fp8(x.t())
ctx.mnk = (m, n, k)
return out
def backward(ctx: Any, o_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# o_grad: (m, n)
# x_grad: (m, k) -> (m, n) @ (n, k) -> deep_gemm((m, n), (k, n))
# w_grad: (n, k) -> (m, n).t() @ (m, k) -> deep_gemm((m, n).t(), (k, m))
m, n, k = ctx.mnk
o_per_tok = per_token_cast_to_fp8(o_grad)
x_grad = torch.empty((m, k), dtype=torch.bfloat16, device=o_grad.device)
deepgemm_fp8_gemm(o_per_tok, ctx.w_t_per_plk, x_grad)
o_grad_t_per_tok = per_token_cast_to_fp8(o_grad.t())
w_grad = torch.empty((n, k), dtype=torch.bfloat16, device=o_grad.device)
deepgemm_fp8_gemm(o_grad_t_per_tok, ctx.x_t_per_blk, w_grad)
return x_grad, w_grad
def linear_fp8_deep_gemm(input: torch.Tensor, weight: torch.Tensor, bias: None = None) -> torch.Tensor:
if bias is not None:
raise ValueError("bias is not supported in deep_gemm")
return _LinearFp8DeepGemm.apply(input, weight)
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel)
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)

View File

@ -0,0 +1,30 @@
import pytest
import torch
import torch.nn.functional as F
from torch.testing import assert_close
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import linear_fp8_deep_gemm
from colossalai.utils import get_current_device
m, k, n = 128, 384, 256
DTYPE = torch.bfloat16
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
def test_fp8_linear():
# create tensors
x = torch.rand((m, k), device=get_current_device(), dtype=DTYPE, requires_grad=True)
w = torch.rand((n, k), device=get_current_device(), dtype=DTYPE, requires_grad=True)
ref_w = w.clone().detach().requires_grad_()
ref_x = x.clone().detach().requires_grad_()
out = linear_fp8_deep_gemm(x, w)
assert out.shape == x.shape[:-1] + (n,)
out.sum().backward()
ref_out = F.linear(ref_x, ref_w)
ref_out.sum().backward()
assert_close(out, ref_out, rtol=0.2, atol=0.1)
assert_close(x.grad, ref_x.grad, rtol=0.2, atol=0.1)
assert_close(w.grad, ref_w.grad, rtol=0.2, atol=0.1)