ColossalAI/tests/test_fp8/test_fp8_fsdp_comm_hook.py
flybird11111 46ed5d856b
[ci] update ci (#6254)
* fix for async io

* test for upgrading transformers

* add ci machine

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_fp16_torch.py

* Update build_on_pr.yml

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fiux

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-04-18 16:40:53 +08:00

109 lines
3.4 KiB
Python

import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from packaging import version
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing import assert_close
from colossalai import launch
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(100, 100)
self.relu = nn.ReLU()
self.net2 = nn.Linear(100, 50)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
@clear_cache_before_run()
@parameterize("mode", ["grad", "params"])
def run_model(mode):
rank = dist.get_rank()
from colossalai.quantization.utils import patch_fsdp_params_comm_hook
patch_fsdp_params_comm_hook()
def get_grads_after_one_iteration(grad_hook=None, params_hook=None):
torch.manual_seed(0)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
fsdp_model = FSDP(model)
if grad_hook is not None:
fsdp_model.register_comm_hook(None, grad_hook)
if params_hook is not None:
fsdp_model.register_params_comm_hook(None, params_hook)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(fsdp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = fsdp_model(torch.randn(20, 100))
labels = torch.randn(20, 50).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
torch.distributed.barrier()
grad_dict = {}
for name, params in fsdp_model.named_parameters():
grad_dict[name] = params.grad
return grad_dict
from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook, fp8_compress_fsdp_params_comm_hook
if mode == "grad":
grad_dict = get_grads_after_one_iteration()
for hook in [
fp8_compress_fsdp_grad_comm_hook,
]:
grad_dict_w_hook = get_grads_after_one_iteration(grad_hook=hook)
if dist.get_rank() == 0:
for name in grad_dict:
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)
elif mode == "params":
grad_dict = get_grads_after_one_iteration()
for hook in [
fp8_compress_fsdp_params_comm_hook,
]:
grad_dict_w_hook = get_grads_after_one_iteration(params_hook=hook)
if dist.get_rank() == 0:
for name in grad_dict:
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)
else:
raise NotImplementedError
def demo_basic(rank, world_size, port):
print(f"Running basic FSDP example on rank {rank}.")
launch(rank=rank, world_size=world_size, port=port, host="localhost")
run_model()
cleanup()
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("2.2.0"), reason="torch version < 2.2.0.")
@rerun_if_address_is_in_use()
def test_fsdp():
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
spawn(demo_basic, n_gpus)
if __name__ == "__main__":
test_fsdp()