[CI] Cleanup Dist Optim tests with shared helper funcs (#6125)

* Refractor and cleanup using common helper funcs. Tests passed

* Update comments

* Fix relative import

* Fix param fetching bug
This commit is contained in:
Wenxuan Tan
2025-02-11 23:42:34 -06:00
committed by GitHub
parent 5c09d726a6
commit ec73f1b5e2
8 changed files with 142 additions and 298 deletions

View File

@@ -3,7 +3,6 @@
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
@@ -17,7 +16,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer
from tests.kit.model_zoo import model_zoo
from tests.test_optimizer._utils import check_optim_states, run_bert_test
from tests.test_optimizer._utils import check_optim_states, run_bert_test, set_dist_grad
_ALLOWED_P_G_TYPES = [
(torch.float, torch.float), # pure fp32
@@ -109,39 +108,6 @@ def force_assign_grad(p, g_dtype, grad=None):
p.data = orig_p
def set_dist_grad(
dist_module: nn.Module,
torch_model: nn.Module,
g_dtype: torch.dtype,
group: dist.ProcessGroup,
) -> None:
"""
Set grads chunks for Tensor Parallel or ZeRO DP.
We do not need a separate treatment for ZeRO,
as the LowLevelOptimizer takes care of reduce-scattering grads.
"""
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()):
if torch_p.grad is None:
# avoid inconsistent grad and param dtype error
force_assign_grad(torch_p, g_dtype)
else:
torch_p.grad += torch.randn_like(torch_p, device=torch_p.device, dtype=g_dtype)
if p.grad is None:
force_assign_grad(p, g_dtype)
if is_distributed_tensor(p):
split_dim = get_shard_dim_1d(p)
# Add grads only to the correctly split chunk
force_assign_grad(p, g_dtype, torch_p.grad.chunk(world_size, dim=split_dim)[rank].contiguous())
# assert_close(p.grad, torch_p.grad.chunk(world_size, dim=split_dim)[rank])
else:
force_assign_grad(p, g_dtype, torch_p.grad)
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
@parameterize("tp_zero_size", [(4, 1), (1, 4), (2, 2)])
def run_dist_galore_basic(p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]) -> None:
@@ -158,7 +124,7 @@ def run_dist_galore_basic(p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_si
dist.get_rank(tp_group)
seed_all(_SEED) # Fix model init
torch_model = Net(in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=True, dtype=p_dtype).to(rank)
torch_model = Net(in_dim=_IN_DIM, hid_dim=_HID_DIM, dtype=p_dtype).to(rank)
tp_model = TPNet(torch_model.fc0, torch_model.fc1, torch_model.fc2, tp_group, dtype=p_dtype).to(rank)
assert_distributed_close(tp_model, torch_model, rtol=0, atol=0, tp_group=tp_group)
@@ -222,7 +188,7 @@ def run_dist_galore_fwd_bwd(p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_
seed_all(_SEED)
clear_layout_converter() # Ensure correct sharding
torch_model = Net(_IN_DIM, _HID_DIM, identity=True, dtype=p_dtype).to(rank)
torch_model = Net(_IN_DIM, _HID_DIM, dtype=p_dtype).to(rank)
tp_model = TPNet(torch_model.fc0, torch_model.fc1, torch_model.fc2, tp_group, dtype=p_dtype).to(rank)
assert_distributed_close(tp_model, torch_model, rtol=0, atol=0, tp_group=tp_group)