mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[shardformer] support DDP in HybridPlugin/add tp+dp tests (#4446)
* support DDP for HybridPlugin/add tp+dp tests * add docstring for HybridParallelPlugin
This commit is contained in:
@@ -3,7 +3,6 @@ import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
@@ -15,6 +14,7 @@ from tests.test_shardformer.test_model._utils import (
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
||||
|
||||
@@ -48,16 +48,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
def unwrap(module):
|
||||
if isinstance(module, HybridParallelModule):
|
||||
module = module.unwrap()
|
||||
if module.__class__.__name__ == 'GPT2Model':
|
||||
return module
|
||||
return module.transformer
|
||||
|
||||
# unwrap model
|
||||
gpt2 = unwrap(org_model)
|
||||
sharded_gpt2 = unwrap(sharded_model)
|
||||
gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer')
|
||||
sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer')
|
||||
|
||||
col_layer_for_check = ['h[0].mlp.c_fc']
|
||||
row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
|
||||
@@ -106,6 +99,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
}, {
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
@@ -117,8 +116,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
@clear_cache_before_run()
|
||||
def run_gpt2_test(test_config):
|
||||
|
||||
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
|
Reference in New Issue
Block a user