From c832927d4ef87743ba328387f6ccde97cff90a5b Mon Sep 17 00:00:00 2001 From: hxwang Date: Wed, 12 Mar 2025 14:47:24 +0800 Subject: [PATCH] [deepgemm] adapt deepgemm --- colossalai/quantization/deep_gemm_utils.py | 94 ++++++++++++++++++++++ colossalai/quantization/fp8.py | 51 ++++++++++-- tests/test_fp8/test_fp8_deepgemm.py | 30 +++++++ 3 files changed, 169 insertions(+), 6 deletions(-) create mode 100644 colossalai/quantization/deep_gemm_utils.py create mode 100644 tests/test_fp8/test_fp8_deepgemm.py diff --git a/colossalai/quantization/deep_gemm_utils.py b/colossalai/quantization/deep_gemm_utils.py new file mode 100644 index 000000000..72a103175 --- /dev/null +++ b/colossalai/quantization/deep_gemm_utils.py @@ -0,0 +1,94 @@ +# This file was modifed from https://github.com/deepseek-ai/DeepGEMM +# as the utils are not included in library +# Thanks for developing and open-sourcing the performant kernel + +# Original LICENSE: + +# MIT License + +# Copyright (c) 2025 DeepSeek + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import warnings +from typing import Tuple + +import torch + +__WARNING_MSG = "Couldn't find deep_gemm library, please install from https://github.com/deepseek-ai/DeepGEMM and run corresponding tests" +try: + from deep_gemm import ceil_div, gemm_fp8_fp8_bf16_nt + + IS_DEEP_GEMM_AVAIL = True +except ImportError: + IS_DEEP_GEMM_AVAIL = False + warnings.warn(__WARNING_MSG) + + def ceil_dev(*args, **kwargs): # to surpass code lint + raise NotImplementedError(__WARNING_MSG) + + def gemm_fp8_fp8_bf16_nt(*args, **kwargs): + raise NotImplementedError(__WARNING_MSG) + + +def deepgemm_fp8_gemm( + lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor +) -> None: + gemm_fp8_fp8_bf16_nt(lhs, rhs, out) + + +# TODO: There seems to be better kernel implemented in triton +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Casting input tensor to float8_e4m3fn percicision and cooresponding scaler in token-wise mannar + Args: + x (`torch.Tensor`): + Matmul x in x @ y.t(), where x.shape() is (m, k) + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: x_float8_e4m3fn and scaler + """ + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + + +# TODO: There seems to be better kernel implemented in triton +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +def per_block_cast_to_fp8(y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Casting input tensor to float8_e4m3fn percicision and cooresponding scaler in block-wise mannar + Args: + y (`torch.Tensor`): + Matmul y in x @ y.t(), where y.shape() is (n, k) + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: y_float8_e4m3fn and scaler + """ + assert y.dim() == 2 + m, n = y.shape + x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=y.dtype, device=y.device) + x_padded[:m, :n] = y + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index e23da5ccc..d899964e5 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from packaging.version import Version from torch.distributed import ReduceOp +from .deep_gemm_utils import deepgemm_fp8_gemm, per_block_cast_to_fp8, per_token_cast_to_fp8 from .fp8_config import dynamic_kernel SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0") @@ -699,17 +700,11 @@ def all_gather_fp8_lagacy( ret = cur_buffer[SCALE_BYTES:].view(fp8_type) ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret) cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale - # cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8) dist.all_gather(combined_buffers, cur_buffer, group=group, async_op=async_op) for out, buf in zip(output_list, combined_buffers): scale = buf[:SCALE_BYTES].clone().view(scale.dtype) output = buf[SCALE_BYTES:].view(fp8_type) cast_from_fp8(output.view(shape), scale, input_type, out=out) - # output = combined_buffer.view(world_size, -1)[:, SCALE_BYTES:].view(fp8_type) - # scales = combined_buffer.view(world_size, -1)[:, :SCALE_BYTES].view(torch.float) - # output = output.float() * scales - # for i, out in enumerate(output_list): - # out.copy_(output[i].view(shape)) @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89) @@ -834,6 +829,50 @@ class _LinearFp8(torch.autograd.Function): return x_grad.reshape(ctx.x_shape), w_grad, bias_grad +class _LinearFp8DeepGemm(torch.autograd.Function): + """ + Behave similar to torch.nn.functional.linear + """ + + def forward(ctx: Any, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + if not (x.dim() == 2 and w.dim() == 2): + raise ValueError("Batched fp8 deep_gemm is not supported") + # x: (m, k), w: (n, k) + # x @ w_t -> (m, k) @ (k, n) -> deep_gemm((m, k), (n, k)) + (m, k), (n, _) = x.shape, w.shape + x_per_tok, w_per_blk = per_token_cast_to_fp8(x), per_block_cast_to_fp8(w) + + out = torch.empty((m, n), dtype=torch.bfloat16, device=x.device) # NOTE: DeepGemm only supports bf16 output + deepgemm_fp8_gemm(x_per_tok, w_per_blk, out) + + ctx.w_t_per_plk = per_block_cast_to_fp8(w.t()) + ctx.x_t_per_blk = per_block_cast_to_fp8(x.t()) + ctx.mnk = (m, n, k) + return out + + def backward(ctx: Any, o_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # o_grad: (m, n) + # x_grad: (m, k) -> (m, n) @ (n, k) -> deep_gemm((m, n), (k, n)) + # w_grad: (n, k) -> (m, n).t() @ (m, k) -> deep_gemm((m, n).t(), (k, m)) + m, n, k = ctx.mnk + o_per_tok = per_token_cast_to_fp8(o_grad) + + x_grad = torch.empty((m, k), dtype=torch.bfloat16, device=o_grad.device) + deepgemm_fp8_gemm(o_per_tok, ctx.w_t_per_plk, x_grad) + + o_grad_t_per_tok = per_token_cast_to_fp8(o_grad.t()) + w_grad = torch.empty((n, k), dtype=torch.bfloat16, device=o_grad.device) + deepgemm_fp8_gemm(o_grad_t_per_tok, ctx.x_t_per_blk, w_grad) + + return x_grad, w_grad + + +def linear_fp8_deep_gemm(input: torch.Tensor, weight: torch.Tensor, bias: None = None) -> torch.Tensor: + if bias is not None: + raise ValueError("bias is not supported in deep_gemm") + return _LinearFp8DeepGemm.apply(input, weight) + + @torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel) def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: return _LinearFp8.apply(input, weight, bias) diff --git a/tests/test_fp8/test_fp8_deepgemm.py b/tests/test_fp8/test_fp8_deepgemm.py new file mode 100644 index 000000000..47aed247a --- /dev/null +++ b/tests/test_fp8/test_fp8_deepgemm.py @@ -0,0 +1,30 @@ +import pytest +import torch +import torch.nn.functional as F +from torch.testing import assert_close + +from colossalai.accelerator import get_accelerator +from colossalai.quantization.fp8 import linear_fp8_deep_gemm +from colossalai.utils import get_current_device + +m, k, n = 128, 384, 256 +DTYPE = torch.bfloat16 + + +@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0") +def test_fp8_linear(): + # create tensors + x = torch.rand((m, k), device=get_current_device(), dtype=DTYPE, requires_grad=True) + w = torch.rand((n, k), device=get_current_device(), dtype=DTYPE, requires_grad=True) + ref_w = w.clone().detach().requires_grad_() + ref_x = x.clone().detach().requires_grad_() + + out = linear_fp8_deep_gemm(x, w) + assert out.shape == x.shape[:-1] + (n,) + out.sum().backward() + ref_out = F.linear(ref_x, ref_w) + ref_out.sum().backward() + + assert_close(out, ref_out, rtol=0.2, atol=0.1) + assert_close(x.grad, ref_x.grad, rtol=0.2, atol=0.1) + assert_close(w.grad, ref_w.grad, rtol=0.2, atol=0.1)