mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[Inference/Feat] Add convert_fp8 op for fp8 test in the future (#5706)
* add convert_fp8 op for fp8 test in the future * rerun ci
This commit is contained in:
57
tests/test_infer/test_kernels/cuda/test_convert_fp8.py
Normal file
57
tests/test_infer/test_kernels/cuda/test_convert_fp8.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [42] # Arbitrary values for testing
|
||||
NUM_LAYERS = [1] # Arbitrary values for testing
|
||||
NUM_HEADS = [8] # Arbitrary values for testing
|
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
BLOCK_SIZES = [8, 16, 32]
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="FP8 conversion still needs improvement, now we skip it's relative test!")
|
||||
@pytest.mark.parametrize("num_heads", [8])
|
||||
@pytest.mark.parametrize("head_size", [64, 80, 96, 112, 128, 256])
|
||||
@pytest.mark.parametrize("block_size", [8, 16, 32])
|
||||
@pytest.mark.parametrize("num_blocks", [1024, 10000])
|
||||
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16, torch.float])
|
||||
@pytest.mark.parametrize("seed", [0])
|
||||
@torch.inference_mode()
|
||||
def test_fp8_conversion(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
device = get_current_device()
|
||||
|
||||
low = -224.0
|
||||
high = 224.0
|
||||
shape = (num_blocks, num_heads, head_size, block_size)
|
||||
cache = torch.empty(shape, dtype=dtype, device=device)
|
||||
cache.uniform_(low, high)
|
||||
|
||||
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
|
||||
inference_ops.convert_fp8(cache, cache_fp8)
|
||||
|
||||
converted_cache = torch.empty_like(cache)
|
||||
inference_ops.convert_fp8(cache_fp8, converted_cache)
|
||||
|
||||
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_fp8_conversion(8, 64, 8, 1024, torch.half, 0)
|
Reference in New Issue
Block a user