mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[fp8] support fp8 amp for hybrid parallel plugin (#5975)
* [fp8] support fp8 amp for hybrid parallel plugin * [test] add fp8 hook test * [fp8] fix fp8 linear compatibility
This commit is contained in:
50
tests/test_fp8/test_fp8_hook.py
Normal file
50
tests/test_fp8/test_fp8_hook.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster.plugin.fp8_hook import FP8Hook
|
||||
from colossalai.quantization.fp8 import linear_fp8
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
REPLACED = False
|
||||
TRIGGERED = False
|
||||
|
||||
|
||||
def new_linear_fp8(x, w, bias=None):
|
||||
global TRIGGERED
|
||||
TRIGGERED = True
|
||||
return linear_fp8(x, w, bias)
|
||||
|
||||
|
||||
class FP8TestHook(FP8Hook):
|
||||
def rewrite_op(self, func):
|
||||
func = super().rewrite_op(func)
|
||||
if func is linear_fp8:
|
||||
global REPLACED
|
||||
REPLACED = True
|
||||
return new_linear_fp8
|
||||
return func
|
||||
|
||||
|
||||
D_IN, D_OUT = 16, 32
|
||||
B, S = 2, 64
|
||||
DTYPE = torch.bfloat16
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
|
||||
def test_fp8_hook():
|
||||
# create tensors
|
||||
w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE))
|
||||
x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||
w.__class__ = ColoParameter
|
||||
w.__init__(w, requires_grad=True)
|
||||
hook = FP8TestHook()
|
||||
with ColoParamOpHookManager.use_hooks(hook):
|
||||
o = F.linear(x, w)
|
||||
assert o.shape == (B, S, D_OUT)
|
||||
assert REPLACED
|
||||
assert TRIGGERED
|
Reference in New Issue
Block a user