From ec73f1b5e21b7a6f62d397a38eade2091779633a Mon Sep 17 00:00:00 2001 From: Wenxuan Tan Date: Tue, 11 Feb 2025 23:42:34 -0600 Subject: [PATCH] [CI] Cleanup Dist Optim tests with shared helper funcs (#6125) * Refractor and cleanup using common helper funcs. Tests passed * Update comments * Fix relative import * Fix param fetching bug --- colossalai/shardformer/layer/linear.py | 6 +- tests/kit/model_zoo/custom/simple_mlp.py | 4 +- tests/test_optimizer/_utils.py | 85 ++++++++++++ tests/test_optimizer/test_adam_optim.py | 21 +-- tests/test_optimizer/test_dist_adafactor.py | 118 ++++------------ tests/test_optimizer/test_dist_came.py | 141 +++----------------- tests/test_optimizer/test_dist_galore.py | 40 +----- tests/test_optimizer/test_dist_lamb.py | 25 +--- 8 files changed, 142 insertions(+), 298 deletions(-) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index fe195d698..84ec3f6f0 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -384,7 +384,7 @@ class Linear1D_Row(ParallelModule): out_features (int): size of each output sample. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. dtype (`torch.dtype`): The dtype of parameters, defaults to None. - parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + parallel_input (bool): If set to ``True``, it's assumed that the input is already split/copied across each rank, defaults to False. process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None. seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence. @@ -544,14 +544,14 @@ class Linear1D_Row(ParallelModule): if self.parallel_input: assert ( input_.shape[-1] == self.weight.shape[-1] - ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected feature dim of input {}.".format( input_.shape, self.weight.shape, self.weight.shape[-1] ) input_ = input_ else: assert ( divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1] - ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected feature dim of input {}.".format( input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions ) input_ = split_forward_gather_backward( diff --git a/tests/kit/model_zoo/custom/simple_mlp.py b/tests/kit/model_zoo/custom/simple_mlp.py index e62369c33..aa85d3ae5 100644 --- a/tests/kit/model_zoo/custom/simple_mlp.py +++ b/tests/kit/model_zoo/custom/simple_mlp.py @@ -13,7 +13,7 @@ _HID_DIM = 128 class Net(nn.Module): - def __init__(self, in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=False, dtype=torch.float32): + def __init__(self, in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=True, dtype=torch.float32): super().__init__() if identity: self.fc0 = nn.Identity() @@ -30,7 +30,7 @@ class Net(nn.Module): class TPNet(nn.Module): def __init__( self, - fc0=nn.Linear(_IN_DIM, _IN_DIM), + fc0=nn.Identity(), fc1=nn.Linear(_IN_DIM, _HID_DIM), fc2=nn.Linear(_HID_DIM, _IN_DIM), tp_group=None, diff --git a/tests/test_optimizer/_utils.py b/tests/test_optimizer/_utils.py index 4046e4118..c780625a1 100644 --- a/tests/test_optimizer/_utils.py +++ b/tests/test_optimizer/_utils.py @@ -1,10 +1,13 @@ import torch import torch.distributed as dist +import torch.nn as nn from torch.testing import assert_close import colossalai from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor import get_layout, get_sharding_spec, is_distributed_tensor from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.tensor.d_tensor.sharding_spec import DimSpec from colossalai.testing import parameterize, spawn from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( @@ -15,6 +18,88 @@ from tests.test_shardformer.test_model._utils import ( ) +def force_assign_grad(p, g_dtype, grad=None): + """Bypass inconsistent grad and param dtype error when assigning grad""" + orig_p = p.data + p.data = torch.randn_like(p, device=orig_p.device, dtype=g_dtype) if grad == None else grad.clone().to(g_dtype) + p.grad = p.data + p.data = orig_p + + +def setup_param_groups(model: nn.Module) -> list: + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": 0.1, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + + +# setup flatten param groups, sharding spec and shape; (For dist Adafactor and CAME) +def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict: + flatten_optimizer_grouped_parameters = [] + sharding_spec = {} # {id(flatten param): get_layout(p).global_shape} + param_shape = {} # {id(flatten param): get_sharding_spec(p)} + for n, p in model.named_parameters(): + # flatten_p = copy.deepcopy(p).flatten() + flatten_p = nn.Parameter(p.clone().flatten().requires_grad_(True)) + flatten_optimizer_grouped_parameters.append(flatten_p) + if is_distributed_tensor(p): + sharding_spec[id(flatten_p)] = get_sharding_spec(p) + param_shape[id(flatten_p)] = get_layout(p).global_shape + else: + sharding_spec[id(flatten_p)] = None + param_shape[id(flatten_p)] = p.shape + return flatten_optimizer_grouped_parameters, sharding_spec, param_shape + + +def set_master_param_to_shard_param(master_param_list) -> dict: + master_param_to_shard_param = {id(p): p for p in master_param_list} + return master_param_to_shard_param + + +def set_dist_grad( + dist_module: nn.Module, + torch_model: nn.Module, + g_dtype: torch.dtype, + group: dist.ProcessGroup, + tp_spec: DimSpec, +) -> None: + """ + Set split grads for Tensor Parallel or ZeRO DP. + We do not need a separate treatment for ZeRO, + as the wrapper takes care of reduce-scattering grads. + """ + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()): + if torch_p.grad is None: + torch_p.grad = torch.zeros_like(torch_p) + + is_distributed = hasattr(p, "dist_layout") + if is_distributed: + sharding = p.dist_layout.sharding_spec.sharding_sequence + split_dim = sharding.index(tp_spec) + shape = torch_p.split(world_size, dim=split_dim)[rank].shape + + indices = torch.arange(shape[split_dim] * rank, shape[split_dim] * (rank + 1)) + # Generate grads only for the correctly split chunk + torch_p.grad.index_add_(split_dim, indices, torch.randn(shape, device=torch_p.device, dtype=g_dtype)) + + else: + shape = torch_p.shape + torch_p.grad += torch.randn(shape, device=torch_p.device, dtype=g_dtype) + + force_assign_grad(p, g_dtype, grad=torch_p.grad) + + def check_optim_states(org_optim, sharded_optim): for group in org_optim.param_groups: for p in group["params"]: diff --git a/tests/test_optimizer/test_adam_optim.py b/tests/test_optimizer/test_adam_optim.py index 68d71e3c4..147f1f7b9 100644 --- a/tests/test_optimizer/test_adam_optim.py +++ b/tests/test_optimizer/test_adam_optim.py @@ -8,6 +8,7 @@ from torch.optim import Adam, AdamW from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam from tests.kit.model_zoo import model_zoo +from tests.test_optimizer._utils import force_assign_grad, setup_param_groups _ALLOWED_OPTIM_DEVICES = [ (FusedAdam, torch.device("cuda:0")), @@ -26,29 +27,11 @@ _ALLOWED_P_G_TYPES = [ N_STEPS = 3 -def setup_param_groups(bert_model: nn.Module) -> list: - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)], - "weight_decay": 0.1, - }, - { - "params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - return optimizer_grouped_parameters - - def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None: for p, torch_p in zip(model.parameters(), torch_model.parameters()): torch_p.grad = torch.rand_like(torch_p) # avoid inconsistent grad and param dtype error - orig_p = p.data - p.data = torch_p.grad.clone().to(g_dtype) - p.grad = p.data - p.data = orig_p + force_assign_grad(p, g_dtype, torch_p.grad) @pytest.mark.parametrize("optim_cls, device", _ALLOWED_OPTIM_DEVICES) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 2da679d7d..ad93b5310 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -1,5 +1,3 @@ -import copy - import pytest import torch import torch.distributed as dist @@ -16,7 +14,6 @@ from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor import ( distribute_tensor, get_device_mesh, - get_layout, get_sharding_spec, is_distributed_tensor, shard_colwise, @@ -28,7 +25,13 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed from colossalai.zero import LowLevelZeroOptimizer from tests.kit.model_zoo import model_zoo -from tests.test_optimizer._utils import check_dist_optim_state, check_dist_param, check_optim_states +from tests.test_optimizer._utils import ( + check_dist_optim_state, + check_dist_param, + check_optim_states, + set_master_param_to_shard_param, + setup_param_groups, +) from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, build_model_from_low_level_zero_plugin, @@ -38,10 +41,13 @@ from tests.test_shardformer.test_model._utils import ( unwrap_model, ) -HEIGHT = 4 -WIDTH = 4 +IN_DIM = 4 +HID_DIM = 4 _TP_SPEC = DimSpec([0]) +Net, data_gen, *_ = next(iter(model_zoo.get_sub_registry("simple_mlp").values())) +TPNet, *_ = next(iter(model_zoo.get_sub_registry("simple_tp_mlp").values())) + def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32): rtol = None @@ -59,92 +65,11 @@ def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torc assert_close(tensor1, tensor2, rtol=rtol, atol=atol) -# setup param groups; (For zero test optim) -def setup_param_groups_zero(model: nn.Module) -> list: - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], - "weight_decay": 0.1, - }, - { - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - return optimizer_grouped_parameters - - -# setup param groups; (For base optim) -def setup_param_groups(model: nn.Module) -> list: - optimizer_grouped_parameters = [p for n, p in model.named_parameters()] - return optimizer_grouped_parameters - - -# setup flatten param groups, sharding spec and shape; (For dist optim) -def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict: - flatten_optimizer_grouped_parameters = [] - sharding_spec = {} # {id(flatten param): get_layout(p).global_shape} - param_shape = {} # {id(flatten param): get_sharding_spec(p)} - for n, p in model.named_parameters(): - # flatten_p = copy.deepcopy(p).flatten() - flatten_p = nn.Parameter(p.clone().flatten().requires_grad_(True)) - flatten_optimizer_grouped_parameters.append(flatten_p) - if is_distributed_tensor(p): - sharding_spec[id(flatten_p)] = get_sharding_spec(p) - param_shape[id(flatten_p)] = get_layout(p).global_shape - else: - sharding_spec[id(flatten_p)] = None - param_shape[id(flatten_p)] = p.shape - return flatten_optimizer_grouped_parameters, sharding_spec, param_shape - - -def set_dist_grad( - dist_module: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype, group: dist.ProcessGroup -) -> None: - """ - Set split grads for Tensor Parallel or ZeRO DP. - We do not need a separate treatment for ZeRO, - as the wrapper takes care of reduce-scattering grads. - """ - rank = dist.get_rank(group) - world_size = dist.get_world_size(group) - - for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()): - if torch_p.grad is None: - torch_p.grad = torch.zeros_like(torch_p) - - is_distributed = hasattr(p, "dist_layout") - if is_distributed: - sharding = p.dist_layout.sharding_spec.sharding_sequence - split_dim = sharding.index(_TP_SPEC) - shape = torch_p.split(world_size, dim=split_dim)[rank].shape - - indices = torch.arange(shape[split_dim] * rank, shape[split_dim] * (rank + 1)) - # Generate grads only for the correctly split chunk - torch_p.grad.index_add_(split_dim, indices, torch.randn(shape, device=torch_p.device, dtype=g_dtype)) - - else: - shape = torch_p.shape - torch_p.grad += torch.randn(shape, device=torch_p.device, dtype=g_dtype) - - # avoid inconsistent grad and param dtype error - orig_p = p.data - p.data = torch_p.grad.clone().to(g_dtype) - p.grad = p.data - p.data = orig_p - - -def set_master_param_to_shard_param(master_param_list) -> dict: - master_param_to_shard_param = {id(p): p for p in master_param_list} - return master_param_to_shard_param - - class MlpModel(nn.Module): def __init__(self): super(MlpModel, self).__init__() - self.linear1 = nn.Linear(HEIGHT, WIDTH) - self.linear2 = nn.Linear(WIDTH, HEIGHT) + self.linear1 = nn.Linear(IN_DIM, HID_DIM) + self.linear2 = nn.Linear(HID_DIM, IN_DIM) def forward(self, x): x = self.linear1(x) @@ -182,7 +107,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # ============================== # Base Case # ============================== - H, W = HEIGHT, WIDTH + H, W = IN_DIM, HID_DIM model_col = nn.Linear(H, W).to(local_rank) # Col parallel weight weight, bias = model_col.weight, model_col.bias @@ -284,8 +209,11 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # ============================== # Model Init # ============================== - base_model = MlpModel().to(local_rank) - tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) + # base_model = MlpModel().to(local_rank) + # tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) + base_model = Net(in_dim=IN_DIM, hid_dim=HID_DIM, dtype=dtype).to(local_rank) + # Must specify dtype; TPNet init seem to run out of set_default_dtype scope + tp_model = TPNet(fc1=base_model.fc1, fc2=base_model.fc2, tp_group=tp_group, dtype=dtype) base_param_group = setup_param_groups(base_model) tp_param_group = setup_param_groups(tp_model) @@ -335,7 +263,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # ============================== # Correctness Verify # ============================== - x = torch.randn(HEIGHT, WIDTH, device=local_rank) + x = torch.randn(IN_DIM, HID_DIM, device=local_rank) out = base_model(x) out_tp = tp_model(x) @@ -353,7 +281,9 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): base_optim.zero_grad() dist_optim.zero_grad() - for p, tp_p in zip(base_param_group, tp_param_group): + base_params = base_model.parameters() + tp_params = tp_model.parameters() + for p, tp_p in zip(base_params, tp_params): param_is_distributed = is_distributed_tensor(tp_p) if param_is_distributed: shard_spec = get_sharding_spec(tp_p) diff --git a/tests/test_optimizer/test_dist_came.py b/tests/test_optimizer/test_dist_came.py index 45fe687b7..d662bc674 100644 --- a/tests/test_optimizer/test_dist_came.py +++ b/tests/test_optimizer/test_dist_came.py @@ -1,9 +1,6 @@ -import copy - import pytest import torch import torch.distributed as dist -from torch import nn from torch.testing import assert_close import colossalai @@ -11,17 +8,23 @@ from colossalai.cluster import ProcessGroupMesh from colossalai.logging import disable_existing_loggers from colossalai.nn.optimizer.came import CAME from colossalai.nn.optimizer.distributed_came import DistributedCAME -from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row from colossalai.shardformer.layer._operation import _gather from colossalai.shardformer.layer.utils import Randomizer -from colossalai.tensor.d_tensor import get_layout, get_sharding_spec, is_distributed_tensor +from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.tensor.d_tensor.sharding_spec import DimSpec from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from colossalai.zero import LowLevelZeroOptimizer from tests.kit.model_zoo import model_zoo -from tests.test_optimizer._utils import check_dist_grad, check_dist_optim_state, check_dist_param, check_optim_states +from tests.test_optimizer._utils import ( + check_dist_grad, + check_dist_optim_state, + check_dist_param, + check_optim_states, + set_master_param_to_shard_param, + setup_param_groups, +) from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, build_model_from_low_level_zero_plugin, @@ -30,10 +33,12 @@ from tests.test_shardformer.test_model._utils import ( unwrap_model, ) -HEIGHT = 128 -WIDTH = 128 +IN_DIM = 128 +HID_DIM = 128 _TP_SPEC = DimSpec([0]) _SEED = 0 +Net, data_gen, *_ = next(iter(model_zoo.get_sub_registry("simple_mlp").values())) +TPNet, *_ = next(iter(model_zoo.get_sub_registry("simple_tp_mlp").values())) def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32): @@ -53,112 +58,6 @@ def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torc assert_close(tensor1, tensor2, rtol=rtol, atol=atol) -# setup param groups; (For zero test optim) -def setup_param_groups_zero(model: nn.Module) -> list: - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], - "weight_decay": 0.1, - }, - { - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - return optimizer_grouped_parameters - - -# setup param groups; (For base optim) -def setup_param_groups(model: nn.Module) -> list: - optimizer_grouped_parameters = [p for n, p in model.named_parameters()] - return optimizer_grouped_parameters - - -# setup flatten param groups, sharding spec and shape; (For dist optim) -def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict: - flatten_optimizer_grouped_parameters = [] - sharding_spec = {} # {id(flatten param): get_layout(p).global_shape} - param_shape = {} # {id(flatten param): get_sharding_spec(p)} - for n, p in model.named_parameters(): - flatten_p = nn.Parameter(p.clone().flatten().requires_grad_(True)) - flatten_optimizer_grouped_parameters.append(flatten_p) - if is_distributed_tensor(p): - sharding_spec[id(flatten_p)] = get_sharding_spec(p) - param_shape[id(flatten_p)] = get_layout(p).global_shape - else: - sharding_spec[id(flatten_p)] = None - param_shape[id(flatten_p)] = p.shape - return flatten_optimizer_grouped_parameters, sharding_spec, param_shape - - -def set_dist_grad( - dist_module: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype, group: dist.ProcessGroup -) -> None: - """ - Set split grads for Tensor Parallel or ZeRO DP. - We do not need a separate treatment for ZeRO, - as the wrapper takes care of reduce-scattering grads. - """ - rank = dist.get_rank(group) - world_size = dist.get_world_size(group) - - for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()): - if torch_p.grad is None: - torch_p.grad = torch.zeros_like(torch_p) - - is_distributed = hasattr(p, "dist_layout") - if is_distributed: - sharding = p.dist_layout.sharding_spec.sharding_sequence - split_dim = sharding.index(_TP_SPEC) - shape = torch_p.split(world_size, dim=split_dim)[rank].shape - - indices = torch.arange(shape[split_dim] * rank, shape[split_dim] * (rank + 1)) - # Generate grads only for the correctly split chunk - torch_p.grad.index_add_(split_dim, indices, torch.randn(shape, device=torch_p.device, dtype=g_dtype)) - - else: - shape = torch_p.shape - torch_p.grad += torch.randn(shape, device=torch_p.device, dtype=g_dtype) - - # avoid inconsistent grad and param dtype error - orig_p = p.data - p.data = torch_p.grad.clone().to(g_dtype) - p.grad = p.data - p.data = orig_p - - -def set_master_param_to_shard_param(master_param_list) -> dict: - master_param_to_shard_param = {id(p): p for p in master_param_list} - return master_param_to_shard_param - - -class MlpModel(nn.Module): - def __init__(self): - super(MlpModel, self).__init__() - self.linear1 = nn.Linear(HEIGHT, WIDTH) - self.linear2 = nn.Linear(WIDTH, HEIGHT) - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - return x - - -class TPModel(nn.Module): - def __init__(self, linear1, linear2, tp_group=None): - super().__init__() - self.linear1 = Linear1D_Col.from_native_module( - linear1, process_group=tp_group, gather_output=False, overlap=True - ) - self.linear2 = Linear1D_Row.from_native_module(linear2, process_group=tp_group, parallel_input=True) - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - return x - - @parameterize("dtype", [torch.float32]) # torch.float32, torch.float16, torch.bfloat16 @parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)]) # (4, 1), (1, 4) def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): @@ -177,12 +76,13 @@ def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # ============================== # Model Init # ============================== - base_model = MlpModel().to(local_rank) - tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) + base_model = Net(in_dim=IN_DIM, hid_dim=HID_DIM, dtype=dtype).to(local_rank) + # tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) + tp_model = TPNet(fc1=base_model.fc1, fc2=base_model.fc2, tp_group=tp_group, dtype=dtype) base_param_group = setup_param_groups(base_model) tp_param_group = setup_param_groups(tp_model) - tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model) + # tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model) # ============================== # Optimizer Init @@ -220,7 +120,7 @@ def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # Correctness Verify # ============================== seed_all(1024) - x = torch.randn(WIDTH, HEIGHT, device=local_rank) + x = torch.randn(HID_DIM, IN_DIM, device=local_rank) out = base_model(x) out_tp = tp_model(x) @@ -238,7 +138,9 @@ def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): base_optim.zero_grad() dist_optim.zero_grad() - for p, tp_p in zip(base_param_group, tp_param_group): + base_params = base_model.parameters() + tp_params = tp_model.parameters() + for p, tp_p in zip(base_params, tp_params): param_is_distributed = is_distributed_tensor(tp_p) if param_is_distributed: shard_spec = get_sharding_spec(tp_p) @@ -256,6 +158,7 @@ def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # No TP bias pass correctness_verify(p.data, tp_p.data, dtype) + clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache() diff --git a/tests/test_optimizer/test_dist_galore.py b/tests/test_optimizer/test_dist_galore.py index 95193accb..8a388eaa7 100644 --- a/tests/test_optimizer/test_dist_galore.py +++ b/tests/test_optimizer/test_dist_galore.py @@ -3,7 +3,6 @@ import pytest import torch import torch.distributed as dist -import torch.nn as nn from torch.testing import assert_close import colossalai @@ -17,7 +16,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from colossalai.zero import LowLevelZeroOptimizer from tests.kit.model_zoo import model_zoo -from tests.test_optimizer._utils import check_optim_states, run_bert_test +from tests.test_optimizer._utils import check_optim_states, run_bert_test, set_dist_grad _ALLOWED_P_G_TYPES = [ (torch.float, torch.float), # pure fp32 @@ -109,39 +108,6 @@ def force_assign_grad(p, g_dtype, grad=None): p.data = orig_p -def set_dist_grad( - dist_module: nn.Module, - torch_model: nn.Module, - g_dtype: torch.dtype, - group: dist.ProcessGroup, -) -> None: - """ - Set grads chunks for Tensor Parallel or ZeRO DP. - We do not need a separate treatment for ZeRO, - as the LowLevelOptimizer takes care of reduce-scattering grads. - """ - rank = dist.get_rank(group) - world_size = dist.get_world_size(group) - - for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()): - if torch_p.grad is None: - # avoid inconsistent grad and param dtype error - force_assign_grad(torch_p, g_dtype) - else: - torch_p.grad += torch.randn_like(torch_p, device=torch_p.device, dtype=g_dtype) - - if p.grad is None: - force_assign_grad(p, g_dtype) - - if is_distributed_tensor(p): - split_dim = get_shard_dim_1d(p) - # Add grads only to the correctly split chunk - force_assign_grad(p, g_dtype, torch_p.grad.chunk(world_size, dim=split_dim)[rank].contiguous()) - # assert_close(p.grad, torch_p.grad.chunk(world_size, dim=split_dim)[rank]) - else: - force_assign_grad(p, g_dtype, torch_p.grad) - - @parameterize("p_g_dtype", _ALLOWED_P_G_TYPES) @parameterize("tp_zero_size", [(4, 1), (1, 4), (2, 2)]) def run_dist_galore_basic(p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]) -> None: @@ -158,7 +124,7 @@ def run_dist_galore_basic(p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_si dist.get_rank(tp_group) seed_all(_SEED) # Fix model init - torch_model = Net(in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=True, dtype=p_dtype).to(rank) + torch_model = Net(in_dim=_IN_DIM, hid_dim=_HID_DIM, dtype=p_dtype).to(rank) tp_model = TPNet(torch_model.fc0, torch_model.fc1, torch_model.fc2, tp_group, dtype=p_dtype).to(rank) assert_distributed_close(tp_model, torch_model, rtol=0, atol=0, tp_group=tp_group) @@ -222,7 +188,7 @@ def run_dist_galore_fwd_bwd(p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_ seed_all(_SEED) clear_layout_converter() # Ensure correct sharding - torch_model = Net(_IN_DIM, _HID_DIM, identity=True, dtype=p_dtype).to(rank) + torch_model = Net(_IN_DIM, _HID_DIM, dtype=p_dtype).to(rank) tp_model = TPNet(torch_model.fc0, torch_model.fc1, torch_model.fc2, tp_group, dtype=p_dtype).to(rank) assert_distributed_close(tp_model, torch_model, rtol=0, atol=0, tp_group=tp_group) diff --git a/tests/test_optimizer/test_dist_lamb.py b/tests/test_optimizer/test_dist_lamb.py index 615c5c33c..390eb9642 100644 --- a/tests/test_optimizer/test_dist_lamb.py +++ b/tests/test_optimizer/test_dist_lamb.py @@ -14,7 +14,7 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad from colossalai.testing.random import seed_all from colossalai.zero import LowLevelZeroOptimizer from tests.kit.model_zoo import model_zoo -from tests.test_optimizer._utils import check_optim_states, run_bert_test +from tests.test_optimizer._utils import check_optim_states, force_assign_grad, run_bert_test, setup_param_groups _ALLOWED_P_G_TYPES = [ (torch.float, torch.float), # pure fp32 @@ -49,29 +49,6 @@ def assert_distributed_close(tp_model, torch_model, rtol, atol, tp_group): raise e -def setup_param_groups(bert_model: nn.Module) -> list: - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)], - "weight_decay": 0.1, - }, - { - "params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - return optimizer_grouped_parameters - - -def force_assign_grad(p, g_dtype, grad=None): - """avoid inconsistent grad and param dtype error""" - orig_p = p.data - p.data = torch.randn_like(p, device=orig_p.device, dtype=g_dtype) if grad == None else grad - p.grad = p.data - p.data = orig_p - - def set_dist_grad( dist_module: nn.Module, torch_model: nn.Module,