diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 203b7439d..2fd135d54 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -64,7 +64,10 @@ def _broadcast_object_list(object_list: List[Any], my_rank = dist.get_rank() # Serialize object_list elements to tensors on src rank. if my_rank == src: - tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) + if torch.__version__ >= "1.13.0": + tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=device) for obj in object_list]) + else: + tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) object_sizes_tensor = torch.cat(size_list) else: object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index a8933bfbb..6ed3055d6 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -205,7 +205,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): # the backward pass. input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) - input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) if last_iteration: diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 0ad9a3e95..ccdb33b2e 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -42,6 +42,8 @@ _POLICY_LIST = { PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"), "transformers.models.bert.modeling_bert.BertForMultipleChoice": PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"), + "transformers.models.bert.modeling_bert.BertForQuestionAnswering": + PolicyLocation(file_name="bert", class_name="BertForQuestionAnsweringPolicy"), # LLaMA "transformers.models.llama.modeling_llama.LlamaModel": diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 2b2c003ff..1af26f504 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,22 +1,30 @@ from functools import partial from types import MethodType -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from torch import Tensor from torch.nn import CrossEntropyLoss, Module from transformers.modeling_outputs import ( - BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, ) from transformers.models.bert.modeling_bert import ( BertForMaskedLM, + BertForMultipleChoice, BertForNextSentencePrediction, BertForPreTraining, BertForPreTrainingOutput, + BertForQuestionAnswering, + BertForSequenceClassification, + BertForTokenClassification, BertLMHeadModel, BertModel, ) @@ -31,9 +39,9 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDe logger = logging.get_logger(__name__) __all__ = [ - 'BertPolicy', 'BertModelPolicy', 'BertForPreTrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy', + 'BertPolicy', 'BertModelPolicy', 'BertForPreTrainingPolicy', 'BertLMdHeadModelPolicy', 'BertForMaskedLMPolicy', 'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy', - 'BertForMultipleChoicePolicy' + 'BertForMultipleChoicePolicy', 'BertForQuestionAnsweringPolicy' ] @@ -172,6 +180,25 @@ class BertPolicy(Policy): def postprocess(self): return self.model + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "BertModel": + module = self.model + else: + module = self.model.bert + + layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=model_cls) + + return + # BertModel class BertModelPolicy(BertPolicy): @@ -180,13 +207,10 @@ class BertModelPolicy(BertPolicy): super().__init__() def module_policy(self): - module_policy = super().module_policy() + policy = super().module_policy() from transformers.models.bert.modeling_bert import BertModel - if self.pipeline_stage_manager: - # set None as default - module_policy[BertModel] = ModulePolicyDescription( - method_replacement={'forward': partial(bert_model_forward, stage_manager=self.pipeline_stage_manager)}) - return module_policy + self.set_pipeline_forward(model_cls=BertModel, new_forward=bert_model_forward, policy=policy) + return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" @@ -214,15 +238,17 @@ class BertForPreTrainingPolicy(BertPolicy): super().__init__() def module_policy(self): - module_policy = super().module_policy() - module_policy = self.add_lm_head_policy(module_policy) - return module_policy + policy = super().module_policy() + policy = self.add_lm_head_policy(policy) + from transformers.models.bert.modeling_bert import BertForPreTraining + self.set_pipeline_forward(model_cls=BertForPreTraining, new_forward=bert_for_pretraining_forward, policy=policy) + return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage""" module = self.model stage_manager = self.pipeline_stage_manager - layers_per_stage = self.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) held_layers = [] if stage_manager.is_first_stage(): held_layers.append(module.bert.embeddings) @@ -237,11 +263,18 @@ class BertForPreTrainingPolicy(BertPolicy): return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - '''No shared params in bertmodel''' + model = self.model + if self.pipeline_stage_manager: + if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight): + #tie weights + return [{ + 0: model.bert.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight + }] return [] def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) @@ -256,9 +289,11 @@ class BertLMHeadModelPolicy(BertPolicy): super().__init__() def module_policy(self): - module_policy = super().module_policy() - module_policy = self.add_lm_head_policy(module_policy) - return module_policy + policy = super().module_policy() + policy = self.add_lm_head_policy(policy) + from transformers.models.bert.modeling_bert import BertLMHeadModel + self.set_pipeline_forward(model_cls=BertLMHeadModel, new_forward=bert_lm_head_model_forward, policy=policy) + return policy def get_held_layers(self) -> List[Module]: """ @@ -267,7 +302,7 @@ class BertLMHeadModelPolicy(BertPolicy): module = self.model held_layers = [] stage_manager = self.pipeline_stage_manager - layers_per_stage = self.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) if stage_manager.is_first_stage(): held_layers.append(module.bert.embeddings) start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) @@ -278,11 +313,18 @@ class BertLMHeadModelPolicy(BertPolicy): return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - '''No shared params in bertmodel''' + bert_model = self.model.bert + if self.pipeline_stage_manager: + if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): + #tie weights + return [{ + 0: bert_model.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight + }] return [] def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) @@ -297,12 +339,42 @@ class BertForMaskedLMPolicy(BertPolicy): super().__init__() def module_policy(self): - module_policy = super().module_policy() - module_policy = self.add_lm_head_policy(module_policy) - return module_policy + policy = super().module_policy() + policy = self.add_lm_head_policy(policy) + from transformers.models.bert.modeling_bert import BertForMaskedLM + self.set_pipeline_forward(model_cls=BertForMaskedLM, new_forward=bert_for_masked_lm_forward, policy=policy) + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.cls) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + bert_model = self.model.bert + if self.pipeline_stage_manager: + if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): + #tie weights + return [{ + 0: bert_model.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight + }] + return [] def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) @@ -319,7 +391,7 @@ class BertForSequenceClassificationPolicy(BertPolicy): def module_policy(self): from transformers.models.bert.modeling_bert import BertForSequenceClassification - module_policy = super().module_policy() + policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: addon_module = { @@ -331,8 +403,35 @@ class BertForSequenceClassificationPolicy(BertPolicy): ) ]) } - module_policy.update(addon_module) - return module_policy + policy.update(addon_module) + + self.set_pipeline_forward(model_cls=BertForSequenceClassification, + new_forward=bert_for_sequence_classification_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.dropout) + held_layers.append(module.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] # BertForTokenClassification @@ -344,7 +443,7 @@ class BertForTokenClassificationPolicy(BertPolicy): def module_policy(self): from transformers.models.bert.modeling_bert import BertForTokenClassification - module_policy = super().module_policy() + policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: addon_module = { @@ -356,8 +455,35 @@ class BertForTokenClassificationPolicy(BertPolicy): ) ]) } - module_policy.update(addon_module) - return module_policy + policy.update(addon_module) + + self.set_pipeline_forward(model_cls=BertForTokenClassification, + new_forward=bert_for_token_classification_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.dropout) + held_layers.append(module.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] # BertForNextSentencePrediction @@ -366,6 +492,36 @@ class BertForNextSentencePredictionPolicy(BertPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self): + policy = super().module_policy() + from transformers.models.bert.modeling_bert import BertForNextSentencePrediction + self.set_pipeline_forward(model_cls=BertForNextSentencePrediction, + new_forward=bert_for_next_sentence_prediction_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.cls) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] + # BertForMultipleChoice class BertForMultipleChoicePolicy(BertPolicy): @@ -376,7 +532,7 @@ class BertForMultipleChoicePolicy(BertPolicy): def module_policy(self): from transformers.models.bert.modeling_bert import BertForMultipleChoice - module_policy = super().module_policy() + policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: addon_module = { @@ -388,28 +544,91 @@ class BertForMultipleChoicePolicy(BertPolicy): ) ]) } - module_policy.update(addon_module) - return module_policy + policy.update(addon_module) + + self.set_pipeline_forward(model_cls=BertForMultipleChoice, + new_forward=bert_for_multiple_choice_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.dropout) + held_layers.append(module.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] + + +class BertForQuestionAnsweringPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.bert.modeling_bert import BertForQuestionAnswering + policy = super().module_policy() + self.set_pipeline_forward(model_cls=BertForQuestionAnswering, + new_forward=bert_for_question_answering_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.qa_outputs) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] def bert_model_forward( - self: BertModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - # labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage + self: BertModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage + stage_index: Optional[List[int]] = None, ): # TODO: add explaination of the output here. r""" @@ -528,14 +747,10 @@ def bert_model_forward( use_cache = False next_decoder_cache = () if use_cache else None - # calculate the num_layers - num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages - start_layer = stage_manager.stage * num_layers_per_stage - end_layer = (stage_manager.stage + 1) * num_layers_per_stage - + start_idx, end_idx = stage_index[0], stage_index[1] # layer_outputs layer_outputs = hidden_states if hidden_states is not None else None - for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): + for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): if stage_manager.is_first_stage() and idx == 0: encoder_attention_mask = encoder_extended_attention_mask @@ -593,8 +808,9 @@ def bert_model_forward( return (sequence_output, pooled_output) + layer_outputs[1:] # return dict is not supported at this moment else: - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, @@ -624,6 +840,7 @@ def bert_for_pretraining_forward( return_dict: Optional[bool] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. @@ -637,18 +854,21 @@ def bert_for_pretraining_forward( logger.warning_once('return_dict is not supported for pipeline models at the moment') return_dict = False - outputs = bert_model_forward(self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states if hidden_states is not None else None) + outputs = bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states if hidden_states is not None else None, + stage_index=stage_index, + ) past_key_values = None all_hidden_states = None all_self_attentions = None @@ -684,23 +904,26 @@ def bert_for_pretraining_forward( } -def bert_lmhead_forward(self: BertLMHeadModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.Tensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_manager: Optional[PipelineStageManager] = None): +def bert_lm_head_model_forward( + self: BertLMHeadModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, +): r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if @@ -754,7 +977,8 @@ def bert_lmhead_forward(self: BertLMHeadModel, output_hidden_states=output_hidden_states, return_dict=return_dict, stage_manager=stage_manager, - hidden_states=hidden_states if hidden_states is not None else None) + hidden_states=hidden_states if hidden_states is not None else None, + stage_index=stage_index) past_key_values = None all_hidden_states = None all_self_attentions = None @@ -806,15 +1030,66 @@ def bert_for_masked_lm_forward( return_dict: Optional[bool] = None, hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, ): - #-> Union[Tuple[torch.Tensor], MaskedLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` """ - pass + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + outputs = bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + ) + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} def bert_for_next_sentence_prediction_forward( @@ -831,6 +1106,7 @@ def bert_for_next_sentence_prediction_forward( return_dict: Optional[bool] = None, hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, **kwargs, ): #-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: @@ -881,18 +1157,20 @@ def bert_for_next_sentence_prediction_forward( return_dict = False return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) + outputs = bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index) + if stage_manager.is_last_stage(): pooled_output = outputs[1] seq_relationship_scores = self.cls(pooled_output) @@ -916,3 +1194,355 @@ def bert_for_next_sentence_prediction_forward( hidden_states = outputs.get('hidden_states') # intermediate stage always return dict return {'hidden_states': hidden_states} + + +def bert_for_sequence_classification_forward( + self: BertForSequenceClassification, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, +): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index) + + if stage_manager.is_last_stage(): + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + +def bert_for_token_classification_forward( + self: BertForTokenClassification, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, +): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + ) + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + +def bert_for_multiple_choice_forward( + self: BertForMultipleChoice, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, +): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # in our pipeline design,input ids are copied for every stage and shouldn't be none + # the input_ids for multiple choice model is [batch_size, num_choices, sequence_length] + if stage_manager.is_last_stage(): + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None else None) + + outputs = bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + ) + if stage_manager.is_last_stage(): + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + +def bert_for_question_answering_forward( + self: BertForQuestionAnswering, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, +): + # NOTE: the arg start_position and end_position are used only for the last stage + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + ) + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index a3ea80726..b3757452c 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -212,11 +212,13 @@ class LlamaForCausalLMPolicy(LlamaPolicy): return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - """No shared params in llama model""" llama_model = self.model.model if id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight): # tie weights - return [{0: llama_model.embed_tokens.weight, self.stage_manager.num_stages - 1: self.model.lm_head.weight}] + return [{ + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight + }] return [] diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 43952e699..4a19f2449 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -from .torchrec import * +#from .torchrec import * diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index d2d3de7b7..1993af51a 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -87,6 +87,17 @@ def data_gen_for_mcq(): return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels) +def data_gen_for_qa(): + # generating data for question answering + # no need for labels and use start and end position instead + data = data_gen() + start_positions = torch.tensor([0], dtype=torch.int64) + data['start_positions'] = start_positions + end_positions = torch.tensor([1], dtype=torch.int64) + data['end_positions'] = end_positions + return data + + # define output transform function output_transform_fn = lambda x: x @@ -150,3 +161,9 @@ model_zoo.register(name='transformers_bert_for_mcq', output_transform_fn=output_transform_fn, loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bert_for_question_answering', + model_fn=lambda: transformers.BertForQuestionAnswering(config), + data_gen_fn=data_gen_for_qa, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py index 97d7d2fa5..6a8d7b636 100644 --- a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py +++ b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py @@ -7,6 +7,7 @@ from transformers.models.bert.modeling_bert import BertForPreTraining import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward from colossalai.shardformer.shard import ShardConfig from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -35,16 +36,20 @@ def check_bert_for_pretraining_forward(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() # print(rank) + layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) x = torch.randint(0, 1000, (2, 3)) hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) if stage_manager.stage == 0: attention_mask = torch.ones_like(x) - output = bert_for_pretraining_forward(self=model, - input_ids=x, - attention_mask=attention_mask, - stage_manager=stage_manager) - print(output['hidden_states'].shape) + output = bert_for_pretraining_forward( + self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager, + stage_index=stage_index, + ) assert output['hidden_states'].shape == (2, 3, 768) else: @@ -52,8 +57,8 @@ def check_bert_for_pretraining_forward(): output = bert_for_pretraining_forward(self=model, hidden_states=hidden_states, attention_mask=attention_mask, - stage_manager=stage_manager) - print(output[0].shape) + stage_manager=stage_manager, + stage_index=stage_index) assert output[0].shape == (2, 3, 30522) # assert output[1].shape == (2, 768) diff --git a/tests/test_pipeline/test_policy/test_bert_lmhead_model.py b/tests/test_pipeline/test_policy/test_bert_lm_head_model.py similarity index 73% rename from tests/test_pipeline/test_policy/test_bert_lmhead_model.py rename to tests/test_pipeline/test_policy/test_bert_lm_head_model.py index b14dadf29..cd47f7a33 100644 --- a/tests/test_pipeline/test_policy/test_bert_lmhead_model.py +++ b/tests/test_pipeline/test_policy/test_bert_lm_head_model.py @@ -7,12 +7,13 @@ from transformers.models.bert.modeling_bert import BertLMHeadModel import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lmhead_forward +from colossalai.shardformer.policies.base_policy import Policy +from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lm_head_model_forward from colossalai.shardformer.shard import ShardConfig from colossalai.testing import rerun_if_address_is_in_use, spawn -def check_bert_lmhead_forward(): +def check_bert_lm_head_model_forward(): configuration = BertConfig() model = BertLMHeadModel(configuration) DP_DIM, PP_DIM = 0, 1 @@ -35,24 +36,28 @@ def check_bert_lmhead_forward(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() # print(rank) - + layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) x = torch.randint(0, 1000, (2, 3)) hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) if stage_manager.stage == 0: attention_mask = torch.ones_like(x) - output = bert_lmhead_forward(self=model, - input_ids=x, - attention_mask=attention_mask, - stage_manager=stage_manager) + + output = bert_lm_head_model_forward(self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager, + stage_index=stage_index) print(output['hidden_states'].shape) assert output['hidden_states'].shape == (2, 3, 768) else: attention_mask = torch.ones((2, 3)) - output = bert_lmhead_forward(self=model, - hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager) + output = bert_lm_head_model_forward(self=model, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager, + stage_index=stage_index) print(output[0].shape) assert output[0].shape == (2, 3, 30522) @@ -93,7 +98,7 @@ def check_bert_lmhead_policy(): def run_dist_model(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_lmhead_forward() + check_bert_lm_head_model_forward() def run_dist_policy(rank, world_size, port): @@ -103,7 +108,7 @@ def run_dist_policy(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() -def test_bert_lmhead_forward(): +def test_bert_lm_head_model_forward(): spawn(run_dist_model, 4) @@ -115,5 +120,5 @@ def test_bert_lmhead_policy(): if __name__ == "__main__": """test the bert for pretraining model forward and bert for pretraining model policy""" - test_bert_lmhead_forward() + test_bert_lm_head_model_forward() test_bert_lmhead_policy() diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py index f5a443309..f116bc761 100644 --- a/tests/test_pipeline/test_policy/test_bert_model.py +++ b/tests/test_pipeline/test_policy/test_bert_model.py @@ -6,12 +6,14 @@ from transformers.models.bert.modeling_bert import BertModel import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.bert import BertModelPolicy, bert_model_forward from colossalai.shardformer.shard import ShardConfig from colossalai.testing import rerun_if_address_is_in_use, spawn def check_bert_model_forward(): + # this test may crash for internet reasons model = BertModel.from_pretrained('bert-base-uncased') DP_DIM, PP_DIM = 0, 1 DP_SIZE, PP_SIZE = 2, 2 @@ -34,20 +36,25 @@ def check_bert_model_forward(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() # print(rank) - + layers_per_stage = Policy.distribute_layers(len(model.encoder.layer), 2) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) x = torch.randint(0, 1000, (2, 3)) hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) if stage_manager.stage == 0: attention_mask = torch.ones_like(x) - output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - print(output['hidden_states'].shape) + output = bert_model_forward(self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager, + stage_index=stage_index) assert output['hidden_states'].shape == (2, 3, 768) else: attention_mask = torch.ones((2, 3)) output = bert_model_forward(self=model, hidden_states=hidden_states, attention_mask=attention_mask, - stage_manager=stage_manager) + stage_manager=stage_manager, + stage_index=stage_index) print(output[0].shape) assert output[0].shape == (2, 3, 768) @@ -112,4 +119,3 @@ if __name__ == "__main__": """test the bert model forward and bert model policy""" #test_bert_model_forward() test_bert_model_policy() - # this test need config to run diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index f26c6622d..825d6df6b 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -49,7 +49,6 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, # prepare input data = data_gen_fn() data = {k: v.cuda() for k, v in data.items()} - # switch to train mode original_model.train() sharded_model.train() diff --git a/tests/test_shardformer/test_model/test_pure_pipeline.py b/tests/test_shardformer/test_model/test_pure_pipeline.py new file mode 100644 index 000000000..24cda193a --- /dev/null +++ b/tests/test_shardformer/test_model/test_pure_pipeline.py @@ -0,0 +1,164 @@ +import random +from contextlib import nullcontext +from typing import Any, Callable, Iterator, List, Optional, Tuple + +import numpy as np +import pytest +import torch +import torch.distributed as dist +from torch import Tensor +from torch.nn import Module +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward + +DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 + + +class PipelineOptimizer(OptimizerWrapper): + + def __init__(self, optim: Optimizer, model: Module): + super().__init__(optim) + params = set(model.parameters()) + new_param_groups = [] + for group in optim.param_groups: + params = [p for p in group['params'] if p in params] + new_param_groups.append({**group, 'params': params}) + optim.__setstate__({'param_groups': new_param_groups}) + # TODO: support amp + + +class PipelinedModel(ModelWrapper): + + def __init__(self, module: Module, shard_config: ShardConfig, stage_manager: PipelineStageManager) -> None: + self.stage_manager = stage_manager + shardformer = ShardFormer(shard_config) + module, self.shared_params = shardformer.optimize(module) + self.shared_param_process_groups = [] + super().__init__(module) + + +def prepare_dataloader(dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0): + sampler = DistributedSampler( + dataset, + #rank=self.pg_mesh.coordinate(DP_AXIS), + shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + ) + + +def execute_pipeline( + data_iter: Iterator, + model: PipelinedModel, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: PipelineOptimizer, + return_loss: bool = True, + return_outputs: bool = False, + schedule: OneForwardOneBackwardSchedule = None, +) -> dict: + # return loss or outputs if needed + outputs = schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss, return_outputs) + return outputs + + +class data_iter(): + + def __getitem__(self, x): + return torch.randint(0, 100, (4, 128)).cuda() + + +def loss(x, y): + return (x[0].float().mean() - y[0].float().mean()) + + +@parameterize('enable_fused_normalization', [False]) +@parameterize('enable_tensor_parallelism', [False]) +@parameterize('use_lazy_init', [False]) +def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): + PP_DIM = 0 + PP_SIZE = 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + from datasets import load_dataset + + #dataset = load_dataset("open_subtitles", lang1="fi", lang2="hi") + pg_mesh = ProcessGroupMesh(PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + num_microbatches = 2 + org_model = model_fn().cuda() + optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3) + #dataloader=prepare_dataloader(dataset=dataset['train'],batch_size=4) + schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager) + shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + pipeline_stage_manager=stage_manager) + pipelined_model = PipelinedModel(org_model, shard_config, stage_manager) + pp_optimizer = PipelineOptimizer(optimizer, pipelined_model) + data_it = iter(data_iter()) + results = execute_pipeline(data_it, pipelined_model, loss, pp_optimizer, schedule=schedule) + if stage_manager.is_last_stage(): + assert results['loss'] is not None + assert results['outputs'] is None + torch.cuda.empty_cache() + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(): + spawn(check_llama, 2) + + +if __name__ == "__main__": + test_llama() diff --git a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py index 9cca5ec8b..4feaf982a 100644 --- a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py @@ -45,25 +45,37 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_laz stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') - x = torch.randint(0, 1000, (2, 3)).cuda() - hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name == 'transformers_bert': - org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) + org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + if name == 'transformers_bert_for_mcq': + x = torch.randint(0, 1000, (2, 3, 3)).cuda() + attention_mask = torch.ones_like(x).cuda() + if stage_manager.stage == 0: + output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) + assert output['hidden_states'].shape == (6, 3, 128) + else: + hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda() + output = sharded_model(input_ids=x, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + assert output[0].shape == (2, 3) + else: + x = torch.randint(0, 1000, (2, 3)).cuda() + # one batch, 2 single sentences, each sentence has 3 tokens + hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() if stage_manager.stage == 0: attention_mask = torch.ones_like(x).cuda() output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - # print(output['hidden_states'].shape) assert output['hidden_states'].shape == (2, 3, 128) else: attention_mask = torch.ones((2, 3)).cuda() output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask, stage_manager=stage_manager) - # print(output[0].shape) - assert output[0].shape == (2, 3, 128) + assert output[0].shape[0] == 2 torch.cuda.empty_cache()