mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 05:29:36 +00:00
[deepgemm] adapt deepgemm
This commit is contained in:
parent
44d4053fec
commit
c832927d4e
94
colossalai/quantization/deep_gemm_utils.py
Normal file
94
colossalai/quantization/deep_gemm_utils.py
Normal file
@ -0,0 +1,94 @@
|
||||
# This file was modifed from https://github.com/deepseek-ai/DeepGEMM
|
||||
# as the utils are not included in library
|
||||
# Thanks for developing and open-sourcing the performant kernel
|
||||
|
||||
# Original LICENSE:
|
||||
|
||||
# MIT License
|
||||
|
||||
# Copyright (c) 2025 DeepSeek
|
||||
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
__WARNING_MSG = "Couldn't find deep_gemm library, please install from https://github.com/deepseek-ai/DeepGEMM and run corresponding tests"
|
||||
try:
|
||||
from deep_gemm import ceil_div, gemm_fp8_fp8_bf16_nt
|
||||
|
||||
IS_DEEP_GEMM_AVAIL = True
|
||||
except ImportError:
|
||||
IS_DEEP_GEMM_AVAIL = False
|
||||
warnings.warn(__WARNING_MSG)
|
||||
|
||||
def ceil_dev(*args, **kwargs): # to surpass code lint
|
||||
raise NotImplementedError(__WARNING_MSG)
|
||||
|
||||
def gemm_fp8_fp8_bf16_nt(*args, **kwargs):
|
||||
raise NotImplementedError(__WARNING_MSG)
|
||||
|
||||
|
||||
def deepgemm_fp8_gemm(
|
||||
lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor
|
||||
) -> None:
|
||||
gemm_fp8_fp8_bf16_nt(lhs, rhs, out)
|
||||
|
||||
|
||||
# TODO: There seems to be better kernel implemented in triton
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
|
||||
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Casting input tensor to float8_e4m3fn percicision and cooresponding scaler in token-wise mannar
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Matmul x in x @ y.t(), where x.shape() is (m, k)
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`: x_float8_e4m3fn and scaler
|
||||
"""
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
# TODO: There seems to be better kernel implemented in triton
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
|
||||
def per_block_cast_to_fp8(y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Casting input tensor to float8_e4m3fn percicision and cooresponding scaler in block-wise mannar
|
||||
Args:
|
||||
y (`torch.Tensor`):
|
||||
Matmul y in x @ y.t(), where y.shape() is (n, k)
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`: y_float8_e4m3fn and scaler
|
||||
"""
|
||||
assert y.dim() == 2
|
||||
m, n = y.shape
|
||||
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=y.dtype, device=y.device)
|
||||
x_padded[:m, :n] = y
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
||||
from packaging.version import Version
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from .deep_gemm_utils import deepgemm_fp8_gemm, per_block_cast_to_fp8, per_token_cast_to_fp8
|
||||
from .fp8_config import dynamic_kernel
|
||||
|
||||
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
|
||||
@ -699,17 +700,11 @@ def all_gather_fp8_lagacy(
|
||||
ret = cur_buffer[SCALE_BYTES:].view(fp8_type)
|
||||
ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret)
|
||||
cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale
|
||||
# cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8)
|
||||
dist.all_gather(combined_buffers, cur_buffer, group=group, async_op=async_op)
|
||||
for out, buf in zip(output_list, combined_buffers):
|
||||
scale = buf[:SCALE_BYTES].clone().view(scale.dtype)
|
||||
output = buf[SCALE_BYTES:].view(fp8_type)
|
||||
cast_from_fp8(output.view(shape), scale, input_type, out=out)
|
||||
# output = combined_buffer.view(world_size, -1)[:, SCALE_BYTES:].view(fp8_type)
|
||||
# scales = combined_buffer.view(world_size, -1)[:, :SCALE_BYTES].view(torch.float)
|
||||
# output = output.float() * scales
|
||||
# for i, out in enumerate(output_list):
|
||||
# out.copy_(output[i].view(shape))
|
||||
|
||||
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
|
||||
@ -834,6 +829,50 @@ class _LinearFp8(torch.autograd.Function):
|
||||
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
||||
|
||||
|
||||
class _LinearFp8DeepGemm(torch.autograd.Function):
|
||||
"""
|
||||
Behave similar to torch.nn.functional.linear
|
||||
"""
|
||||
|
||||
def forward(ctx: Any, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
||||
if not (x.dim() == 2 and w.dim() == 2):
|
||||
raise ValueError("Batched fp8 deep_gemm is not supported")
|
||||
# x: (m, k), w: (n, k)
|
||||
# x @ w_t -> (m, k) @ (k, n) -> deep_gemm((m, k), (n, k))
|
||||
(m, k), (n, _) = x.shape, w.shape
|
||||
x_per_tok, w_per_blk = per_token_cast_to_fp8(x), per_block_cast_to_fp8(w)
|
||||
|
||||
out = torch.empty((m, n), dtype=torch.bfloat16, device=x.device) # NOTE: DeepGemm only supports bf16 output
|
||||
deepgemm_fp8_gemm(x_per_tok, w_per_blk, out)
|
||||
|
||||
ctx.w_t_per_plk = per_block_cast_to_fp8(w.t())
|
||||
ctx.x_t_per_blk = per_block_cast_to_fp8(x.t())
|
||||
ctx.mnk = (m, n, k)
|
||||
return out
|
||||
|
||||
def backward(ctx: Any, o_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# o_grad: (m, n)
|
||||
# x_grad: (m, k) -> (m, n) @ (n, k) -> deep_gemm((m, n), (k, n))
|
||||
# w_grad: (n, k) -> (m, n).t() @ (m, k) -> deep_gemm((m, n).t(), (k, m))
|
||||
m, n, k = ctx.mnk
|
||||
o_per_tok = per_token_cast_to_fp8(o_grad)
|
||||
|
||||
x_grad = torch.empty((m, k), dtype=torch.bfloat16, device=o_grad.device)
|
||||
deepgemm_fp8_gemm(o_per_tok, ctx.w_t_per_plk, x_grad)
|
||||
|
||||
o_grad_t_per_tok = per_token_cast_to_fp8(o_grad.t())
|
||||
w_grad = torch.empty((n, k), dtype=torch.bfloat16, device=o_grad.device)
|
||||
deepgemm_fp8_gemm(o_grad_t_per_tok, ctx.x_t_per_blk, w_grad)
|
||||
|
||||
return x_grad, w_grad
|
||||
|
||||
|
||||
def linear_fp8_deep_gemm(input: torch.Tensor, weight: torch.Tensor, bias: None = None) -> torch.Tensor:
|
||||
if bias is not None:
|
||||
raise ValueError("bias is not supported in deep_gemm")
|
||||
return _LinearFp8DeepGemm.apply(input, weight)
|
||||
|
||||
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel)
|
||||
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return _LinearFp8.apply(input, weight, bias)
|
||||
|
30
tests/test_fp8/test_fp8_deepgemm.py
Normal file
30
tests/test_fp8/test_fp8_deepgemm.py
Normal file
@ -0,0 +1,30 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import linear_fp8_deep_gemm
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
m, k, n = 128, 384, 256
|
||||
DTYPE = torch.bfloat16
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
|
||||
def test_fp8_linear():
|
||||
# create tensors
|
||||
x = torch.rand((m, k), device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||
w = torch.rand((n, k), device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||
ref_w = w.clone().detach().requires_grad_()
|
||||
ref_x = x.clone().detach().requires_grad_()
|
||||
|
||||
out = linear_fp8_deep_gemm(x, w)
|
||||
assert out.shape == x.shape[:-1] + (n,)
|
||||
out.sum().backward()
|
||||
ref_out = F.linear(ref_x, ref_w)
|
||||
ref_out.sum().backward()
|
||||
|
||||
assert_close(out, ref_out, rtol=0.2, atol=0.1)
|
||||
assert_close(x.grad, ref_x.grad, rtol=0.2, atol=0.1)
|
||||
assert_close(w.grad, ref_w.grad, rtol=0.2, atol=0.1)
|
Loading…
Reference in New Issue
Block a user