mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[shardformer] support whisper (#4212)
* support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme
This commit is contained in:
@@ -7,3 +7,4 @@ from .opt import *
|
||||
from .sam import *
|
||||
from .t5 import *
|
||||
from .vit import *
|
||||
from .whisper import *
|
||||
|
91
tests/kit/model_zoo/transformers/whisper.py
Normal file
91
tests/kit/model_zoo/transformers/whisper.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
# ===============================
|
||||
# Register single-sentence Whisper
|
||||
# ===============================
|
||||
|
||||
|
||||
# define data gen function
|
||||
def data_gen():
|
||||
# Generated from following code snippet
|
||||
#
|
||||
# from transformers import AutoFeatureExtractor, WhisperModel
|
||||
# from datasets import load_dataset
|
||||
|
||||
# model = WhisperModel.from_pretrained("openai/whisper-base")
|
||||
# feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
|
||||
# ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
|
||||
# input_features = inputs.input_features
|
||||
# decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
|
||||
|
||||
input_features = torch.randn(1, 80, 3000)
|
||||
decoder_input_ids = torch.tensor([[1, 1]]) * 50258
|
||||
return dict(input_features=input_features, decoder_input_ids=decoder_input_ids)
|
||||
|
||||
|
||||
def data_gen_for_conditional_generation():
|
||||
# labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
# Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
|
||||
# or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
|
||||
# only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([[0, 1]], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_audio_classification():
|
||||
# labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
# Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
# config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
# `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
# `WhisperForAudioClassification` does not need `decoder_input_ids`
|
||||
data = data_gen()
|
||||
data.pop('decoder_input_ids')
|
||||
data['labels'] = torch.tensor([1], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
# define output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss funciton
|
||||
loss_fn = lambda x: x.last_hidden_state.mean()
|
||||
loss_fn_attr = lambda x: x.loss
|
||||
|
||||
config = transformers.WhisperConfig(
|
||||
classifier_proj_size=256,
|
||||
d_model=256,
|
||||
decoder_attention_heads=4,
|
||||
decoder_ffn_dim=1536,
|
||||
decoder_layers=2,
|
||||
encoder_attention_heads=4,
|
||||
encoder_ffn_dim=1536,
|
||||
encoder_layers=2,
|
||||
vocab_size=51866,
|
||||
)
|
||||
|
||||
# register the Whisper variants
|
||||
model_zoo.register(name='transformers_whisper',
|
||||
model_fn=lambda: transformers.WhisperModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
model_zoo.register(name='transformers_whisperForConditionalGeneration',
|
||||
model_fn=lambda: transformers.WhisperForConditionalGeneration(config),
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_attr,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
model_zoo.register(name='transformers_whisperWhisperForAudioClassification',
|
||||
model_fn=lambda: transformers.WhisperForAudioClassification(config),
|
||||
data_gen_fn=data_gen_for_audio_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_attr,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
101
tests/test_shardformer/test_model/test_shard_whisper.py
Normal file
101
tests/test_shardformer/test_model/test_shard_whisper.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
assert_hf_output_close(org_output, shard_output, ignore_keys='past_key_values')
|
||||
|
||||
# do backward
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
|
||||
# check grad
|
||||
|
||||
if org_model.__class__.__name__ == 'WhisperForConditionalGeneration':
|
||||
whisper = org_model.model
|
||||
sharded_whisper = sharded_model.model
|
||||
else:
|
||||
whisper = org_model
|
||||
sharded_whisper = sharded_model
|
||||
|
||||
# compare self attention grad
|
||||
org_grad = whisper.encoder.layers[0].self_attn.q_proj.weight.grad
|
||||
shard_grad = sharded_whisper.encoder.layers[0].self_attn.q_proj.weight.grad
|
||||
shard_weight = sharded_whisper.encoder.layers[0].self_attn.q_proj.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
# WhisperForAudioClassification does not have decoder and embedding layer
|
||||
if org_model.__class__.__name__ == 'WhisperForAudioClassification':
|
||||
return
|
||||
|
||||
# compare embedding grad
|
||||
org_grad = whisper.decoder.embed_tokens.weight.grad
|
||||
shard_grad = sharded_whisper.decoder.embed_tokens.weight.grad
|
||||
shard_weight = sharded_whisper.decoder.embed_tokens.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn,
|
||||
enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_whisper(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_whisper_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_whisper():
|
||||
spawn(check_whisper, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_whisper()
|
Reference in New Issue
Block a user