mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)
* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement communication hook for FSDP params all-gather * added unit test for fp8 operators * support fp8 communication in GeminiPlugin * update training scripts to support fsdp and fp8 communication * fixed some minor bugs observed in unit test * add all_gather_into_tensor_flat_fp8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * add skip the test if torch < 2.2.0 * add fp8_comm flag * rebase latest fp8 operators * rebase latest fp8 operators * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
26
tests/test_fp8/test_fp8_cast.py
Normal file
26
tests/test_fp8/test_fp8_cast.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import cast_from_fp8, cast_from_fp8_pipeline, cast_to_fp8, cast_to_fp8_pipeline
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
|
||||
@parameterize("shape", [(100, 10), (10, 100), (3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16, torch.float32])
|
||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||
def test_fp8_cast(shape, dtype, fp8_format):
|
||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
ret, scale_inv = cast_to_fp8(x, fp8_format=fp8_format)
|
||||
out = cast_from_fp8(ret, scale_inv, x.dtype)
|
||||
assert_close(out, x, rtol=0.1, atol=0.1)
|
||||
|
||||
if x.size(-1) % 2 == 0:
|
||||
inp_dict = {"hidden_states": x.clone()}
|
||||
cast_to_fp8_pipeline(inp_dict)
|
||||
cast_from_fp8_pipeline(inp_dict)
|
||||
assert_close(inp_dict["hidden_states"], x, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_fp8_cast()
|
87
tests/test_fp8/test_fp8_ddp_comm_hook.py
Normal file
87
tests/test_fp8/test_fp8_ddp_comm_hook.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
|
||||
|
||||
|
||||
def setup(rank, world_size):
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12355"
|
||||
|
||||
# initialize the process group
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(ToyModel, self).__init__()
|
||||
self.net1 = nn.Linear(10, 10)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(10, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net2(self.relu(self.net1(x)))
|
||||
|
||||
|
||||
def demo_basic(rank, world_size):
|
||||
print(f"Running basic DDP example on rank {rank}.")
|
||||
setup(rank, world_size)
|
||||
|
||||
def get_grads_after_one_iteration(hook=None):
|
||||
torch.manual_seed(0)
|
||||
# create model and move it to GPU with id rank
|
||||
model = ToyModel().to(rank)
|
||||
|
||||
ddp_model = DDP(model, device_ids=[rank])
|
||||
|
||||
if hook is not None:
|
||||
ddp_model.register_comm_hook(None, hook)
|
||||
|
||||
loss_fn = nn.MSELoss()
|
||||
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = ddp_model(torch.randn(20, 10))
|
||||
labels = torch.randn(20, 5).to(rank)
|
||||
loss_fn(outputs, labels).backward()
|
||||
optimizer.step()
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
grad_dict = {}
|
||||
for name, params in ddp_model.named_parameters():
|
||||
grad_dict[name] = params.grad
|
||||
return grad_dict
|
||||
|
||||
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async, fp8_compress_ddp_grad_comm_hook_sync
|
||||
|
||||
grad_dict = get_grads_after_one_iteration()
|
||||
for hook in [fp8_compress_ddp_grad_comm_hook_sync, fp8_compress_ddp_grad_comm_hook_async]:
|
||||
grad_dict_w_hook = get_grads_after_one_iteration(hook)
|
||||
if dist.get_rank() == 0:
|
||||
for name in grad_dict:
|
||||
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)
|
||||
|
||||
cleanup()
|
||||
|
||||
|
||||
def run_demo(demo_fn, world_size):
|
||||
mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
n_gpus = torch.cuda.device_count()
|
||||
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
|
||||
world_size = n_gpus
|
||||
run_demo(demo_basic, world_size)
|
107
tests/test_fp8/test_fp8_fsdp_comm_hook.py
Normal file
107
tests/test_fp8/test_fp8_fsdp_comm_hook.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from packaging import version
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
|
||||
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(ToyModel, self).__init__()
|
||||
self.net1 = nn.Linear(100, 100)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(100, 50)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net2(self.relu(self.net1(x)))
|
||||
|
||||
|
||||
@parameterize("mode", ["grad", "params"])
|
||||
def run_model(mode):
|
||||
rank = dist.get_rank()
|
||||
|
||||
from colossalai.quantization.utils import patch_fsdp_params_comm_hook
|
||||
|
||||
patch_fsdp_params_comm_hook()
|
||||
|
||||
def get_grads_after_one_iteration(grad_hook=None, params_hook=None):
|
||||
torch.manual_seed(0)
|
||||
# create model and move it to GPU with id rank
|
||||
model = ToyModel().to(rank)
|
||||
fsdp_model = FSDP(model)
|
||||
|
||||
if grad_hook is not None:
|
||||
fsdp_model.register_comm_hook(None, grad_hook)
|
||||
|
||||
if params_hook is not None:
|
||||
fsdp_model.register_params_comm_hook(None, params_hook)
|
||||
|
||||
loss_fn = nn.MSELoss()
|
||||
optimizer = optim.SGD(fsdp_model.parameters(), lr=0.001)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = fsdp_model(torch.randn(20, 100))
|
||||
labels = torch.randn(20, 50).to(rank)
|
||||
loss_fn(outputs, labels).backward()
|
||||
optimizer.step()
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
grad_dict = {}
|
||||
for name, params in fsdp_model.named_parameters():
|
||||
grad_dict[name] = params.grad
|
||||
return grad_dict
|
||||
|
||||
from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook, fp8_compress_fsdp_params_comm_hook
|
||||
|
||||
if mode == "grad":
|
||||
grad_dict = get_grads_after_one_iteration()
|
||||
for hook in [
|
||||
fp8_compress_fsdp_grad_comm_hook,
|
||||
]:
|
||||
grad_dict_w_hook = get_grads_after_one_iteration(grad_hook=hook)
|
||||
if dist.get_rank() == 0:
|
||||
for name in grad_dict:
|
||||
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)
|
||||
elif mode == "params":
|
||||
grad_dict = get_grads_after_one_iteration()
|
||||
for hook in [
|
||||
fp8_compress_fsdp_params_comm_hook,
|
||||
]:
|
||||
grad_dict_w_hook = get_grads_after_one_iteration(params_hook=hook)
|
||||
if dist.get_rank() == 0:
|
||||
for name in grad_dict:
|
||||
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def demo_basic(rank, world_size, port):
|
||||
print(f"Running basic FSDP example on rank {rank}.")
|
||||
launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
run_model()
|
||||
cleanup()
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("2.2.0"), reason="torch version < 2.2.0.")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_fsdp():
|
||||
n_gpus = torch.cuda.device_count()
|
||||
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
|
||||
spawn(demo_basic, n_gpus)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_fsdp()
|
Reference in New Issue
Block a user