mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 13:59:08 +00:00
[CI] Cleanup Dist Optim tests with shared helper funcs (#6125)
* Refractor and cleanup using common helper funcs. Tests passed * Update comments * Fix relative import * Fix param fetching bug
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
@@ -17,7 +16,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_optimizer._utils import check_optim_states, run_bert_test
|
||||
from tests.test_optimizer._utils import check_optim_states, run_bert_test, set_dist_grad
|
||||
|
||||
_ALLOWED_P_G_TYPES = [
|
||||
(torch.float, torch.float), # pure fp32
|
||||
@@ -109,39 +108,6 @@ def force_assign_grad(p, g_dtype, grad=None):
|
||||
p.data = orig_p
|
||||
|
||||
|
||||
def set_dist_grad(
|
||||
dist_module: nn.Module,
|
||||
torch_model: nn.Module,
|
||||
g_dtype: torch.dtype,
|
||||
group: dist.ProcessGroup,
|
||||
) -> None:
|
||||
"""
|
||||
Set grads chunks for Tensor Parallel or ZeRO DP.
|
||||
We do not need a separate treatment for ZeRO,
|
||||
as the LowLevelOptimizer takes care of reduce-scattering grads.
|
||||
"""
|
||||
rank = dist.get_rank(group)
|
||||
world_size = dist.get_world_size(group)
|
||||
|
||||
for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()):
|
||||
if torch_p.grad is None:
|
||||
# avoid inconsistent grad and param dtype error
|
||||
force_assign_grad(torch_p, g_dtype)
|
||||
else:
|
||||
torch_p.grad += torch.randn_like(torch_p, device=torch_p.device, dtype=g_dtype)
|
||||
|
||||
if p.grad is None:
|
||||
force_assign_grad(p, g_dtype)
|
||||
|
||||
if is_distributed_tensor(p):
|
||||
split_dim = get_shard_dim_1d(p)
|
||||
# Add grads only to the correctly split chunk
|
||||
force_assign_grad(p, g_dtype, torch_p.grad.chunk(world_size, dim=split_dim)[rank].contiguous())
|
||||
# assert_close(p.grad, torch_p.grad.chunk(world_size, dim=split_dim)[rank])
|
||||
else:
|
||||
force_assign_grad(p, g_dtype, torch_p.grad)
|
||||
|
||||
|
||||
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
|
||||
@parameterize("tp_zero_size", [(4, 1), (1, 4), (2, 2)])
|
||||
def run_dist_galore_basic(p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]) -> None:
|
||||
@@ -158,7 +124,7 @@ def run_dist_galore_basic(p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_si
|
||||
|
||||
dist.get_rank(tp_group)
|
||||
seed_all(_SEED) # Fix model init
|
||||
torch_model = Net(in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=True, dtype=p_dtype).to(rank)
|
||||
torch_model = Net(in_dim=_IN_DIM, hid_dim=_HID_DIM, dtype=p_dtype).to(rank)
|
||||
tp_model = TPNet(torch_model.fc0, torch_model.fc1, torch_model.fc2, tp_group, dtype=p_dtype).to(rank)
|
||||
assert_distributed_close(tp_model, torch_model, rtol=0, atol=0, tp_group=tp_group)
|
||||
|
||||
@@ -222,7 +188,7 @@ def run_dist_galore_fwd_bwd(p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_
|
||||
|
||||
seed_all(_SEED)
|
||||
clear_layout_converter() # Ensure correct sharding
|
||||
torch_model = Net(_IN_DIM, _HID_DIM, identity=True, dtype=p_dtype).to(rank)
|
||||
torch_model = Net(_IN_DIM, _HID_DIM, dtype=p_dtype).to(rank)
|
||||
tp_model = TPNet(torch_model.fc0, torch_model.fc1, torch_model.fc2, tp_group, dtype=p_dtype).to(rank)
|
||||
assert_distributed_close(tp_model, torch_model, rtol=0, atol=0, tp_group=tp_group)
|
||||
|
||||
|
Reference in New Issue
Block a user