mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 20:46:00 +00:00
[pipeline] Bert pipeline for shardformer and its tests (#4197)
* add pipeline forward * complete pipeline forward check * fix bert forward without pipeline * fix comments * discard useless line * add todo * clean prints * fix distribute layers
This commit is contained in:
@@ -13,6 +13,8 @@ from transformers.modeling_outputs import (
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
)
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BertForMaskedLM,
|
||||
BertForNextSentencePrediction,
|
||||
BertForPreTraining,
|
||||
BertForPreTrainingOutput,
|
||||
BertLMHeadModel,
|
||||
@@ -135,7 +137,6 @@ class BertPolicy(Policy):
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BertLayer)
|
||||
|
||||
# handle embedding layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[SubModuleReplacementDescription(
|
||||
@@ -144,6 +145,7 @@ class BertPolicy(Policy):
|
||||
)],
|
||||
policy=policy,
|
||||
target_key=BertEmbeddings)
|
||||
|
||||
return policy
|
||||
|
||||
def add_lm_head_policy(self, base_policy):
|
||||
@@ -177,6 +179,15 @@ class BertModelPolicy(BertPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
module_policy = super().module_policy()
|
||||
from transformers.models.bert.modeling_bert import BertModel
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
module_policy[BertModel] = ModulePolicyDescription(
|
||||
method_replacement={'forward': partial(bert_model_forward, stage_manager=self.pipeline_stage_manager)})
|
||||
return module_policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
module = self.model
|
||||
@@ -444,6 +455,13 @@ def bert_model_forward(
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
else:
|
||||
input_shape = hidden_states.size()[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
@@ -466,14 +484,6 @@ def bert_model_forward(
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
@@ -778,3 +788,131 @@ def bert_lmhead_forward(self: BertLMHeadModel,
|
||||
hidden_states = outputs.get('hidden_states')
|
||||
# intermediate stage always return dict
|
||||
return {'hidden_states': hidden_states}
|
||||
|
||||
|
||||
def bert_for_masked_lm_forward(
|
||||
self: BertForMaskedLM,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
hidden_states: Optional[torch.Tensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
):
|
||||
#-> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
||||
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
||||
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def bert_for_next_sentence_prediction_forward(
|
||||
self: BertForNextSentencePrediction,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
hidden_states: Optional[torch.Tensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
**kwargs,
|
||||
):
|
||||
#-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
||||
(see `input_ids` docstring). Indices should be in `[0, 1]`:
|
||||
|
||||
- 0 indicates sequence B is a continuation of sequence A,
|
||||
- 1 indicates sequence B is a random sequence.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, BertForNextSentencePrediction
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
>>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased")
|
||||
|
||||
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
||||
>>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
|
||||
>>> logits = outputs.logits
|
||||
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
|
||||
```
|
||||
"""
|
||||
|
||||
if "next_sentence_label" in kwargs:
|
||||
warnings.warn(
|
||||
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
|
||||
" `labels` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
labels = kwargs.pop("next_sentence_label")
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if return_dict:
|
||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||
return_dict = False
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = bert_model_forward(
|
||||
self.bert,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
pooled_output = outputs[1]
|
||||
seq_relationship_scores = self.cls(pooled_output)
|
||||
|
||||
next_sentence_loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (seq_relationship_scores,) + outputs[2:]
|
||||
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
|
||||
|
||||
return NextSentencePredictorOutput(
|
||||
loss=next_sentence_loss,
|
||||
logits=seq_relationship_scores,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs.get('hidden_states')
|
||||
# intermediate stage always return dict
|
||||
return {'hidden_states': hidden_states}
|
||||
|
Reference in New Issue
Block a user