[shardformer] support DDP in HybridPlugin/add tp+dp tests (#4446)

* support DDP for HybridPlugin/add tp+dp tests

* add docstring for HybridParallelPlugin
This commit is contained in:
Baizhou Zhang
2023-08-16 16:11:57 +08:00
committed by GitHub
parent 424629fea0
commit 6ef33f75aa
10 changed files with 199 additions and 100 deletions

View File

@@ -1,5 +1,6 @@
import pytest
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.logging import disable_existing_loggers
@@ -14,6 +15,7 @@ from tests.test_shardformer.test_model._utils import (
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@@ -48,8 +50,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
t5 = org_model
sharded_t5 = sharded_model.unwrap()
t5 = unwrap_model(org_model)
sharded_t5 = unwrap_model(sharded_model)
row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
@@ -99,17 +101,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}])
@clear_cache_before_run()
def run_t5_test(test_config):
# TODO(baizhou): add plugin_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
# TODO(baizhou): add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():