[pipeline] All bert models (#4233)

* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* Revert "bloom policy"

This reverts commit 8dee68a0a2.

This policy should be revert and copied to feature/bloom

* revert the bloom changes

* cancel unneeded inputs

* gpt

* finish llama

* causal lm and sequence classification

* revision

* add pure pipeline test

* finish some bert models

* finish all bert models

* finish bert tests

* fix bugs

* fix bugs

* fix test pipeline

* fix data gen for qa

* update the set pipeline forward

* shared params

* fix bugs
This commit is contained in:
Jianghai 2023-07-17 16:12:20 +08:00 committed by Hongxin Liu
parent a14d352088
commit e7cc62d735
13 changed files with 988 additions and 144 deletions

View File

@ -64,6 +64,9 @@ def _broadcast_object_list(object_list: List[Any],
my_rank = dist.get_rank() my_rank = dist.get_rank()
# Serialize object_list elements to tensors on src rank. # Serialize object_list elements to tensors on src rank.
if my_rank == src: if my_rank == src:
if torch.__version__ >= "1.13.0":
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=device) for obj in object_list])
else:
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
object_sizes_tensor = torch.cat(size_list) object_sizes_tensor = torch.cat(size_list)
else: else:

View File

@ -205,7 +205,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# the backward pass. # the backward pass.
input_obj = input_objs.pop(0) input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0) output_obj = output_objs.pop(0)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
if last_iteration: if last_iteration:

View File

@ -42,6 +42,8 @@ _POLICY_LIST = {
PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"), PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"),
"transformers.models.bert.modeling_bert.BertForMultipleChoice": "transformers.models.bert.modeling_bert.BertForMultipleChoice":
PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"), PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"),
"transformers.models.bert.modeling_bert.BertForQuestionAnswering":
PolicyLocation(file_name="bert", class_name="BertForQuestionAnsweringPolicy"),
# LLaMA # LLaMA
"transformers.models.llama.modeling_llama.LlamaModel": "transformers.models.llama.modeling_llama.LlamaModel":

View File

@ -1,22 +1,30 @@
from functools import partial from functools import partial
from types import MethodType from types import MethodType
from typing import Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.nn import CrossEntropyLoss, Module from torch.nn import CrossEntropyLoss, Module
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
MultipleChoiceModelOutput,
NextSentencePredictorOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
) )
from transformers.models.bert.modeling_bert import ( from transformers.models.bert.modeling_bert import (
BertForMaskedLM, BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction, BertForNextSentencePrediction,
BertForPreTraining, BertForPreTraining,
BertForPreTrainingOutput, BertForPreTrainingOutput,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
BertLMHeadModel, BertLMHeadModel,
BertModel, BertModel,
) )
@ -31,9 +39,9 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDe
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
__all__ = [ __all__ = [
'BertPolicy', 'BertModelPolicy', 'BertForPreTrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy', 'BertPolicy', 'BertModelPolicy', 'BertForPreTrainingPolicy', 'BertLMdHeadModelPolicy', 'BertForMaskedLMPolicy',
'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy', 'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy',
'BertForMultipleChoicePolicy' 'BertForMultipleChoicePolicy', 'BertForQuestionAnsweringPolicy'
] ]
@ -172,6 +180,25 @@ class BertPolicy(Policy):
def postprocess(self): def postprocess(self):
return self.model return self.model
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "BertModel":
module = self.model
else:
module = self.model.bert
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=model_cls)
return
# BertModel # BertModel
class BertModelPolicy(BertPolicy): class BertModelPolicy(BertPolicy):
@ -180,13 +207,10 @@ class BertModelPolicy(BertPolicy):
super().__init__() super().__init__()
def module_policy(self): def module_policy(self):
module_policy = super().module_policy() policy = super().module_policy()
from transformers.models.bert.modeling_bert import BertModel from transformers.models.bert.modeling_bert import BertModel
if self.pipeline_stage_manager: self.set_pipeline_forward(model_cls=BertModel, new_forward=bert_model_forward, policy=policy)
# set None as default return policy
module_policy[BertModel] = ModulePolicyDescription(
method_replacement={'forward': partial(bert_model_forward, stage_manager=self.pipeline_stage_manager)})
return module_policy
def get_held_layers(self) -> List[Module]: def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
@ -214,15 +238,17 @@ class BertForPreTrainingPolicy(BertPolicy):
super().__init__() super().__init__()
def module_policy(self): def module_policy(self):
module_policy = super().module_policy() policy = super().module_policy()
module_policy = self.add_lm_head_policy(module_policy) policy = self.add_lm_head_policy(policy)
return module_policy from transformers.models.bert.modeling_bert import BertForPreTraining
self.set_pipeline_forward(model_cls=BertForPreTraining, new_forward=bert_for_pretraining_forward, policy=policy)
return policy
def get_held_layers(self) -> List[Module]: def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage""" """Get pipeline layers for current stage"""
module = self.model module = self.model
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
layers_per_stage = self.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages)
held_layers = [] held_layers = []
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.bert.embeddings) held_layers.append(module.bert.embeddings)
@ -237,11 +263,18 @@ class BertForPreTrainingPolicy(BertPolicy):
return held_layers return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]: def get_shared_params(self) -> List[Dict[int, Tensor]]:
'''No shared params in bertmodel''' model = self.model
if self.pipeline_stage_manager:
if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight):
#tie weights
return [{
0: model.bert.embeddings.word_embeddings.weight,
self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight
}]
return [] return []
def postprocess(self): def postprocess(self):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items(): for k, v in binding_map.items():
param = getattr_(self.model, k) param = getattr_(self.model, k)
@ -256,9 +289,11 @@ class BertLMHeadModelPolicy(BertPolicy):
super().__init__() super().__init__()
def module_policy(self): def module_policy(self):
module_policy = super().module_policy() policy = super().module_policy()
module_policy = self.add_lm_head_policy(module_policy) policy = self.add_lm_head_policy(policy)
return module_policy from transformers.models.bert.modeling_bert import BertLMHeadModel
self.set_pipeline_forward(model_cls=BertLMHeadModel, new_forward=bert_lm_head_model_forward, policy=policy)
return policy
def get_held_layers(self) -> List[Module]: def get_held_layers(self) -> List[Module]:
""" """
@ -267,7 +302,7 @@ class BertLMHeadModelPolicy(BertPolicy):
module = self.model module = self.model
held_layers = [] held_layers = []
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
layers_per_stage = self.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages)
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.bert.embeddings) held_layers.append(module.bert.embeddings)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
@ -278,11 +313,18 @@ class BertLMHeadModelPolicy(BertPolicy):
return held_layers return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]: def get_shared_params(self) -> List[Dict[int, Tensor]]:
'''No shared params in bertmodel''' bert_model = self.model.bert
if self.pipeline_stage_manager:
if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
#tie weights
return [{
0: bert_model.embeddings.word_embeddings.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
}]
return [] return []
def postprocess(self): def postprocess(self):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items(): for k, v in binding_map.items():
param = getattr_(self.model, k) param = getattr_(self.model, k)
@ -297,12 +339,42 @@ class BertForMaskedLMPolicy(BertPolicy):
super().__init__() super().__init__()
def module_policy(self): def module_policy(self):
module_policy = super().module_policy() policy = super().module_policy()
module_policy = self.add_lm_head_policy(module_policy) policy = self.add_lm_head_policy(policy)
return module_policy from transformers.models.bert.modeling_bert import BertForMaskedLM
self.set_pipeline_forward(model_cls=BertForMaskedLM, new_forward=bert_for_masked_lm_forward, policy=policy)
return policy
def get_held_layers(self) -> List[Module]:
"""
get pipeline layers for current stage
"""
module = self.model
held_layers = []
stage_manager = self.pipeline_stage_manager
layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.bert.embeddings)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.bert.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.bert.pooler)
held_layers.append(module.cls)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
bert_model = self.model.bert
if self.pipeline_stage_manager:
if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
#tie weights
return [{
0: bert_model.embeddings.word_embeddings.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
}]
return []
def postprocess(self): def postprocess(self):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items(): for k, v in binding_map.items():
param = getattr_(self.model, k) param = getattr_(self.model, k)
@ -319,7 +391,7 @@ class BertForSequenceClassificationPolicy(BertPolicy):
def module_policy(self): def module_policy(self):
from transformers.models.bert.modeling_bert import BertForSequenceClassification from transformers.models.bert.modeling_bert import BertForSequenceClassification
module_policy = super().module_policy() policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
addon_module = { addon_module = {
@ -331,8 +403,35 @@ class BertForSequenceClassificationPolicy(BertPolicy):
) )
]) ])
} }
module_policy.update(addon_module) policy.update(addon_module)
return module_policy
self.set_pipeline_forward(model_cls=BertForSequenceClassification,
new_forward=bert_for_sequence_classification_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[Module]:
"""
get pipeline layers for current stage
"""
module = self.model
held_layers = []
stage_manager = self.pipeline_stage_manager
layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.bert.embeddings)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.bert.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.bert.pooler)
held_layers.append(module.dropout)
held_layers.append(module.classifier)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
# no shared params for sequence classification model
return []
# BertForTokenClassification # BertForTokenClassification
@ -344,7 +443,7 @@ class BertForTokenClassificationPolicy(BertPolicy):
def module_policy(self): def module_policy(self):
from transformers.models.bert.modeling_bert import BertForTokenClassification from transformers.models.bert.modeling_bert import BertForTokenClassification
module_policy = super().module_policy() policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
addon_module = { addon_module = {
@ -356,8 +455,35 @@ class BertForTokenClassificationPolicy(BertPolicy):
) )
]) ])
} }
module_policy.update(addon_module) policy.update(addon_module)
return module_policy
self.set_pipeline_forward(model_cls=BertForTokenClassification,
new_forward=bert_for_token_classification_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[Module]:
"""
get pipeline layers for current stage
"""
module = self.model
held_layers = []
stage_manager = self.pipeline_stage_manager
layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.bert.embeddings)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.bert.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.bert.pooler)
held_layers.append(module.dropout)
held_layers.append(module.classifier)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
# no shared params for sequence classification model
return []
# BertForNextSentencePrediction # BertForNextSentencePrediction
@ -366,6 +492,36 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def module_policy(self):
policy = super().module_policy()
from transformers.models.bert.modeling_bert import BertForNextSentencePrediction
self.set_pipeline_forward(model_cls=BertForNextSentencePrediction,
new_forward=bert_for_next_sentence_prediction_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[Module]:
"""
get pipeline layers for current stage
"""
module = self.model
held_layers = []
stage_manager = self.pipeline_stage_manager
layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.bert.embeddings)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.bert.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.bert.pooler)
held_layers.append(module.cls)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
# no shared params for sequence classification model
return []
# BertForMultipleChoice # BertForMultipleChoice
class BertForMultipleChoicePolicy(BertPolicy): class BertForMultipleChoicePolicy(BertPolicy):
@ -376,7 +532,7 @@ class BertForMultipleChoicePolicy(BertPolicy):
def module_policy(self): def module_policy(self):
from transformers.models.bert.modeling_bert import BertForMultipleChoice from transformers.models.bert.modeling_bert import BertForMultipleChoice
module_policy = super().module_policy() policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
addon_module = { addon_module = {
@ -388,8 +544,71 @@ class BertForMultipleChoicePolicy(BertPolicy):
) )
]) ])
} }
module_policy.update(addon_module) policy.update(addon_module)
return module_policy
self.set_pipeline_forward(model_cls=BertForMultipleChoice,
new_forward=bert_for_multiple_choice_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[Module]:
"""
get pipeline layers for current stage
"""
module = self.model
held_layers = []
stage_manager = self.pipeline_stage_manager
layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.bert.embeddings)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.bert.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.bert.pooler)
held_layers.append(module.dropout)
held_layers.append(module.classifier)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
# no shared params for sequence classification model
return []
class BertForQuestionAnsweringPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForQuestionAnswering
policy = super().module_policy()
self.set_pipeline_forward(model_cls=BertForQuestionAnswering,
new_forward=bert_for_question_answering_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[Module]:
"""
get pipeline layers for current stage
"""
module = self.model
held_layers = []
stage_manager = self.pipeline_stage_manager
layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.bert.embeddings)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.bert.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.bert.pooler)
held_layers.append(module.qa_outputs)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
# no shared params for sequence classification model
return []
def bert_model_forward( def bert_model_forward(
@ -403,13 +622,13 @@ def bert_model_forward(
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[List[torch.FloatTensor]] = None,
# labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
stage_index: Optional[List[int]] = None,
): ):
# TODO: add explaination of the output here. # TODO: add explaination of the output here.
r""" r"""
@ -528,14 +747,10 @@ def bert_model_forward(
use_cache = False use_cache = False
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# calculate the num_layers start_idx, end_idx = stage_index[0], stage_index[1]
num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages
start_layer = stage_manager.stage * num_layers_per_stage
end_layer = (stage_manager.stage + 1) * num_layers_per_stage
# layer_outputs # layer_outputs
layer_outputs = hidden_states if hidden_states is not None else None layer_outputs = hidden_states if hidden_states is not None else None
for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
if stage_manager.is_first_stage() and idx == 0: if stage_manager.is_first_stage() and idx == 0:
encoder_attention_mask = encoder_extended_attention_mask encoder_attention_mask = encoder_extended_attention_mask
@ -593,8 +808,9 @@ def bert_model_forward(
return (sequence_output, pooled_output) + layer_outputs[1:] return (sequence_output, pooled_output) + layer_outputs[1:]
# return dict is not supported at this moment # return dict is not supported at this moment
else: else:
return BaseModelOutputWithPastAndCrossAttentions( return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=next_decoder_cache, past_key_values=next_decoder_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attentions, attentions=all_self_attentions,
@ -624,6 +840,7 @@ def bert_for_pretraining_forward(
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
): ):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
@ -637,7 +854,8 @@ def bert_for_pretraining_forward(
logger.warning_once('return_dict is not supported for pipeline models at the moment') logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False return_dict = False
outputs = bert_model_forward(self.bert, outputs = bert_model_forward(
self.bert,
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
@ -648,7 +866,9 @@ def bert_for_pretraining_forward(
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states if hidden_states is not None else None) hidden_states=hidden_states if hidden_states is not None else None,
stage_index=stage_index,
)
past_key_values = None past_key_values = None
all_hidden_states = None all_hidden_states = None
all_self_attentions = None all_self_attentions = None
@ -684,7 +904,8 @@ def bert_for_pretraining_forward(
} }
def bert_lmhead_forward(self: BertLMHeadModel, def bert_lm_head_model_forward(
self: BertLMHeadModel,
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None,
@ -700,7 +921,9 @@ def bert_lmhead_forward(self: BertLMHeadModel,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_manager: Optional[PipelineStageManager] = None): stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
):
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
@ -754,7 +977,8 @@ def bert_lmhead_forward(self: BertLMHeadModel,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states if hidden_states is not None else None) hidden_states=hidden_states if hidden_states is not None else None,
stage_index=stage_index)
past_key_values = None past_key_values = None
all_hidden_states = None all_hidden_states = None
all_self_attentions = None all_self_attentions = None
@ -806,15 +1030,66 @@ def bert_for_masked_lm_forward(
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
hidden_states: Optional[torch.Tensor] = None, hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
): ):
#-> Union[Tuple[torch.Tensor], MaskedLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
""" """
pass return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
outputs = bert_model_forward(
self.bert,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
hidden_states=hidden_states,
stage_manager=stage_manager,
stage_index=stage_index,
)
if stage_manager.is_last_stage():
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
hidden_states = outputs.get('hidden_states')
return {'hidden_states': hidden_states}
def bert_for_next_sentence_prediction_forward( def bert_for_next_sentence_prediction_forward(
@ -831,6 +1106,7 @@ def bert_for_next_sentence_prediction_forward(
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
hidden_states: Optional[torch.Tensor] = None, hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
**kwargs, **kwargs,
): ):
#-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: #-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
@ -881,8 +1157,7 @@ def bert_for_next_sentence_prediction_forward(
return_dict = False return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = bert_model_forward( outputs = bert_model_forward(self.bert,
self.bert,
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
@ -892,7 +1167,10 @@ def bert_for_next_sentence_prediction_forward(
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
) hidden_states=hidden_states,
stage_manager=stage_manager,
stage_index=stage_index)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
pooled_output = outputs[1] pooled_output = outputs[1]
seq_relationship_scores = self.cls(pooled_output) seq_relationship_scores = self.cls(pooled_output)
@ -916,3 +1194,355 @@ def bert_for_next_sentence_prediction_forward(
hidden_states = outputs.get('hidden_states') hidden_states = outputs.get('hidden_states')
# intermediate stage always return dict # intermediate stage always return dict
return {'hidden_states': hidden_states} return {'hidden_states': hidden_states}
def bert_for_sequence_classification_forward(
self: BertForSequenceClassification,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = bert_model_forward(self.bert,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
hidden_states=hidden_states,
stage_manager=stage_manager,
stage_index=stage_index)
if stage_manager.is_last_stage():
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
hidden_states = outputs.get('hidden_states')
return {'hidden_states': hidden_states}
def bert_for_token_classification_forward(
self: BertForTokenClassification,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = bert_model_forward(
self.bert,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
hidden_states=hidden_states,
stage_manager=stage_manager,
stage_index=stage_index,
)
if stage_manager.is_last_stage():
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
hidden_states = outputs.get('hidden_states')
return {'hidden_states': hidden_states}
def bert_for_multiple_choice_forward(
self: BertForMultipleChoice,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
`input_ids` above)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# in our pipeline design,input ids are copied for every stage and shouldn't be none
# the input_ids for multiple choice model is [batch_size, num_choices, sequence_length]
if stage_manager.is_last_stage():
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None else None)
outputs = bert_model_forward(
self.bert,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
hidden_states=hidden_states,
stage_manager=stage_manager,
stage_index=stage_index,
)
if stage_manager.is_last_stage():
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
if not return_dict:
output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
hidden_states = outputs.get('hidden_states')
return {'hidden_states': hidden_states}
def bert_for_question_answering_forward(
self: BertForQuestionAnswering,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
start_positions: Optional[torch.Tensor] = None,
end_positions: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
):
# NOTE: the arg start_position and end_position are used only for the last stage
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = bert_model_forward(
self.bert,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
hidden_states=hidden_states,
stage_manager=stage_manager,
stage_index=stage_index,
)
if stage_manager.is_last_stage():
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
hidden_states = outputs.get('hidden_states')
return {'hidden_states': hidden_states}

View File

@ -212,11 +212,13 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
return held_layers return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]: def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama model"""
llama_model = self.model.model llama_model = self.model.model
if id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight): if id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight):
# tie weights # tie weights
return [{0: llama_model.embed_tokens.weight, self.stage_manager.num_stages - 1: self.model.lm_head.weight}] return [{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight
}]
return [] return []

View File

@ -1 +1 @@
from .torchrec import * #from .torchrec import *

View File

@ -87,6 +87,17 @@ def data_gen_for_mcq():
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
def data_gen_for_qa():
# generating data for question answering
# no need for labels and use start and end position instead
data = data_gen()
start_positions = torch.tensor([0], dtype=torch.int64)
data['start_positions'] = start_positions
end_positions = torch.tensor([1], dtype=torch.int64)
data['end_positions'] = end_positions
return data
# define output transform function # define output transform function
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
@ -150,3 +161,9 @@ model_zoo.register(name='transformers_bert_for_mcq',
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
loss_fn=loss_fn, loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True)) model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_question_answering',
model_fn=lambda: transformers.BertForQuestionAnswering(config),
data_gen_fn=data_gen_for_qa,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))

View File

@ -7,6 +7,7 @@ from transformers.models.bert.modeling_bert import BertForPreTraining
import colossalai import colossalai
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.policies.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward from colossalai.shardformer.policies.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
@ -35,16 +36,20 @@ def check_bert_for_pretraining_forward():
stage_manager = PipelineStageManager(pg_mesh, PP_DIM) stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank() rank = dist.get_rank()
# print(rank) # print(rank)
layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
x = torch.randint(0, 1000, (2, 3)) x = torch.randint(0, 1000, (2, 3))
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
if stage_manager.stage == 0: if stage_manager.stage == 0:
attention_mask = torch.ones_like(x) attention_mask = torch.ones_like(x)
output = bert_for_pretraining_forward(self=model, output = bert_for_pretraining_forward(
self=model,
input_ids=x, input_ids=x,
attention_mask=attention_mask, attention_mask=attention_mask,
stage_manager=stage_manager) stage_manager=stage_manager,
print(output['hidden_states'].shape) stage_index=stage_index,
)
assert output['hidden_states'].shape == (2, 3, 768) assert output['hidden_states'].shape == (2, 3, 768)
else: else:
@ -52,8 +57,8 @@ def check_bert_for_pretraining_forward():
output = bert_for_pretraining_forward(self=model, output = bert_for_pretraining_forward(self=model,
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
stage_manager=stage_manager) stage_manager=stage_manager,
print(output[0].shape) stage_index=stage_index)
assert output[0].shape == (2, 3, 30522) assert output[0].shape == (2, 3, 30522)
# assert output[1].shape == (2, 768) # assert output[1].shape == (2, 768)

View File

@ -7,12 +7,13 @@ from transformers.models.bert.modeling_bert import BertLMHeadModel
import colossalai import colossalai
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lmhead_forward from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lm_head_model_forward
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_bert_lmhead_forward(): def check_bert_lm_head_model_forward():
configuration = BertConfig() configuration = BertConfig()
model = BertLMHeadModel(configuration) model = BertLMHeadModel(configuration)
DP_DIM, PP_DIM = 0, 1 DP_DIM, PP_DIM = 0, 1
@ -35,24 +36,28 @@ def check_bert_lmhead_forward():
stage_manager = PipelineStageManager(pg_mesh, PP_DIM) stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank() rank = dist.get_rank()
# print(rank) # print(rank)
layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
x = torch.randint(0, 1000, (2, 3)) x = torch.randint(0, 1000, (2, 3))
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
if stage_manager.stage == 0: if stage_manager.stage == 0:
attention_mask = torch.ones_like(x) attention_mask = torch.ones_like(x)
output = bert_lmhead_forward(self=model,
output = bert_lm_head_model_forward(self=model,
input_ids=x, input_ids=x,
attention_mask=attention_mask, attention_mask=attention_mask,
stage_manager=stage_manager) stage_manager=stage_manager,
stage_index=stage_index)
print(output['hidden_states'].shape) print(output['hidden_states'].shape)
assert output['hidden_states'].shape == (2, 3, 768) assert output['hidden_states'].shape == (2, 3, 768)
else: else:
attention_mask = torch.ones((2, 3)) attention_mask = torch.ones((2, 3))
output = bert_lmhead_forward(self=model, output = bert_lm_head_model_forward(self=model,
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
stage_manager=stage_manager) stage_manager=stage_manager,
stage_index=stage_index)
print(output[0].shape) print(output[0].shape)
assert output[0].shape == (2, 3, 30522) assert output[0].shape == (2, 3, 30522)
@ -93,7 +98,7 @@ def check_bert_lmhead_policy():
def run_dist_model(rank, world_size, port): def run_dist_model(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_bert_lmhead_forward() check_bert_lm_head_model_forward()
def run_dist_policy(rank, world_size, port): def run_dist_policy(rank, world_size, port):
@ -103,7 +108,7 @@ def run_dist_policy(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_bert_lmhead_forward(): def test_bert_lm_head_model_forward():
spawn(run_dist_model, 4) spawn(run_dist_model, 4)
@ -115,5 +120,5 @@ def test_bert_lmhead_policy():
if __name__ == "__main__": if __name__ == "__main__":
"""test the bert for pretraining model forward and bert for pretraining model policy""" """test the bert for pretraining model forward and bert for pretraining model policy"""
test_bert_lmhead_forward() test_bert_lm_head_model_forward()
test_bert_lmhead_policy() test_bert_lmhead_policy()

View File

@ -6,12 +6,14 @@ from transformers.models.bert.modeling_bert import BertModel
import colossalai import colossalai
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.policies.bert import BertModelPolicy, bert_model_forward from colossalai.shardformer.policies.bert import BertModelPolicy, bert_model_forward
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_bert_model_forward(): def check_bert_model_forward():
# this test may crash for internet reasons
model = BertModel.from_pretrained('bert-base-uncased') model = BertModel.from_pretrained('bert-base-uncased')
DP_DIM, PP_DIM = 0, 1 DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2 DP_SIZE, PP_SIZE = 2, 2
@ -34,20 +36,25 @@ def check_bert_model_forward():
stage_manager = PipelineStageManager(pg_mesh, PP_DIM) stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank() rank = dist.get_rank()
# print(rank) # print(rank)
layers_per_stage = Policy.distribute_layers(len(model.encoder.layer), 2)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
x = torch.randint(0, 1000, (2, 3)) x = torch.randint(0, 1000, (2, 3))
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
if stage_manager.stage == 0: if stage_manager.stage == 0:
attention_mask = torch.ones_like(x) attention_mask = torch.ones_like(x)
output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) output = bert_model_forward(self=model,
print(output['hidden_states'].shape) input_ids=x,
attention_mask=attention_mask,
stage_manager=stage_manager,
stage_index=stage_index)
assert output['hidden_states'].shape == (2, 3, 768) assert output['hidden_states'].shape == (2, 3, 768)
else: else:
attention_mask = torch.ones((2, 3)) attention_mask = torch.ones((2, 3))
output = bert_model_forward(self=model, output = bert_model_forward(self=model,
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
stage_manager=stage_manager) stage_manager=stage_manager,
stage_index=stage_index)
print(output[0].shape) print(output[0].shape)
assert output[0].shape == (2, 3, 768) assert output[0].shape == (2, 3, 768)
@ -112,4 +119,3 @@ if __name__ == "__main__":
"""test the bert model forward and bert model policy""" """test the bert model forward and bert model policy"""
#test_bert_model_forward() #test_bert_model_forward()
test_bert_model_policy() test_bert_model_policy()
# this test need config to run

View File

@ -49,7 +49,6 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
# prepare input # prepare input
data = data_gen_fn() data = data_gen_fn()
data = {k: v.cuda() for k, v in data.items()} data = {k: v.cuda() for k, v in data.items()}
# switch to train mode # switch to train mode
original_model.train() original_model.train()
sharded_model.train() sharded_model.train()

View File

@ -0,0 +1,164 @@
import random
from contextlib import nullcontext
from typing import Any, Callable, Iterator, List, Optional, Tuple
import numpy as np
import pytest
import torch
import torch.distributed as dist
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
class PipelineOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer, model: Module):
super().__init__(optim)
params = set(model.parameters())
new_param_groups = []
for group in optim.param_groups:
params = [p for p in group['params'] if p in params]
new_param_groups.append({**group, 'params': params})
optim.__setstate__({'param_groups': new_param_groups})
# TODO: support amp
class PipelinedModel(ModelWrapper):
def __init__(self, module: Module, shard_config: ShardConfig, stage_manager: PipelineStageManager) -> None:
self.stage_manager = stage_manager
shardformer = ShardFormer(shard_config)
module, self.shared_params = shardformer.optimize(module)
self.shared_param_process_groups = []
super().__init__(module)
def prepare_dataloader(dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0):
sampler = DistributedSampler(
dataset,
#rank=self.pg_mesh.coordinate(DP_AXIS),
shuffle=shuffle)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
)
def execute_pipeline(
data_iter: Iterator,
model: PipelinedModel,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: PipelineOptimizer,
return_loss: bool = True,
return_outputs: bool = False,
schedule: OneForwardOneBackwardSchedule = None,
) -> dict:
# return loss or outputs if needed
outputs = schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss, return_outputs)
return outputs
class data_iter():
def __getitem__(self, x):
return torch.randint(0, 100, (4, 128)).cuda()
def loss(x, y):
return (x[0].float().mean() - y[0].float().mean())
@parameterize('enable_fused_normalization', [False])
@parameterize('enable_tensor_parallelism', [False])
@parameterize('use_lazy_init', [False])
def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
PP_DIM = 0
PP_SIZE = 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
from datasets import load_dataset
#dataset = load_dataset("open_subtitles", lang1="fi", lang2="hi")
pg_mesh = ProcessGroupMesh(PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
num_microbatches = 2
org_model = model_fn().cuda()
optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3)
#dataloader=prepare_dataloader(dataset=dataset['train'],batch_size=4)
schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager)
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
pipeline_stage_manager=stage_manager)
pipelined_model = PipelinedModel(org_model, shard_config, stage_manager)
pp_optimizer = PipelineOptimizer(optimizer, pipelined_model)
data_it = iter(data_iter())
results = execute_pipeline(data_it, pipelined_model, loss, pp_optimizer, schedule=schedule)
if stage_manager.is_last_stage():
assert results['loss'] is not None
assert results['outputs'] is None
torch.cuda.empty_cache()
def check_llama(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_llama_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama():
spawn(check_llama, 2)
if __name__ == "__main__":
test_llama()

View File

@ -45,25 +45,37 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
stage_manager = PipelineStageManager(pg_mesh, PP_DIM) stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
x = torch.randint(0, 1000, (2, 3)).cuda()
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name == 'transformers_bert':
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init) enable_tensor_parallelism, use_lazy_init)
if name == 'transformers_bert_for_mcq':
x = torch.randint(0, 1000, (2, 3, 3)).cuda()
attention_mask = torch.ones_like(x).cuda()
if stage_manager.stage == 0:
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
assert output['hidden_states'].shape == (6, 3, 128)
else:
hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda()
output = sharded_model(input_ids=x,
hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
assert output[0].shape == (2, 3)
else:
x = torch.randint(0, 1000, (2, 3)).cuda()
# one batch, 2 single sentences, each sentence has 3 tokens
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
if stage_manager.stage == 0: if stage_manager.stage == 0:
attention_mask = torch.ones_like(x).cuda() attention_mask = torch.ones_like(x).cuda()
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
# print(output['hidden_states'].shape)
assert output['hidden_states'].shape == (2, 3, 128) assert output['hidden_states'].shape == (2, 3, 128)
else: else:
attention_mask = torch.ones((2, 3)).cuda() attention_mask = torch.ones((2, 3)).cuda()
output = sharded_model(hidden_states=hidden_states, output = sharded_model(hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
stage_manager=stage_manager) stage_manager=stage_manager)
# print(output[0].shape) assert output[0].shape[0] == 2
assert output[0].shape == (2, 3, 128)
torch.cuda.empty_cache() torch.cuda.empty_cache()