1
0
mirror of https://github.com/hpcaitech/ColossalAI.git synced 2025-05-02 13:45:36 +00:00
ColossalAI/tests/test_fp8/test_fp8_hook.py
2025-03-17 10:47:25 +08:00

88 lines
2.6 KiB
Python

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import linear_fp8, linear_fp8_deep_gemm
from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device
REPLACED = False
TRIGGERED = False
def new_linear_fp8(x, w, bias=None):
global TRIGGERED
TRIGGERED = True
return linear_fp8(x, w, bias)
def new_deepgemm_fp8_gemm(lhs, rhs, out=None):
global TRIGGERED
TRIGGERED = True
return linear_fp8_deep_gemm(lhs, rhs, out)
class FP8TestHook(FP8Hook):
def rewrite_op(self, func):
func = super().rewrite_op(func)
if func is linear_fp8:
global REPLACED
REPLACED = True
return new_linear_fp8
return func
class DeepGemmTestHook(FP8DeepGemmHook):
def rewrite_op(self, func):
func = super().rewrite_op(func)
if func is linear_fp8_deep_gemm:
global REPLACED
REPLACED = True
return new_deepgemm_fp8_gemm
return func
D_IN, D_OUT = 128, 128
B, S = 2, 64
DTYPE = torch.bfloat16
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
def test_fp8_hook():
global REPLACED, TRIGGERED
REPLACED = False
TRIGGERED = False
# create tensors
w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE))
x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
w.__class__ = ColoParameter
w.__init__(w, requires_grad=True)
hook = FP8TestHook()
with ColoParamOpHookManager.use_hooks(hook):
o = F.linear(x, w)
assert o.shape == (B, S, D_OUT)
assert REPLACED
assert TRIGGERED
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
def test_fp8_deep_gemm_hook():
global REPLACED, TRIGGERED
REPLACED = False
TRIGGERED = False
# create tensors
w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE))
x = torch.rand(S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
w.__class__ = ColoParameter
w.__init__(w, requires_grad=True)
hook = DeepGemmTestHook()
with ColoParamOpHookManager.use_hooks(hook):
o = F.linear(x, w)
assert o.shape == (S, D_OUT)
assert REPLACED
assert TRIGGERED