mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-18 19:58:17 +00:00
[pipeline] All bert models (#4233)
* bloom policy
* llama pipeline forward and tests
* fix the output and attention_mask
* fix name
* bind argument to policy
* Revert "bloom policy"
This reverts commit 8dee68a0a2
.
This policy should be revert and copied to feature/bloom
* revert the bloom changes
* cancel unneeded inputs
* gpt
* finish llama
* causal lm and sequence classification
* revision
* add pure pipeline test
* finish some bert models
* finish all bert models
* finish bert tests
* fix bugs
* fix bugs
* fix test pipeline
* fix data gen for qa
* update the set pipeline forward
* shared params
* fix bugs
This commit is contained in:
parent
a14d352088
commit
e7cc62d735
@ -64,6 +64,9 @@ def _broadcast_object_list(object_list: List[Any],
|
||||
my_rank = dist.get_rank()
|
||||
# Serialize object_list elements to tensors on src rank.
|
||||
if my_rank == src:
|
||||
if torch.__version__ >= "1.13.0":
|
||||
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=device) for obj in object_list])
|
||||
else:
|
||||
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
|
||||
object_sizes_tensor = torch.cat(size_list)
|
||||
else:
|
||||
|
@ -205,7 +205,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
# the backward pass.
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
if last_iteration:
|
||||
|
@ -42,6 +42,8 @@ _POLICY_LIST = {
|
||||
PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForMultipleChoice":
|
||||
PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForQuestionAnswering":
|
||||
PolicyLocation(file_name="bert", class_name="BertForQuestionAnsweringPolicy"),
|
||||
|
||||
# LLaMA
|
||||
"transformers.models.llama.modeling_llama.LlamaModel":
|
||||
|
@ -1,22 +1,30 @@
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import CrossEntropyLoss, Module
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
MultipleChoiceModelOutput,
|
||||
NextSentencePredictorOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BertForMaskedLM,
|
||||
BertForMultipleChoice,
|
||||
BertForNextSentencePrediction,
|
||||
BertForPreTraining,
|
||||
BertForPreTrainingOutput,
|
||||
BertForQuestionAnswering,
|
||||
BertForSequenceClassification,
|
||||
BertForTokenClassification,
|
||||
BertLMHeadModel,
|
||||
BertModel,
|
||||
)
|
||||
@ -31,9 +39,9 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDe
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
__all__ = [
|
||||
'BertPolicy', 'BertModelPolicy', 'BertForPreTrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy',
|
||||
'BertPolicy', 'BertModelPolicy', 'BertForPreTrainingPolicy', 'BertLMdHeadModelPolicy', 'BertForMaskedLMPolicy',
|
||||
'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy',
|
||||
'BertForMultipleChoicePolicy'
|
||||
'BertForMultipleChoicePolicy', 'BertForQuestionAnsweringPolicy'
|
||||
]
|
||||
|
||||
|
||||
@ -172,6 +180,25 @@ class BertPolicy(Policy):
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager:
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "BertModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.bert
|
||||
|
||||
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(description=method_replacement,
|
||||
policy=policy,
|
||||
target_key=model_cls)
|
||||
|
||||
return
|
||||
|
||||
|
||||
# BertModel
|
||||
class BertModelPolicy(BertPolicy):
|
||||
@ -180,13 +207,10 @@ class BertModelPolicy(BertPolicy):
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
module_policy = super().module_policy()
|
||||
policy = super().module_policy()
|
||||
from transformers.models.bert.modeling_bert import BertModel
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
module_policy[BertModel] = ModulePolicyDescription(
|
||||
method_replacement={'forward': partial(bert_model_forward, stage_manager=self.pipeline_stage_manager)})
|
||||
return module_policy
|
||||
self.set_pipeline_forward(model_cls=BertModel, new_forward=bert_model_forward, policy=policy)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
@ -214,15 +238,17 @@ class BertForPreTrainingPolicy(BertPolicy):
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
module_policy = super().module_policy()
|
||||
module_policy = self.add_lm_head_policy(module_policy)
|
||||
return module_policy
|
||||
policy = super().module_policy()
|
||||
policy = self.add_lm_head_policy(policy)
|
||||
from transformers.models.bert.modeling_bert import BertForPreTraining
|
||||
self.set_pipeline_forward(model_cls=BertForPreTraining, new_forward=bert_for_pretraining_forward, policy=policy)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage"""
|
||||
module = self.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
layers_per_stage = self.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages)
|
||||
layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages)
|
||||
held_layers = []
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.bert.embeddings)
|
||||
@ -237,11 +263,18 @@ class BertForPreTrainingPolicy(BertPolicy):
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
'''No shared params in bertmodel'''
|
||||
model = self.model
|
||||
if self.pipeline_stage_manager:
|
||||
if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight):
|
||||
#tie weights
|
||||
return [{
|
||||
0: model.bert.embeddings.word_embeddings.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight
|
||||
}]
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
@ -256,9 +289,11 @@ class BertLMHeadModelPolicy(BertPolicy):
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
module_policy = super().module_policy()
|
||||
module_policy = self.add_lm_head_policy(module_policy)
|
||||
return module_policy
|
||||
policy = super().module_policy()
|
||||
policy = self.add_lm_head_policy(policy)
|
||||
from transformers.models.bert.modeling_bert import BertLMHeadModel
|
||||
self.set_pipeline_forward(model_cls=BertLMHeadModel, new_forward=bert_lm_head_model_forward, policy=policy)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""
|
||||
@ -267,7 +302,7 @@ class BertLMHeadModelPolicy(BertPolicy):
|
||||
module = self.model
|
||||
held_layers = []
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
layers_per_stage = self.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages)
|
||||
layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.bert.embeddings)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
@ -278,11 +313,18 @@ class BertLMHeadModelPolicy(BertPolicy):
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
'''No shared params in bertmodel'''
|
||||
bert_model = self.model.bert
|
||||
if self.pipeline_stage_manager:
|
||||
if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
|
||||
#tie weights
|
||||
return [{
|
||||
0: bert_model.embeddings.word_embeddings.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
|
||||
}]
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
@ -297,12 +339,42 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
module_policy = super().module_policy()
|
||||
module_policy = self.add_lm_head_policy(module_policy)
|
||||
return module_policy
|
||||
policy = super().module_policy()
|
||||
policy = self.add_lm_head_policy(policy)
|
||||
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
||||
self.set_pipeline_forward(model_cls=BertForMaskedLM, new_forward=bert_for_masked_lm_forward, policy=policy)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""
|
||||
get pipeline layers for current stage
|
||||
"""
|
||||
module = self.model
|
||||
held_layers = []
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.bert.embeddings)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
held_layers.extend(module.bert.encoder.layer[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.bert.pooler)
|
||||
held_layers.append(module.cls)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
bert_model = self.model.bert
|
||||
if self.pipeline_stage_manager:
|
||||
if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
|
||||
#tie weights
|
||||
return [{
|
||||
0: bert_model.embeddings.word_embeddings.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
|
||||
}]
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
@ -319,7 +391,7 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertForSequenceClassification
|
||||
|
||||
module_policy = super().module_policy()
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
addon_module = {
|
||||
@ -331,8 +403,35 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
)
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
policy.update(addon_module)
|
||||
|
||||
self.set_pipeline_forward(model_cls=BertForSequenceClassification,
|
||||
new_forward=bert_for_sequence_classification_forward,
|
||||
policy=policy)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""
|
||||
get pipeline layers for current stage
|
||||
"""
|
||||
module = self.model
|
||||
held_layers = []
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.bert.embeddings)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
held_layers.extend(module.bert.encoder.layer[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.bert.pooler)
|
||||
held_layers.append(module.dropout)
|
||||
held_layers.append(module.classifier)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
# no shared params for sequence classification model
|
||||
return []
|
||||
|
||||
|
||||
# BertForTokenClassification
|
||||
@ -344,7 +443,7 @@ class BertForTokenClassificationPolicy(BertPolicy):
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertForTokenClassification
|
||||
|
||||
module_policy = super().module_policy()
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
addon_module = {
|
||||
@ -356,8 +455,35 @@ class BertForTokenClassificationPolicy(BertPolicy):
|
||||
)
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
policy.update(addon_module)
|
||||
|
||||
self.set_pipeline_forward(model_cls=BertForTokenClassification,
|
||||
new_forward=bert_for_token_classification_forward,
|
||||
policy=policy)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""
|
||||
get pipeline layers for current stage
|
||||
"""
|
||||
module = self.model
|
||||
held_layers = []
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.bert.embeddings)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
held_layers.extend(module.bert.encoder.layer[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.bert.pooler)
|
||||
held_layers.append(module.dropout)
|
||||
held_layers.append(module.classifier)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
# no shared params for sequence classification model
|
||||
return []
|
||||
|
||||
|
||||
# BertForNextSentencePrediction
|
||||
@ -366,6 +492,36 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
from transformers.models.bert.modeling_bert import BertForNextSentencePrediction
|
||||
self.set_pipeline_forward(model_cls=BertForNextSentencePrediction,
|
||||
new_forward=bert_for_next_sentence_prediction_forward,
|
||||
policy=policy)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""
|
||||
get pipeline layers for current stage
|
||||
"""
|
||||
module = self.model
|
||||
held_layers = []
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.bert.embeddings)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
held_layers.extend(module.bert.encoder.layer[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.bert.pooler)
|
||||
held_layers.append(module.cls)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
# no shared params for sequence classification model
|
||||
return []
|
||||
|
||||
|
||||
# BertForMultipleChoice
|
||||
class BertForMultipleChoicePolicy(BertPolicy):
|
||||
@ -376,7 +532,7 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertForMultipleChoice
|
||||
|
||||
module_policy = super().module_policy()
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
addon_module = {
|
||||
@ -388,8 +544,71 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
||||
)
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
policy.update(addon_module)
|
||||
|
||||
self.set_pipeline_forward(model_cls=BertForMultipleChoice,
|
||||
new_forward=bert_for_multiple_choice_forward,
|
||||
policy=policy)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""
|
||||
get pipeline layers for current stage
|
||||
"""
|
||||
module = self.model
|
||||
held_layers = []
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.bert.embeddings)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
held_layers.extend(module.bert.encoder.layer[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.bert.pooler)
|
||||
held_layers.append(module.dropout)
|
||||
held_layers.append(module.classifier)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
# no shared params for sequence classification model
|
||||
return []
|
||||
|
||||
|
||||
class BertForQuestionAnsweringPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertForQuestionAnswering
|
||||
policy = super().module_policy()
|
||||
self.set_pipeline_forward(model_cls=BertForQuestionAnswering,
|
||||
new_forward=bert_for_question_answering_forward,
|
||||
policy=policy)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""
|
||||
get pipeline layers for current stage
|
||||
"""
|
||||
module = self.model
|
||||
held_layers = []
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.bert.embeddings)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
held_layers.extend(module.bert.encoder.layer[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.bert.pooler)
|
||||
held_layers.append(module.qa_outputs)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
# no shared params for sequence classification model
|
||||
return []
|
||||
|
||||
|
||||
def bert_model_forward(
|
||||
@ -403,13 +622,13 @@ def bert_model_forward(
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
# labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
# TODO: add explaination of the output here.
|
||||
r"""
|
||||
@ -528,14 +747,10 @@ def bert_model_forward(
|
||||
use_cache = False
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# calculate the num_layers
|
||||
num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages
|
||||
start_layer = stage_manager.stage * num_layers_per_stage
|
||||
end_layer = (stage_manager.stage + 1) * num_layers_per_stage
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
# layer_outputs
|
||||
layer_outputs = hidden_states if hidden_states is not None else None
|
||||
for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer):
|
||||
for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
|
||||
if stage_manager.is_first_stage() and idx == 0:
|
||||
encoder_attention_mask = encoder_extended_attention_mask
|
||||
|
||||
@ -593,8 +808,9 @@ def bert_model_forward(
|
||||
return (sequence_output, pooled_output) + layer_outputs[1:]
|
||||
# return dict is not supported at this moment
|
||||
else:
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
@ -624,6 +840,7 @@ def bert_for_pretraining_forward(
|
||||
return_dict: Optional[bool] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
@ -637,7 +854,8 @@ def bert_for_pretraining_forward(
|
||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||
return_dict = False
|
||||
|
||||
outputs = bert_model_forward(self.bert,
|
||||
outputs = bert_model_forward(
|
||||
self.bert,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@ -648,7 +866,9 @@ def bert_for_pretraining_forward(
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states if hidden_states is not None else None)
|
||||
hidden_states=hidden_states if hidden_states is not None else None,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
past_key_values = None
|
||||
all_hidden_states = None
|
||||
all_self_attentions = None
|
||||
@ -684,7 +904,8 @@ def bert_for_pretraining_forward(
|
||||
}
|
||||
|
||||
|
||||
def bert_lmhead_forward(self: BertLMHeadModel,
|
||||
def bert_lm_head_model_forward(
|
||||
self: BertLMHeadModel,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
@ -700,7 +921,9 @@ def bert_lmhead_forward(self: BertLMHeadModel,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None):
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
r"""
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||
@ -754,7 +977,8 @@ def bert_lmhead_forward(self: BertLMHeadModel,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states if hidden_states is not None else None)
|
||||
hidden_states=hidden_states if hidden_states is not None else None,
|
||||
stage_index=stage_index)
|
||||
past_key_values = None
|
||||
all_hidden_states = None
|
||||
all_self_attentions = None
|
||||
@ -806,15 +1030,66 @@ def bert_for_masked_lm_forward(
|
||||
return_dict: Optional[bool] = None,
|
||||
hidden_states: Optional[torch.Tensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
#-> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
||||
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
||||
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
||||
"""
|
||||
pass
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if return_dict:
|
||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||
return_dict = False
|
||||
|
||||
outputs = bert_model_forward(
|
||||
self.bert,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
hidden_states=hidden_states,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.cls(sequence_output)
|
||||
|
||||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||||
|
||||
return MaskedLMOutput(
|
||||
loss=masked_lm_loss,
|
||||
logits=prediction_scores,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs.get('hidden_states')
|
||||
return {'hidden_states': hidden_states}
|
||||
|
||||
|
||||
def bert_for_next_sentence_prediction_forward(
|
||||
@ -831,6 +1106,7 @@ def bert_for_next_sentence_prediction_forward(
|
||||
return_dict: Optional[bool] = None,
|
||||
hidden_states: Optional[torch.Tensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
#-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
|
||||
@ -881,8 +1157,7 @@ def bert_for_next_sentence_prediction_forward(
|
||||
return_dict = False
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = bert_model_forward(
|
||||
self.bert,
|
||||
outputs = bert_model_forward(self.bert,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@ -892,7 +1167,10 @@ def bert_for_next_sentence_prediction_forward(
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states=hidden_states,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
pooled_output = outputs[1]
|
||||
seq_relationship_scores = self.cls(pooled_output)
|
||||
@ -916,3 +1194,355 @@ def bert_for_next_sentence_prediction_forward(
|
||||
hidden_states = outputs.get('hidden_states')
|
||||
# intermediate stage always return dict
|
||||
return {'hidden_states': hidden_states}
|
||||
|
||||
|
||||
def bert_for_sequence_classification_forward(
|
||||
self: BertForSequenceClassification,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
hidden_states: Optional[torch.Tensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if return_dict:
|
||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||
return_dict = False
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = bert_model_forward(self.bert,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
hidden_states=hidden_states,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs.get('hidden_states')
|
||||
return {'hidden_states': hidden_states}
|
||||
|
||||
|
||||
def bert_for_token_classification_forward(
|
||||
self: BertForTokenClassification,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
hidden_states: Optional[torch.Tensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if return_dict:
|
||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||
return_dict = False
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = bert_model_forward(
|
||||
self.bert,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
hidden_states=hidden_states,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs.get('hidden_states')
|
||||
return {'hidden_states': hidden_states}
|
||||
|
||||
|
||||
def bert_for_multiple_choice_forward(
|
||||
self: BertForMultipleChoice,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
hidden_states: Optional[torch.Tensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
||||
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
|
||||
`input_ids` above)
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if return_dict:
|
||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||
return_dict = False
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# in our pipeline design,input ids are copied for every stage and shouldn't be none
|
||||
# the input_ids for multiple choice model is [batch_size, num_choices, sequence_length]
|
||||
if stage_manager.is_last_stage():
|
||||
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
||||
|
||||
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
||||
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||
inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
||||
if inputs_embeds is not None else None)
|
||||
|
||||
outputs = bert_model_forward(
|
||||
self.bert,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
hidden_states=hidden_states,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
pooled_output = outputs[1]
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = logits.view(-1, num_choices)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(reshaped_logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return MultipleChoiceModelOutput(
|
||||
loss=loss,
|
||||
logits=reshaped_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs.get('hidden_states')
|
||||
return {'hidden_states': hidden_states}
|
||||
|
||||
|
||||
def bert_for_question_answering_forward(
|
||||
self: BertForQuestionAnswering,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
start_positions: Optional[torch.Tensor] = None,
|
||||
end_positions: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
hidden_states: Optional[torch.Tensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
# NOTE: the arg start_position and end_position are used only for the last stage
|
||||
r"""
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if return_dict:
|
||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||
return_dict = False
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = bert_model_forward(
|
||||
self.bert,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
hidden_states=hidden_states,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1).contiguous()
|
||||
end_logits = end_logits.squeeze(-1).contiguous()
|
||||
|
||||
total_loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions = start_positions.clamp(0, ignored_index)
|
||||
end_positions = end_positions.clamp(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
return ((total_loss,) + output) if total_loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=total_loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs.get('hidden_states')
|
||||
return {'hidden_states': hidden_states}
|
||||
|
@ -212,11 +212,13 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama model"""
|
||||
llama_model = self.model.model
|
||||
if id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight):
|
||||
# tie weights
|
||||
return [{0: llama_model.embed_tokens.weight, self.stage_manager.num_stages - 1: self.model.lm_head.weight}]
|
||||
return [{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight
|
||||
}]
|
||||
return []
|
||||
|
||||
|
||||
|
@ -1 +1 @@
|
||||
from .torchrec import *
|
||||
#from .torchrec import *
|
||||
|
@ -87,6 +87,17 @@ def data_gen_for_mcq():
|
||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
|
||||
|
||||
|
||||
def data_gen_for_qa():
|
||||
# generating data for question answering
|
||||
# no need for labels and use start and end position instead
|
||||
data = data_gen()
|
||||
start_positions = torch.tensor([0], dtype=torch.int64)
|
||||
data['start_positions'] = start_positions
|
||||
end_positions = torch.tensor([1], dtype=torch.int64)
|
||||
data['end_positions'] = end_positions
|
||||
return data
|
||||
|
||||
|
||||
# define output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
@ -150,3 +161,9 @@ model_zoo.register(name='transformers_bert_for_mcq',
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_question_answering',
|
||||
model_fn=lambda: transformers.BertForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_qa,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
@ -7,6 +7,7 @@ from transformers.models.bert.modeling_bert import BertForPreTraining
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.policies.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
@ -35,16 +36,20 @@ def check_bert_for_pretraining_forward():
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
rank = dist.get_rank()
|
||||
# print(rank)
|
||||
layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
|
||||
x = torch.randint(0, 1000, (2, 3))
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x)
|
||||
output = bert_for_pretraining_forward(self=model,
|
||||
output = bert_for_pretraining_forward(
|
||||
self=model,
|
||||
input_ids=x,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
print(output['hidden_states'].shape)
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
assert output['hidden_states'].shape == (2, 3, 768)
|
||||
|
||||
else:
|
||||
@ -52,8 +57,8 @@ def check_bert_for_pretraining_forward():
|
||||
output = bert_for_pretraining_forward(self=model,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
print(output[0].shape)
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
assert output[0].shape == (2, 3, 30522)
|
||||
# assert output[1].shape == (2, 768)
|
||||
|
||||
|
@ -7,12 +7,13 @@ from transformers.models.bert.modeling_bert import BertLMHeadModel
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lmhead_forward
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lm_head_model_forward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_bert_lmhead_forward():
|
||||
def check_bert_lm_head_model_forward():
|
||||
configuration = BertConfig()
|
||||
model = BertLMHeadModel(configuration)
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
@ -35,24 +36,28 @@ def check_bert_lmhead_forward():
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
rank = dist.get_rank()
|
||||
# print(rank)
|
||||
|
||||
layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
x = torch.randint(0, 1000, (2, 3))
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x)
|
||||
output = bert_lmhead_forward(self=model,
|
||||
|
||||
output = bert_lm_head_model_forward(self=model,
|
||||
input_ids=x,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
print(output['hidden_states'].shape)
|
||||
assert output['hidden_states'].shape == (2, 3, 768)
|
||||
|
||||
else:
|
||||
attention_mask = torch.ones((2, 3))
|
||||
output = bert_lmhead_forward(self=model,
|
||||
output = bert_lm_head_model_forward(self=model,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
print(output[0].shape)
|
||||
assert output[0].shape == (2, 3, 30522)
|
||||
|
||||
@ -93,7 +98,7 @@ def check_bert_lmhead_policy():
|
||||
|
||||
def run_dist_model(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_bert_lmhead_forward()
|
||||
check_bert_lm_head_model_forward()
|
||||
|
||||
|
||||
def run_dist_policy(rank, world_size, port):
|
||||
@ -103,7 +108,7 @@ def run_dist_policy(rank, world_size, port):
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bert_lmhead_forward():
|
||||
def test_bert_lm_head_model_forward():
|
||||
spawn(run_dist_model, 4)
|
||||
|
||||
|
||||
@ -115,5 +120,5 @@ def test_bert_lmhead_policy():
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""test the bert for pretraining model forward and bert for pretraining model policy"""
|
||||
test_bert_lmhead_forward()
|
||||
test_bert_lm_head_model_forward()
|
||||
test_bert_lmhead_policy()
|
@ -6,12 +6,14 @@ from transformers.models.bert.modeling_bert import BertModel
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.policies.bert import BertModelPolicy, bert_model_forward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_bert_model_forward():
|
||||
# this test may crash for internet reasons
|
||||
model = BertModel.from_pretrained('bert-base-uncased')
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
@ -34,20 +36,25 @@ def check_bert_model_forward():
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
rank = dist.get_rank()
|
||||
# print(rank)
|
||||
|
||||
layers_per_stage = Policy.distribute_layers(len(model.encoder.layer), 2)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
x = torch.randint(0, 1000, (2, 3))
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x)
|
||||
output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
|
||||
print(output['hidden_states'].shape)
|
||||
output = bert_model_forward(self=model,
|
||||
input_ids=x,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
assert output['hidden_states'].shape == (2, 3, 768)
|
||||
else:
|
||||
attention_mask = torch.ones((2, 3))
|
||||
output = bert_model_forward(self=model,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
print(output[0].shape)
|
||||
assert output[0].shape == (2, 3, 768)
|
||||
|
||||
@ -112,4 +119,3 @@ if __name__ == "__main__":
|
||||
"""test the bert model forward and bert model policy"""
|
||||
#test_bert_model_forward()
|
||||
test_bert_model_policy()
|
||||
# this test need config to run
|
||||
|
@ -49,7 +49,6 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
|
||||
# prepare input
|
||||
data = data_gen_fn()
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
|
||||
# switch to train mode
|
||||
original_model.train()
|
||||
sharded_model.train()
|
||||
|
164
tests/test_shardformer/test_model/test_pure_pipeline.py
Normal file
164
tests/test_shardformer/test_model/test_pure_pipeline.py
Normal file
@ -0,0 +1,164 @@
|
||||
import random
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Iterator, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
|
||||
|
||||
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
|
||||
|
||||
|
||||
class PipelineOptimizer(OptimizerWrapper):
|
||||
|
||||
def __init__(self, optim: Optimizer, model: Module):
|
||||
super().__init__(optim)
|
||||
params = set(model.parameters())
|
||||
new_param_groups = []
|
||||
for group in optim.param_groups:
|
||||
params = [p for p in group['params'] if p in params]
|
||||
new_param_groups.append({**group, 'params': params})
|
||||
optim.__setstate__({'param_groups': new_param_groups})
|
||||
# TODO: support amp
|
||||
|
||||
|
||||
class PipelinedModel(ModelWrapper):
|
||||
|
||||
def __init__(self, module: Module, shard_config: ShardConfig, stage_manager: PipelineStageManager) -> None:
|
||||
self.stage_manager = stage_manager
|
||||
shardformer = ShardFormer(shard_config)
|
||||
module, self.shared_params = shardformer.optimize(module)
|
||||
self.shared_param_process_groups = []
|
||||
super().__init__(module)
|
||||
|
||||
|
||||
def prepare_dataloader(dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0):
|
||||
sampler = DistributedSampler(
|
||||
dataset,
|
||||
#rank=self.pg_mesh.coordinate(DP_AXIS),
|
||||
shuffle=shuffle)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
|
||||
def execute_pipeline(
|
||||
data_iter: Iterator,
|
||||
model: PipelinedModel,
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
optimizer: PipelineOptimizer,
|
||||
return_loss: bool = True,
|
||||
return_outputs: bool = False,
|
||||
schedule: OneForwardOneBackwardSchedule = None,
|
||||
) -> dict:
|
||||
# return loss or outputs if needed
|
||||
outputs = schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss, return_outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class data_iter():
|
||||
|
||||
def __getitem__(self, x):
|
||||
return torch.randint(0, 100, (4, 128)).cuda()
|
||||
|
||||
|
||||
def loss(x, y):
|
||||
return (x[0].float().mean() - y[0].float().mean())
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('enable_tensor_parallelism', [False])
|
||||
@parameterize('use_lazy_init', [False])
|
||||
def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
PP_DIM = 0
|
||||
PP_SIZE = 2
|
||||
RANK_TO_COORDINATE = {
|
||||
0: (0, 0),
|
||||
1: (0, 1),
|
||||
2: (1, 0),
|
||||
3: (1, 1),
|
||||
}
|
||||
PP_RANKS_IN_GROUP = {
|
||||
0: [0, 1],
|
||||
1: [0, 1],
|
||||
2: [2, 3],
|
||||
3: [2, 3],
|
||||
}
|
||||
from datasets import load_dataset
|
||||
|
||||
#dataset = load_dataset("open_subtitles", lang1="fi", lang2="hi")
|
||||
pg_mesh = ProcessGroupMesh(PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
num_microbatches = 2
|
||||
org_model = model_fn().cuda()
|
||||
optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3)
|
||||
#dataloader=prepare_dataloader(dataset=dataset['train'],batch_size=4)
|
||||
schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager)
|
||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||
pipeline_stage_manager=stage_manager)
|
||||
pipelined_model = PipelinedModel(org_model, shard_config, stage_manager)
|
||||
pp_optimizer = PipelineOptimizer(optimizer, pipelined_model)
|
||||
data_it = iter(data_iter())
|
||||
results = execute_pipeline(data_it, pipelined_model, loss, pp_optimizer, schedule=schedule)
|
||||
if stage_manager.is_last_stage():
|
||||
assert results['loss'] is not None
|
||||
assert results['outputs'] is None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_llama_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_llama():
|
||||
spawn(check_llama, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llama()
|
@ -45,25 +45,37 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||
x = torch.randint(0, 1000, (2, 3)).cuda()
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
if name == 'transformers_bert':
|
||||
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
|
||||
if name == 'transformers_bert_for_mcq':
|
||||
x = torch.randint(0, 1000, (2, 3, 3)).cuda()
|
||||
attention_mask = torch.ones_like(x).cuda()
|
||||
if stage_manager.stage == 0:
|
||||
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
|
||||
assert output['hidden_states'].shape == (6, 3, 128)
|
||||
else:
|
||||
hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda()
|
||||
output = sharded_model(input_ids=x,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
assert output[0].shape == (2, 3)
|
||||
else:
|
||||
x = torch.randint(0, 1000, (2, 3)).cuda()
|
||||
# one batch, 2 single sentences, each sentence has 3 tokens
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x).cuda()
|
||||
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
|
||||
# print(output['hidden_states'].shape)
|
||||
assert output['hidden_states'].shape == (2, 3, 128)
|
||||
else:
|
||||
attention_mask = torch.ones((2, 3)).cuda()
|
||||
output = sharded_model(hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
# print(output[0].shape)
|
||||
assert output[0].shape == (2, 3, 128)
|
||||
assert output[0].shape[0] == 2
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user