diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index b1b8c6156..9cc583d58 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -68,6 +68,16 @@ _POLICY_LIST = { PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"), "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"), + + # OPT + "transformers.models.opt.modeling_opt.OPTModel": + PolicyLocation(file_name="opt", class_name="OPTModelPolicy"), + "transformers.models.opt.modeling_opt.OPTForCausalLM": + PolicyLocation(file_name="opt", class_name="OPTForCausalLMPolicy"), + "transformers.models.opt.modeling_opt.OPTForSequenceClassification": + PolicyLocation(file_name="opt", class_name="OPTForSequenceClassificationPolicy"), + "transformers.models.opt.modeling_opt.OPTForQuestionAnswering": + PolicyLocation(file_name="opt", class_name="OPTForQuestionAnsweringPolicy"), } diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py new file mode 100644 index 000000000..f467726e5 --- /dev/null +++ b/colossalai/shardformer/policies/opt.py @@ -0,0 +1,133 @@ +from transformers.models.opt.modeling_opt import ( + OPTAttention, + OPTDecoder, + OPTDecoderLayer, + OPTForCausalLM, + OPTForSequenceClassification, +) + +from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row + +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + + +class OPTPolicy(Policy): + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + base_policy = { + OPTDecoder: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=Embedding1D, + ) + ]), + OPTDecoderLayer: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ) + ]), + OPTAttention: + ModulePolicyDescription(attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ]), + } + if self.shard_config.fused_layernorm: + base_policy[OPTDecoder].sub_module_replacement.append( + SubModuleReplacementDescription(suffix="final_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True)) + base_policy[OPTDecoderLayer].sub_module_replacement.extend([ + SubModuleReplacementDescription(suffix="self_attn_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True), + SubModuleReplacementDescription(suffix="final_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True) + ]) + return base_policy + + def new_model_class(self): + return None + + def postprocess(self): + return self.model + + +class OPTModelPolicy(OPTPolicy): + + def __init__(self) -> None: + super().__init__() + + +class OPTForCausalLMPolicy(OPTPolicy): + + def module_policy(self): + policy = super().module_policy() + new_item = { + OPTForCausalLM: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)) + ]) + } + + policy.update(new_item) + return policy + + +class OPTForSequenceClassificationPolicy(OPTPolicy): + + def __init__(self) -> None: + super().__init__() + + +class OPTForQuestionAnsweringPolicy(OPTPolicy): + + def __init__(self) -> None: + super().__init__() diff --git a/tests/kit/model_zoo/transformers/opt.py b/tests/kit/model_zoo/transformers/opt.py index d9c4a0b3c..4463ae12b 100644 --- a/tests/kit/model_zoo/transformers/opt.py +++ b/tests/kit/model_zoo/transformers/opt.py @@ -11,14 +11,47 @@ SEQ_LENGTH = 16 def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long() return dict(input_ids=input_ids, attention_mask=attention_mask) -output_transform_fn = lambda x: x +def data_gen_for_causal_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + labels = data['input_ids'].clone() + data['labels'] = labels + return data -config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4) + +def data_gen_for_sequence_classification(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + labels = data['input_ids'].clone() + data['labels'] = torch.tensor([1]) + return data + + +def data_gen_for_question_answering(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data['start_positions'] = torch.tensor([0]) + data['end_positions'] = torch.tensor([1]) + return data + + +output_transform_fn = lambda x: x +loss_fn_for_opt_model = lambda x: x.last_hidden_state.mean() +loss_fn_for_lm = lambda x: x.loss +config = transformers.OPTConfig( + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + dropout=0, +) # register the following models # transformers.OPTModel, @@ -27,9 +60,23 @@ model_zoo.register(name='transformers_opt', model_fn=lambda: transformers.OPTModel(config), data_gen_fn=data_gen, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_opt_model, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_opt_for_causal_lm', model_fn=lambda: transformers.OPTForCausalLM(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_causal_lm, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_lm, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_opt_for_question_answering', + model_fn=lambda: transformers.OPTForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_lm, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_opt_for_sequence_classification', + model_fn=lambda: transformers.OPTForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_lm, model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index f528db6a6..c68b89e82 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -11,10 +11,9 @@ from tests.kit.model_zoo import model_zoo @clear_cache_before_run() def test_opt(): sub_registry = model_zoo.get_sub_registry('transformers_opt') - for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() - trace_model_and_compare_output(model, data_gen_fn) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'start_positions', 'end_positions']) if __name__ == '__main__': diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index a282e0bb9..ad7c408ae 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -25,7 +25,6 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, # switch to train mode original_model.train() sharded_model.train() - # run forward org_output = original_model(**data) org_output = output_transform_fn(org_output) @@ -34,5 +33,4 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, shard_output = sharded_model(**data) shard_output = output_transform_fn(shard_output) shard_loss = loss_fn(shard_output) - - return org_output, org_loss, shard_output, shard_loss + return org_output, org_loss, shard_output, shard_loss \ No newline at end of file diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py new file mode 100644 index 000000000..4d4c55770 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -0,0 +1,67 @@ +import copy +import os + +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.testing import ( + assert_hf_output_close, + check_state_dict_equal, + clear_cache_before_run, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4) + + # run backward + org_loss.backward() + shard_loss.backward() + + # check grad + if hasattr(org_model, 'model'): + opt_model = org_model.model + shard_opt_model = sharded_model.model + else: + opt_model = org_model + shard_opt_model = sharded_model + + org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad + shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad + + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + +def check_OPTModel(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(world_size, model_fn) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_OPTModel(): + spawn(check_OPTModel, 4)