diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 5e7a285e3..b1b8c6156 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -25,17 +25,19 @@ class PolicyLocation: _POLICY_LIST = { # BERT "transformers.models.bert.modeling_bert.BertModel": - PolicyLocation(file_name="bert", class_name="BertPolicy"), + PolicyLocation(file_name="bert", class_name="BertModelPolicy"), "transformers.models.bert.modeling_bert.BertForPreTraining": PolicyLocation(file_name="bert", class_name="BertForPretrainingPolicy"), - "transformers.models.bert.modeling_bert.BertForMaskedLM": - PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"), "transformers.models.bert.modeling_bert.BertLMHeadModel": PolicyLocation(file_name="bert", class_name="BertLMHeadModelPolicy"), - "transformers.models.bert.modeling_bert.BertForNextSentencePrediction": - PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"), + "transformers.models.bert.modeling_bert.BertForMaskedLM": + PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"), "transformers.models.bert.modeling_bert.BertForSequenceClassification": PolicyLocation(file_name="bert", class_name="BertForSequenceClassificationPolicy"), + "transformers.models.bert.modeling_bert.BertForTokenClassification": + PolicyLocation(file_name="bert", class_name="BertForTokenClassificationPolicy"), + "transformers.models.bert.modeling_bert.BertForNextSentencePrediction": + PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"), "transformers.models.bert.modeling_bert.BertForMultipleChoice": PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"), @@ -58,6 +60,14 @@ _POLICY_LIST = { # GPT2 "transformers.models.gpt2.modeling_gpt2.GPT2Model": PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"), + "transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel": + PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"), + "transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel": + PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"), + "transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification": + PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"), + "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": + PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"), } diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index d5e8e01cf..8649c0dbe 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -131,37 +131,6 @@ class BertForPretrainingPolicy(BertPolicy): return self.model -# BertForMaskedLM -class BertForMaskedLMPolicy(BertPolicy): - - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - module_policy = super().module_policy() - addon_module = { - BertLMPredictionHead: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="decoder", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}) - ]) - } - module_policy.update(addon_module) - return module_policy - - def postprocess(self): - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - param = nn.Parameter(param) - setattr_(self.model, k, param) - setattr_(self.model, v, param) - return self.model - - # BertLMHeadModel class BertLMHeadModelPolicy(BertPolicy): @@ -193,15 +162,53 @@ class BertLMHeadModelPolicy(BertPolicy): return self.model -# BertForNextSentencePrediction -class BertForNextSentencePredictionPolicy(BertPolicy): +# BertForMaskedLM +class BertForMaskedLMPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + addon_module = { + BertLMPredictionHead: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="decoder", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) + return module_policy + + def postprocess(self): + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + param = nn.Parameter(param) + setattr_(self.model, k, param) + setattr_(self.model, v, param) + return self.model + + +# BertForSequenceClassification +class BertForSequenceClassificationPolicy(BertPolicy): def __init__(self) -> None: super().__init__() -# BertForSequenceClassification -class BertForSequenceClassificationPolicy(BertPolicy): +# BertForTokenClassification +class BertForTokenClassificationPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + +# BertForNextSentencePrediction +class BertForNextSentencePredictionPolicy(BertPolicy): def __init__(self) -> None: super().__init__() diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index da9e6b7bd..54ea2f6e3 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,7 +1,9 @@ -from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model +import torch.nn as nn +from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2Model import colossalai.shardformer.layer as col_nn +from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -82,7 +84,6 @@ class GPT2Policy(Policy): } def new_model_class(self): - return self.model def postprocess(self): @@ -94,3 +95,79 @@ class GPT2ModelPolicy(GPT2Policy): def __init__(self) -> None: super().__init__() + + +# GPT2LMHeadModel +class GPT2LMHeadModelPolicy(GPT2Policy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + addon_module = { + GPT2LMHeadModel: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) + return module_policy + + def postprocess(self): + binding_map = {"transformer.wte.weight": "lm_head.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + param = nn.Parameter(param) + setattr_(self.model, k, param) + setattr_(self.model, v, param) + return self.model + + +# GPT22DoubleHeadsModel +class GPT2DoubleHeadsModelPolicy(GPT2Policy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + addon_module = { + GPT2DoubleHeadsModel: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) + return module_policy + + def postprocess(self): + binding_map = {"transformer.wte.weight": "lm_head.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + param = nn.Parameter(param) + setattr_(self.model, k, param) + setattr_(self.model, v, param) + return self.model + + +# GPT2ForTokenClassification +class GPT2ForTokenClassificationPolicy(GPT2Policy): + + def __init__(self) -> None: + super().__init__() + + +# GPT2ForSequenceClassification +class GPT2ForSequenceClassificationPolicy(GPT2Policy): + + def __init__(self) -> None: + super().__init__() diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index 99135704d..d2d3de7b7 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -6,83 +6,147 @@ from ..registry import ModelAttribute, model_zoo # =============================== # Register single-sentence BERT # =============================== -BATCH_SIZE = 2 -SEQ_LENGTH = 16 -def data_gen_fn(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) +# define data gen function +def data_gen(): + # Generated from following code snippet + # + # from transformers import BertTokenizer + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt') + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + # token_type_ids = tokenized_input['token_type_ids'] + input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64) + token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data['labels'] = data['input_ids'].clone() + return data + + +def data_gen_for_pretraining(): + # pretraining data gen + # `next_sentence_label` is the label for next sentence prediction, 0 or 1 + data = data_gen_for_lm() + data['next_sentence_label'] = torch.tensor([1], dtype=torch.int64) + return data + + +def data_gen_for_sequence_classification(): + # sequence classification data gen + # `labels` is the label for sequence classification, 0 or 1 + data = data_gen() + data['labels'] = torch.tensor([1], dtype=torch.int64) + return data + + +def data_gen_for_token_classification(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + data['labels'] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) + return data + + +def data_gen_for_mcq(): + # multiple choice question data gen + # Generated from following code snippet + # + # tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") + # prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + # choice0 = "It is eaten with a fork and a knife." + # choice1 = "It is eaten while held in the hand." + # data = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) + # data = {k: v.unsqueeze(0) for k, v in encoding.items()} + # data['labels'] = torch.tensor([0], dtype=torch.int64) + input_ids = torch.tensor([[[ + 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591, + 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102 + ], + [ + 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, + 4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096, + 2218, 1999, 1996, 2192, 1012, 102, 0 + ]]]) + token_type_ids = torch.tensor( + [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]]) + attention_mask = torch.tensor( + [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]]) + labels = torch.tensor([0], dtype=torch.int64) + + return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels) + + +# define output transform function output_transform_fn = lambda x: x -config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256) +# define loss funciton +loss_fn_for_bert_model = lambda x: x.pooler_output.mean() +loss_fn = lambda x: x.loss + +config = transformers.BertConfig(hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256, + hidden_dropout_prob=0, + attention_probs_dropout_prob=0) # register the BERT variants model_zoo.register(name='transformers_bert', model_fn=lambda: transformers.BertModel(config), - data_gen_fn=data_gen_fn, + data_gen_fn=data_gen, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_bert_model, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_for_pretraining', model_fn=lambda: transformers.BertForPreTraining(config), - data_gen_fn=data_gen_fn, + data_gen_fn=data_gen_for_pretraining, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_lm_head_model', model_fn=lambda: transformers.BertLMHeadModel(config), - data_gen_fn=data_gen_fn, + data_gen_fn=data_gen_for_lm, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_for_masked_lm', model_fn=lambda: transformers.BertForMaskedLM(config), - data_gen_fn=data_gen_fn, + data_gen_fn=data_gen_for_lm, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_for_sequence_classification', model_fn=lambda: transformers.BertForSequenceClassification(config), - data_gen_fn=data_gen_fn, + data_gen_fn=data_gen_for_sequence_classification, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_for_token_classification', model_fn=lambda: transformers.BertForTokenClassification(config), - data_gen_fn=data_gen_fn, + data_gen_fn=data_gen_for_token_classification, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) - - -# =============================== -# Register multi-sentence BERT -# =============================== -def data_gen_for_next_sentence(): - tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") - prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - next_sentence = "The sky is blue due to the shorter wavelength of blue light." - encoding = tokenizer(prompt, next_sentence, return_tensors="pt") - return encoding - - -def data_gen_for_mcq(): - tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") - prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - choice0 = "It is eaten with a fork and a knife." - choice1 = "It is eaten while held in the hand." - encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) - encoding = {k: v.unsqueeze(0) for k, v in encoding.items()} - return encoding - - -# register the following models model_zoo.register(name='transformers_bert_for_next_sentence', model_fn=lambda: transformers.BertForNextSentencePrediction(config), - data_gen_fn=data_gen_for_next_sentence, + data_gen_fn=data_gen_for_sequence_classification, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_for_mcq', model_fn=lambda: transformers.BertForMultipleChoice(config), data_gen_fn=data_gen_for_mcq, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 5ed4fbe70..c598fa8f4 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -11,47 +11,86 @@ SEQ_LENGTH = 16 def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + # Generated from following code snippet + # + # from transformers import GPT2Tokenizer + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt') + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) -def seq_classification_data_gen(): - # batch sizes should be 1 if no padding token is defined. - input_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) - token_type_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) - return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data['labels'] = data['input_ids'].clone() + return data +def data_gen_for_token_classification(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64) + return data + + +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data['labels'] = torch.tensor([0], dtype=torch.int64) + return data + + +# define output transform function output_transform_fn = lambda x: x -config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4) +# define loss function +loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean() +loss_fn = lambda x: x.loss + +config = transformers.GPT2Config(n_layer=2, + n_head=4, + vocab_size=50258, + attn_pdrop=0, + embd_pdrop=0, + resid_pdrop=0, + summary_first_dropout=0, + hidden_dropout=0, + problem_type="single_label_classification") # register the following models model_zoo.register(name='transformers_gpt', model_fn=lambda: transformers.GPT2Model(config), data_gen_fn=data_gen, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_gpt2_model, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_gpt_lm', model_fn=lambda: transformers.GPT2LMHeadModel(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_lm, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_gpt_double_heads', model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_lm, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_gpt_for_token_classification', model_fn=lambda: transformers.GPT2ForTokenClassification(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_token_classification, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_gpt_for_sequence_classification', model_fn=lambda: transformers.GPT2ForSequenceClassification(config), - data_gen_fn=seq_classification_data_gen, + data_gen_fn=data_gen_for_sequence_classification, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 043ed1a74..ad98e3d07 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -1,86 +1,30 @@ -import copy -import os - import pytest import torch -from transformers import ( - AutoTokenizer, - BertConfig, - BertForMaskedLM, - BertForNextSentencePrediction, - BertForPreTraining, - BertForSequenceClassification, - BertLMHeadModel, - BertModel, -) import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.testing import rerun_if_address_is_in_use, spawn - -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) -tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") +from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward -def build_model(world_size, model_fn): - config = BertConfig() - config.hidden_dropout_prob = 0 - config.attention_probs_dropout_prob = 0 +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output) - org_model = model_fn(config=config) - org_model_forshard = copy.deepcopy(org_model) - - org_model.to('cuda') - # TODO: no need to transfer to cuda - org_model_forshard.to('cuda') - shard_config = ShardConfig(tensor_parallel_size=world_size,) - shard_former = ShardFormer(shard_config=shard_config) - shard_former.init_distributed() - sharded_model = shard_former.shard_model(org_model_forshard).to('cuda') - - return org_model, sharded_model - - -def check_forward(org_model, sharded_model): - input = 'Hello, my dog is cute' - tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') - - #orgin model - org_model.eval() - org_out = org_model(**tokenized_input) - - #shard model - sharded_model.eval() - shard_out = sharded_model(**tokenized_input) - - assert torch.allclose( - org_out[0], shard_out[0], - atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" - - -def check_backward(org_model, sharded_model): - # prepare input - input = 'Hello, my dog is cute' - tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') - labels = tokenized_input['input_ids'].clone() - labels[labels == tokenizer.pad_token_id] = -100 - tokenized_input['labels'] = labels - - #orgin model - org_model.train() - org_out = org_model(**tokenized_input) - org_loss = org_out.loss + # do backward org_loss.backward() - org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad - - #shard model - sharded_model.train() - shard_out = sharded_model(**tokenized_input) - shard_loss = shard_out.loss shard_loss.backward() - shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad + + # check grad equality + if org_model.__class__.__name__ == 'BertModel': + org_grad = org_model.encoder.layer[0].attention.self.query.weight.grad + shard_grad = sharded_model.encoder.layer[0].attention.self.query.weight.grad + else: + org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad + shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) @@ -89,36 +33,24 @@ def check_backward(org_model, sharded_model): assert torch.allclose(org_loss, shard_loss, atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" def check_bert(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - forward_list = [ - BertForMaskedLM, - BertForPreTraining, - BertLMHeadModel, + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - # TODO: do not work yet - # BertModel, - # BertForSequenceClassification - # BertForNextSentencePrediction, - ] - backward_lsit = [BertForMaskedLM, BertLMHeadModel] - - for model_fn in forward_list: + sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(world_size, model_fn) - check_forward(org_model, sharded_model) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - if model_fn in backward_lsit: - check_backward(org_model, sharded_model) - - torch.cuda.empty_cache() + torch.cuda.empty_cache() @pytest.mark.dist @rerun_if_address_is_in_use() +@clear_cache_before_run() def test_bert(): spawn(check_bert, 2) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 2f679b83f..0c07f4440 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -1,117 +1,61 @@ -import copy -import os - import pytest import torch -from transformers import AutoTokenizer, GPT2Config, GPT2Model import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.testing import rerun_if_address_is_in_use, spawn - -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) -tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") +from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward -def build_model(world_size, model_fn): - config = GPT2Config() - config.attn_pdrop = 0 - config.embd_pdrop = 0 - config.resid_pdrop = 0 - config.summary_first_dropout +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) - org_model = model_fn(config=config) - org_model_forshard = copy.deepcopy(org_model) - - org_model.to('cuda') - # TODO: no need to transfer to cuda - org_model_forshard.to('cuda') - shard_config = ShardConfig(tensor_parallel_size=world_size,) - shard_former = ShardFormer(shard_config=shard_config) - shard_former.init_distributed() - sharded_model = shard_former.shard_model(org_model_forshard).to('cuda') - - return org_model, sharded_model - - -def check_forward(org_model, sharded_model): - input = 'Hello, my dog is cute' - tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') - - #orgin model - org_model.eval() - org_out = org_model(**tokenized_input) - - #shard model - sharded_model.eval() - shard_out = sharded_model(**tokenized_input) - - assert torch.allclose( - org_out[0], shard_out[0], - atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" - - -def check_backward(org_model, sharded_model): - # prepare input - input = 'Hello, my dog is cute' - tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') - labels = tokenized_input['input_ids'].clone() - labels[labels == tokenizer.pad_token_id] = -100 - # tokenized_input['labels'] = labels - - #orgin model - org_model.train() - org_out = org_model(**tokenized_input) - org_loss = org_out.loss + # do backward org_loss.backward() - org_grad = org_model.h[0].attn.c_attn.weight.grad - - #shard model - sharded_model.train() - shard_out = sharded_model(**tokenized_input) - shard_loss = shard_out.loss shard_loss.backward() - shard_grad = sharded_model.h[0].attn.c_attn.weight.grad + + # check grad equality + if org_model.__class__.__name__ == 'GPT2Model': + org_grad = org_model.h[0].attn.c_attn.weight.grad + shard_grad = sharded_model.h[0].attn.c_attn.weight.grad.transpose(0, 1).contiguous() + else: + org_grad = org_model.transformer.h[0].mlp.c_fc.weight.grad + shard_grad = sharded_model.transformer.h[0].mlp.c_fc.weight.grad.transpose(0, 1).contiguous() shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + all_shard_grad = torch.cat(shard_grad_list, dim=1) assert torch.allclose(org_loss, shard_loss, atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" -def check_bert(rank, world_size, port): +def check_gpt2(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - forward_list = [ - GPT2Model, + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - # TODO: do not work yet - # BertModel, - # BertForSequenceClassification - # BertForNextSentencePrediction, - ] - backward_lsit = [] - - for model_fn in forward_list: + sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + print(name) + # if name == 'transformers_gpt': + # continue org_model, sharded_model = build_model(world_size, model_fn) - check_forward(org_model, sharded_model) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - if model_fn in backward_lsit: - check_backward(org_model, sharded_model) - - torch.cuda.empty_cache() + torch.cuda.empty_cache() @pytest.mark.dist @rerun_if_address_is_in_use() +@clear_cache_before_run() def test_gpt2(): - spawn(check_bert, 2) + spawn(check_gpt2, 2) if __name__ == "__main__":