mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)
* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement communication hook for FSDP params all-gather * added unit test for fp8 operators * support fp8 communication in GeminiPlugin * update training scripts to support fsdp and fp8 communication * fixed some minor bugs observed in unit test * add all_gather_into_tensor_flat_fp8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * add skip the test if torch < 2.2.0 * add fp8_comm flag * rebase latest fp8 operators * rebase latest fp8 operators * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
26
tests/test_fp8/test_fp8_cast.py
Normal file
26
tests/test_fp8/test_fp8_cast.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import cast_from_fp8, cast_from_fp8_pipeline, cast_to_fp8, cast_to_fp8_pipeline
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
|
||||
@parameterize("shape", [(100, 10), (10, 100), (3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16, torch.float32])
|
||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||
def test_fp8_cast(shape, dtype, fp8_format):
|
||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
ret, scale_inv = cast_to_fp8(x, fp8_format=fp8_format)
|
||||
out = cast_from_fp8(ret, scale_inv, x.dtype)
|
||||
assert_close(out, x, rtol=0.1, atol=0.1)
|
||||
|
||||
if x.size(-1) % 2 == 0:
|
||||
inp_dict = {"hidden_states": x.clone()}
|
||||
cast_to_fp8_pipeline(inp_dict)
|
||||
cast_from_fp8_pipeline(inp_dict)
|
||||
assert_close(inp_dict["hidden_states"], x, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_fp8_cast()
|
Reference in New Issue
Block a user