From 9ee4ebea83b483cf95c6c4924621e89860cc6fd5 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Mon, 17 Jul 2023 14:25:32 +0800 Subject: [PATCH] [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme --- colossalai/shardformer/README.md | 2 +- colossalai/shardformer/layer/embedding.py | 10 +- .../shardformer/policies/auto_policy.py | 8 + colossalai/shardformer/policies/whisper.py | 232 ++++++++++++++++++ tests/kit/model_zoo/transformers/__init__.py | 1 + tests/kit/model_zoo/transformers/whisper.py | 91 +++++++ .../test_model/test_shard_whisper.py | 101 ++++++++ 7 files changed, 443 insertions(+), 2 deletions(-) create mode 100644 colossalai/shardformer/policies/whisper.py create mode 100644 tests/kit/model_zoo/transformers/whisper.py create mode 100644 tests/test_shardformer/test_model/test_shard_whisper.py diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index bf4215c52..3c322aabf 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -102,7 +102,7 @@ We will follow this roadmap to develop Shardformer: - [ ] SwinTransformer - [ ] SwinTransformer V2 - [ ] Audio - - [ ] Whisper + - [x] Whisper - [ ] Multi-modal - [ ] To be added diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 09b22abb1..f07a93bd6 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -202,7 +202,6 @@ class VocabParallelEmbedding1D(ParallelModule): super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim - self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs self.process_group = process_group @@ -276,6 +275,15 @@ class VocabParallelEmbedding1D(ParallelModule): with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + def _select_padding_idx(self, padding_idx: int): + # select padding index according to the rank + if padding_idx is None: + return None + elif padding_idx < self.vocab_end_index and padding_idx >= self.vocab_start_index: + return padding_idx - self.vocab_start_index + else: + return None + def forward(self, input_: Tensor) -> Tensor: # Build the mask. input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 63ec8398f..90347a984 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -105,6 +105,14 @@ _POLICY_LIST = { "transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering": PolicyLocation(file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"), + # Whisper + "transformers.models.whisper.modeling_whisper.WhisperModel": + PolicyLocation(file_name="whisper", class_name="WhisperModelPolicy"), + "transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration": + PolicyLocation(file_name="whisper", class_name="WhisperForConditionalGenerationPolicy"), + "transformers.models.whisper.modeling_whisper.WhisperForAudioClassification": + PolicyLocation(file_name="whisper", class_name="WhisperForAudioClassificationPolicy"), + # Sam "transformers.models.sam.modeling_sam.SamModel": PolicyLocation(file_name="sam", class_name="SamModelPolicy"), diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py new file mode 100644 index 000000000..7751bbb5d --- /dev/null +++ b/colossalai/shardformer/policies/whisper.py @@ -0,0 +1,232 @@ +import torch.nn as nn + +import colossalai.shardformer.layer as col_nn + +from .._utils import getattr_, setattr_ +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = [ + 'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', 'WhisperForAudioClassification' +] + + +class WhisperPolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + # TODO: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.whisper.modeling_whisper import ( + WhisperDecoder, + WhisperDecoderLayer, + WhisperEncoder, + WhisperEncoderLayer, + ) + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={ + "self_attn.embed_dim": + self.model.config.d_model // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": + self.model.config.encoder_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.Linear1D_Row, + ), + ]) + + policy[WhisperDecoderLayer] = ModulePolicyDescription(attribute_replacement={ + "self_attn.embed_dim": + self.model.config.d_model // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": + self.model.config.decoder_attention_heads // self.shard_config.tensor_parallel_size, + "encoder_attn.embed_dim": + self.model.config.d_model // self.shard_config.tensor_parallel_size, + "encoder_attn.num_heads": + self.model.config.encoder_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.Linear1D_Row, + ), + ]) + + policy[WhisperDecoder] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=col_nn.VocabParallelEmbedding1D, + ), + ]) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + # Handle encoder layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=WhisperEncoderLayer) + + # Handle decoder layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=WhisperDecoderLayer) + + # handle encoder layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=WhisperEncoder) + + # handle decoder layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=WhisperDecoder) + return policy + + def add_lm_head_policy(self, base_policy): + from transformers.models.whisper.modeling_whisper import WhisperForConditionalGeneration + + # optimize for tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="proj_out", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), + policy=base_policy, + target_key=WhisperForConditionalGeneration) + + return base_policy + + def postprocess(self): + return self.model + + +# WhisperModel +class WhisperModelPolicy(WhisperPolicy): + + def __init__(self) -> None: + super().__init__() + + +# WhisperForConditionalGeneration +class WhisperForConditionalGenerationPolicy(WhisperPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + module_policy = self.add_lm_head_policy(module_policy) + return module_policy + + def postprocess(self): + binding_map = {"model.decoder.embed_tokens.weight": "proj_out.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + setattr_(self.model, v, param) + return self.model + + +# WhisperForAudioClassification +class WhisperForAudioClassificationPolicy(WhisperPolicy): + + def __init__(self) -> None: + super().__init__() diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index a1bcb78dd..39e5ef411 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -7,3 +7,4 @@ from .opt import * from .sam import * from .t5 import * from .vit import * +from .whisper import * diff --git a/tests/kit/model_zoo/transformers/whisper.py b/tests/kit/model_zoo/transformers/whisper.py new file mode 100644 index 000000000..b58716217 --- /dev/null +++ b/tests/kit/model_zoo/transformers/whisper.py @@ -0,0 +1,91 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence Whisper +# =============================== + + +# define data gen function +def data_gen(): + # Generated from following code snippet + # + # from transformers import AutoFeatureExtractor, WhisperModel + # from datasets import load_dataset + + # model = WhisperModel.from_pretrained("openai/whisper-base") + # feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") + # ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + # inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") + # input_features = inputs.input_features + # decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id + + input_features = torch.randn(1, 80, 3000) + decoder_input_ids = torch.tensor([[1, 1]]) * 50258 + return dict(input_features=input_features, decoder_input_ids=decoder_input_ids) + + +def data_gen_for_conditional_generation(): + # labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + # Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` + # or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is + # only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + data = data_gen() + data['labels'] = torch.tensor([[0, 1]], dtype=torch.int64) + return data + + +def data_gen_for_audio_classification(): + # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + # Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + # config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + # `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + # `WhisperForAudioClassification` does not need `decoder_input_ids` + data = data_gen() + data.pop('decoder_input_ids') + data['labels'] = torch.tensor([1], dtype=torch.int64) + return data + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss funciton +loss_fn = lambda x: x.last_hidden_state.mean() +loss_fn_attr = lambda x: x.loss + +config = transformers.WhisperConfig( + classifier_proj_size=256, + d_model=256, + decoder_attention_heads=4, + decoder_ffn_dim=1536, + decoder_layers=2, + encoder_attention_heads=4, + encoder_ffn_dim=1536, + encoder_layers=2, + vocab_size=51866, +) + +# register the Whisper variants +model_zoo.register(name='transformers_whisper', + model_fn=lambda: transformers.WhisperModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name='transformers_whisperForConditionalGeneration', + model_fn=lambda: transformers.WhisperForConditionalGeneration(config), + data_gen_fn=data_gen_for_conditional_generation, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_attr, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name='transformers_whisperWhisperForAudioClassification', + model_fn=lambda: transformers.WhisperForAudioClassification(config), + data_gen_fn=data_gen_for_audio_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_attr, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py new file mode 100644 index 000000000..8932a4ab9 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -0,0 +1,101 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys='past_key_values') + + # do backward + org_loss.backward() + shard_loss.backward() + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # check grad + + if org_model.__class__.__name__ == 'WhisperForConditionalGeneration': + whisper = org_model.model + sharded_whisper = sharded_model.model + else: + whisper = org_model + sharded_whisper = sharded_model + + # compare self attention grad + org_grad = whisper.encoder.layers[0].self_attn.q_proj.weight.grad + shard_grad = sharded_whisper.encoder.layers[0].self_attn.q_proj.weight.grad + shard_weight = sharded_whisper.encoder.layers[0].self_attn.q_proj.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + # WhisperForAudioClassification does not have decoder and embedding layer + if org_model.__class__.__name__ == 'WhisperForAudioClassification': + return + + # compare embedding grad + org_grad = whisper.decoder.embed_tokens.weight.grad + shard_grad = sharded_whisper.decoder.embed_tokens.weight.grad + shard_weight = sharded_whisper.decoder.embed_tokens.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism): + sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(model_fn, + enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() + + +def check_whisper(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_whisper_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_whisper(): + spawn(check_whisper, 2) + + +if __name__ == "__main__": + test_whisper()