diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 35a88d1e8..a22bdb719 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -37,7 +37,8 @@ class HybridParallelModule(ModelWrapper): self.shared_param_process_groups = [] for shared_param in self.shared_params: if len(shared_param) > 0: - self.stage_manager.init_process_group_by_stages(list(shared_param.keys())) + self.shared_param_process_groups.append( + self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))) if precision == 'fp16': module = module.half().cuda() elif precision == 'bf16': diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index a704310e1..73c210221 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -72,7 +72,9 @@ config = transformers.GPT2Config(n_layer=2, embd_pdrop=0, resid_pdrop=0, summary_first_dropout=0, - hidden_dropout=0) + hidden_dropout=0, + problem_type="single_label_classification", + pad_token_id=50256) config_for_token_classification = copy.deepcopy(config) config_for_token_classification.num_labels = 2 diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index e15295bc9..46b262d0a 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,11 +1,19 @@ import copy from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional import torch import torch.distributed as dist +from torch import Tensor +from torch import distributed as dist +from torch.distributed import ProcessGroup from torch.nn import Module +from torch.optim import Adam, Optimizer +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin from colossalai.lazy import LazyInitContext +from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer._utils import getattr_ from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor @@ -79,20 +87,151 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''): assert torch.equal(v, shard_v), f'{name} {k} value mismatch' -def check_grad(original_model, sharded_model, layer_suffix, atol=1e-5, rtol=1e-5, dim=0, verbose=False): +def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any]): + + use_lazy_init = False + if 'use_lazy_init' in test_config: + use_lazy_init = test_config.pop('use_lazy_init') + + if use_lazy_init: + ctx = LazyInitContext() + else: + ctx = nullcontext() + + plugin = HybridParallelPlugin(**test_config) + booster = Booster(plugin=plugin) + + with ctx: + org_model = model_fn().cuda() + sharded_model = copy.deepcopy(org_model) + + if use_lazy_init: + org_model = ctx.materialize(org_model) + + org_optimizer = Adam(org_model.parameters(), lr=1e-3) + sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3) + criterion = loss_fn + + sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) + + return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster + + +def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Module, sharded_optimizer: Optimizer, + data_gen_fn: Callable, output_transform_fn: Callable, criterion: Callable, + booster: Booster): + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + data = data_gen_fn() + sharded_model.train() + if booster.plugin.stage_manager is not None: + data = { + k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v + for k, v in data.items() + } + data_iter = iter([data]) + sharded_output = booster.execute_pipeline(data_iter, + sharded_model, + _criterion, + sharded_optimizer, + return_loss=True, + return_outputs=True) + sharded_loss = sharded_output['loss'] + else: + data = {k: v.cuda() for k, v in data.items()} + sharded_output = sharded_model(**data) + sharded_loss = criterion(sharded_output) + sharded_loss.backward() + + org_model.train() + org_output = org_model(**data) + org_loss = criterion(org_output) + org_loss.backward() + + return org_loss, org_output, sharded_loss, sharded_output + + +def check_output_hidden_state(org_output: Tensor, + sharded_output: Tensor, + stage_manager: Optional[PipelineStageManager] = None, + atol: float = 1e-5, + rtol: float = 1e-3): + + org_hidden_state = org_output.last_hidden_state + + if stage_manager is None: + sharded_hidden_state = sharded_output.last_hidden_state + + if stage_manager and stage_manager.is_last_stage(): + sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0) + + assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=atol, rtol=rtol), \ + f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" + + +def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3): + assert torch.allclose(org_loss, sharded_loss, atol=atol, rtol=rtol), \ + f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" + + +def check_weight(org_model: Module, + sharded_model: Module, + layer_suffix: List[str], + tp_group: Optional[ProcessGroup] = None, + dim: int = 0, + atol: float = 1e-5, + rtol: float = 1e-3, + verbose: bool = False): + for suffix in layer_suffix: - org_grad = getattr_(original_model, suffix).weight.grad + org_weight = getattr_(org_model, suffix).weight + sharded_weight = getattr_(sharded_model, suffix).weight + + if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): + sharded_weight_list = [ + torch.zeros([*sharded_weight.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group)) + ] + dist.all_gather(sharded_weight_list, sharded_weight, tp_group) + sharded_weight = torch.cat(sharded_weight_list, dim=dim) + + if verbose and dist.get_rank() == 0: + print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") + + assert torch.allclose(org_weight, sharded_weight, atol=atol, rtol=rtol), \ + f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}" + + +def check_grad(org_model: Module, + sharded_model: Module, + layer_suffix: List[str], + tp_group: ProcessGroup = None, + dim: int = 0, + atol: float = 1e-5, + rtol: float = 1e-3, + verbose: bool = False): + + for suffix in layer_suffix: + org_grad = getattr_(org_model, suffix).weight.grad shard_grad = getattr_(sharded_model, suffix).weight.grad shard_weight = getattr_(sharded_model, suffix).weight if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size())] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=dim) - else: - all_shard_grad = shard_grad + shard_grad_list = [ + torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group)) + ] + dist.all_gather(shard_grad_list, shard_grad, tp_group) + shard_grad = torch.cat(shard_grad_list, dim=dim) + + # embedding may be resized when using tensor parallel + if shard_grad.shape[0] > org_grad.shape[0]: + shard_grad = shard_grad[:org_grad.shape[0], :] + if verbose and dist.get_rank() == 0: - print(f"'{suffix}' grad: {org_grad}, {all_shard_grad}") + print(f"'{suffix}' grad: {org_grad}, {shard_grad}") assert torch.allclose( - org_grad, all_shard_grad, rtol=rtol, atol=atol - ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{all_shard_grad}" + org_grad, shard_grad, rtol=rtol, atol=atol + ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index d1ab352f6..cebb40bd1 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -1,107 +1,48 @@ -import copy -from contextlib import nullcontext - import pytest import torch from torch import distributed as dist -from torch.optim import Adam import colossalai -from colossalai.booster import Booster -from colossalai.booster.plugin import HybridParallelPlugin -from colossalai.lazy.lazy_init import LazyInitContext from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.api import ( - clear_layout_converter, - is_customized_distributed_tensor, - is_distributed_tensor, -) +from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): - use_lazy_init = False - if 'use_lazy_init' in test_config: - use_lazy_init = test_config.pop('use_lazy_init') + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - if use_lazy_init: - ctx = LazyInitContext() - else: - ctx = nullcontext() - # prepare booster - plugin = HybridParallelPlugin(**test_config) - booster = Booster(plugin=plugin) - stage_manager = plugin.stage_manager + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) - # prepare models and optimizers - with ctx: - org_model = model_fn().cuda() - sharded_model = copy.deepcopy(org_model) - - if use_lazy_init: - org_model = ctx.materialize(org_model) - - org_optimizer = Adam(org_model.parameters(), lr=1e-3) - sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3) - criterion = loss_fn - - sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) - - def _criterion(outputs, inputs): - outputs = output_transform_fn(outputs) - loss = criterion(outputs) - return loss - - # do forward and backward - data = data_gen_fn() - sharded_model.train() - if stage_manager: - data = { - k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v - for k, v in data.items() - } - data_iter = iter([data]) - sharded_output = booster.execute_pipeline(data_iter, - sharded_model, - _criterion, - sharded_optimizer, - return_loss=True, - return_outputs=True) - sharded_loss = sharded_output['loss'] - else: - data = {k: v.cuda() for k, v in data.items()} - sharded_output = sharded_model(**data) - sharded_loss = criterion(sharded_output) - sharded_loss.backward() - - org_model.train() - org_output = org_model(**data) - org_loss = criterion(org_output) - org_loss.backward() + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - # check last hidden state if org_model.__class__.__name__ == 'GPT2Model': - org_hidden_state = org_output.last_hidden_state + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) - if stage_manager is None: - sharded_hidden_state = sharded_output.last_hidden_state - - if stage_manager and stage_manager.is_last_stage(): - sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], - dim=0) - - assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=1e-5, rtol=1e-3), \ - f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" - - # check loss - assert torch.allclose(org_loss, sharded_loss, atol=1e-5, rtol=1e-3), \ - f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" + check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3) # unwrap model if org_model.__class__.__name__ == 'GPT2Model': @@ -111,27 +52,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, gpt2 = org_model.transformer sharded_gpt2 = sharded_model.unwrap().transformer - # check grad col_layer_for_check = ['h[0].mlp.c_fc'] - row_layer_for_check = ['h[0].mlp.c_proj'] - check_grad(gpt2, sharded_gpt2, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False) - check_grad(gpt2, sharded_gpt2, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False) + row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] + + # check grad + if stage_manager is None or stage_manager.is_first_stage(): + check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False) + check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False) # check weights after optimizer.step() org_optimizer.step() sharded_optimizer.step() if stage_manager is None or stage_manager.is_first_stage(): - - org_weight = org_model.h[0].mlp.c_fc.weight - shard_weight = sharded_model.h[0].mlp.c_fc.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_weight_list = [torch.zeros([*shard_weight.shape]).to('cuda') for _ in range(plugin.tp_size)] - dist.all_gather(shard_weight_list, shard_weight, plugin.tp_group) - shard_weight = torch.cat(shard_weight_list, dim=1) - - assert torch.allclose(org_weight, shard_weight, atol=5e-3, rtol=1e-3), \ - f"shard model weight is not equal to origin model weight\n{org_weight}\n{shard_weight}" + check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False) torch.cuda.empty_cache() @@ -156,9 +89,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @clear_cache_before_run() def run_gpt2_test(test_config): - # TODO: add plugin_config for TP+DP after supporting & debugging it + # TODO: add test_config for TP+DP after supporting & debugging it # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} + # TODO: add test_config for flash attention & jit operator after supporting + sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') test_config['precision'] = 'float' # Do not use fp16/bf16 in testing @@ -175,7 +110,6 @@ def check_gpt2(rank, world_size, port): run_gpt2_test() -@pytest.mark.skip('Have some bug caused by merge') @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run()