[pipeline] rewrite t5 tests & support multi-tensor transmitting in pipeline (#4388)

* fix remaining t5 bugs/rewrite t5 tests

* fix multi-tensor communication in pipeline

* rearrange test_config

* fix keyerror in sync_shared_params

* fix get_held_layers & Randomnizer, complete t5 tests

* erase printing

* fix get_held_layers through modifying _release_unheld_layers

* fix _get_recursive_held_layers bug
This commit is contained in:
Baizhou Zhang
2023-08-08 17:46:44 +08:00
committed by Hongxin Liu
parent 906426cb44
commit ed4c448488
11 changed files with 196 additions and 246 deletions

View File

@@ -68,16 +68,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
torch.cuda.empty_cache()
@parameterize('test_config', [{
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'use_lazy_init': True
}, {
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': False,
'enable_fused_normalization': True,
'use_lazy_init': True
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'use_lazy_init': False
}, {
'tp_size': 4,

View File

@@ -1,60 +1,110 @@
import os
import pytest
import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
# the value "past_key_values" is sharded, so we ignore
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], atol=1e-5)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# do backward
org_loss.backward()
shard_loss.backward()
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)
# check grad
col_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.q', 'shared']
row_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.relative_attention_bias']
check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check weights are tied
if hasattr(org_model, 'lm_head'):
assert org_model.shared.weight.data.data_ptr() == org_model.lm_head.weight.data.data_ptr()
assert sharded_model.shared.weight.data.data_ptr() == sharded_model.lm_head.weight.data.data_ptr()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if org_model.__class__.__name__ != 'T5ForConditionalGeneration':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3)
# unwrap model
t5 = org_model
sharded_t5 = sharded_model.unwrap()
row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
# check weights and gradients
if stage_manager is None or stage_manager.is_first_stage():
check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-5, rtol=1e-3, dim=0)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False)
torch.cuda.empty_cache()
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('use_lazy_init', [False, True])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init, enable_flash_attention,
enable_jit_fused):
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_fused_normalization': True,
'use_lazy_init': True
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'use_lazy_init': False
}, {
'tp_size': 4,
'pp_size': 1,
'enable_fused_normalization': True,
'use_lazy_init': False
}, {
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'use_lazy_init': False
}])
@clear_cache_before_run()
def run_t5_test(test_config):
# TODO: add plugin_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
# TODO: add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention, enable_jit_fused, use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
# skip 4-stage pp test for t5_encoder
if test_config['pp_size'] > 2 and name == 'transformers_t5_encoder_model':
continue
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@@ -68,7 +118,7 @@ def check_t5(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_t5():
spawn(check_t5, 2)
spawn(check_t5, 4)
if __name__ == "__main__":

View File

@@ -1,101 +0,0 @@
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.t5 import T5BasePolicy
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_pipeline_model
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# TODO: add tests for forward/backward later
pass
@parameterize('enable_tensor_parallelism', [False])
@parameterize('enable_fused_normalization', [False])
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_t5.py
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
input_ids = inputs['input_ids']
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
batch_size, seq_len = input_ids.shape
hidden_size = sharded_model.config.d_model
num_heads = sharded_model.config.num_heads
hidden_state_shape = (batch_size, seq_len, hidden_size)
position_bias_shape = (batch_size, num_heads, seq_len, seq_len)
num_encoder_layers = len(sharded_model.encoder.block)
decoder = sharded_model.__dict__.get('decoder', None)
num_decoder_layers = len(decoder.block) if decoder else 0
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(num_encoder_layers, num_decoder_layers, PP_SIZE)
stage = stage_manager.stage
at_first_stage = (stage == 0) or (stage == decoder_starting_stage)
at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1)
in_decoder = stage >= decoder_starting_stage
if not at_first_stage:
# change inputs if not the first stage
hidden_states = torch.zeros(*hidden_state_shape).cuda()
position_bias = torch.zeros(*position_bias_shape).cuda()
encoder_decoder_position_bias = torch.zeros(*position_bias_shape).cuda()
inputs['input_ids'] = None
inputs['hidden_states'] = hidden_states
inputs['position_bias'] = position_bias
inputs['encoder_decoder_position_bias'] = encoder_decoder_position_bias
if in_decoder:
encoder_output_states = torch.zeros(*hidden_state_shape).cuda()
inputs['encoder_outputs'] = (encoder_output_states,)
sharded_model.train()
output = sharded_model(**inputs)
if at_last_stage:
if name == 'transformers_t5_for_conditional_generation' and in_decoder:
assert output.loss is not None
else:
if name != 'transformers_t5_encoder_model' and not in_decoder:
output = output['encoder_outputs']
assert output[0].shape == hidden_state_shape
else:
assert output['hidden_states'].shape == hidden_state_shape
# position_bias information should be passed in T5
assert output['position_bias'].shape == position_bias_shape
if in_decoder:
assert output['encoder_decoder_position_bias'].shape == position_bias_shape
torch.cuda.empty_cache()
def check_t5(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_t5_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_t5():
spawn(check_t5, 4)
if __name__ == "__main__":
test_t5()