[shardformer] Pipeline/whisper (#4456)

* add some base tests and policies

* finish whisper base model

* add conditional generation

* finish basic tests

* whisper

* finish whisper

* finish whisper

* del useless  whisper test

* fix

* add argmin to replace

* finish revision
This commit is contained in:
Jianghai
2023-08-18 21:29:25 +08:00
committed by GitHub
parent a27e0bb494
commit 8739aa7fa0
8 changed files with 1156 additions and 53 deletions

View File

@@ -0,0 +1,39 @@
from colossalai.shardformer.policies.t5 import T5BasePolicy
def test_t5_pipeline_distribution():
num_test_cases = 8
test_dict = {
'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
}
for i in range(num_test_cases):
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i],
test_dict['num_decoder_layers'][i],
test_dict['num_stages'][i])
assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
def test_t5_pipeline_layers():
num_test_cases = 4
test_dict = {
'num_encoder_layers': [2, 3, 2, 4],
'num_decoder_layers': [2, 0, 2, 8],
'num_stages': [2, 2, 4, 4],
'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
[[0, 4], [0, 3], [3, 6], [6, 8]]]
}
for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
for stage in range(test_dict['num_stages'][i]):
start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage,
decoder_starting_stage)
assert start_idx == predicted_start
assert end_idx == predicted_end

View File

@@ -0,0 +1,44 @@
from colossalai.shardformer.policies.whisper import WhisperPolicy
def test_whisper_pipeline_distribution():
num_test_cases = 8
test_dict = {
'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
}
for i in range(num_test_cases):
_, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(test_dict['num_encoder_layers'][i],
test_dict['num_decoder_layers'][i],
test_dict['num_stages'][i])
assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
def test_whisper_pipeline_layers():
num_test_cases = 4
test_dict = {
'num_encoder_layers': [2, 3, 2, 4],
'num_decoder_layers': [2, 0, 2, 8],
'num_stages': [2, 2, 4, 4],
'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
[[0, 4], [0, 3], [3, 6], [6, 8]]]
}
for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
for stage in range(test_dict['num_stages'][i]):
start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage,
decoder_starting_stage)
assert start_idx == predicted_start
assert end_idx == predicted_end
if __name__ == '__main__':
test_whisper_pipeline_distribution()
test_whisper_pipeline_layers()

View File

@@ -6,6 +6,7 @@ from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
@@ -143,6 +144,7 @@ def run_llama_test(test_config):
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()

View File

@@ -3,6 +3,8 @@ import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
@@ -11,55 +13,145 @@ from colossalai.testing import (
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys='past_key_values', atol=1e-5)
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
# do backward
org_loss.backward()
shard_loss.backward()
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'WhisperModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwarp the model
if org_model.__class__.__name__ == 'WhisperForConditionalGeneration':
whisper = org_model.model
sharded_whisper = sharded_model.model
sharded_whisper = sharded_model.unwrap().model
else:
whisper = org_model
sharded_whisper = sharded_model
sharded_whisper = sharded_model.unwrap()
# check grad
if org_model.__class__.__name__ == 'WhisperForAudioClassification':
col_layer_for_check = ['encoder.layers[0].self_attn.q_proj']
row_layer_for_check = ['encoder.layers[0].self_attn.out_proj']
else:
col_layer_for_check = ['encoder.layers[0].self_attn.q_proj', 'decoder.layers[0].self_attn.q_proj']
row_layer_for_check = ['encoder.layers[0].self_attn.out_proj', 'decoder.layers[0].self_attn.out_proj']
check_grad(whisper, sharded_whisper, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
check_grad(whisper, sharded_whisper, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
col_layer_for_check = [
'encoder.layers[0].self_attn.q_proj',
# 'decoder.layers[0].self_attn.q_proj'
]
row_layer_for_check = [
'encoder.layers[0].self_attn.out_proj',
#'decoder.layers[0].self_attn.out_proj'
]
# check weights and gradients
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0)
check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_weight(whisper,
sharded_whisper,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
check_weight(whisper,
sharded_whisper,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
torch.cuda.empty_cache()
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
# TODOjianghai) fix fp16
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'precision': 'fp32',
'initial_scale': 1,
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
}, {
'tp_size': 4,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
}, {
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'use_lazy_init': False,
'precision': 'fp32',
}])
def run_whisper_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn,
enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
if test_config['pp_size'] > 2 and name == 'transformers_whisper_for_audio_classification':
continue
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@@ -73,7 +165,7 @@ def check_whisper(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_whisper():
spawn(check_whisper, 2)
spawn(check_whisper, 4)
if __name__ == "__main__":