[shardformer] update transformers (#5583)

* flash_attention forward upgrade

* llama_model_forward

* remove useless comment

* update the requirements.txt

* add the transformers version requirements

* remove the LATEST VERSION try

* [shardformer] update bloom model (#5518)

* update bloom model

* remove the version restriction

* [shardformer] update_falcon (#5520)

* [shardformer] update mistral model (#5511)

* [shardformer] update gpt2 (#5502)

* [shardformer] update gptj model (#5503)

* [shardformer] update opt (#5522)

* [shardformer] update t5 model (#5524)

* [shardformer] update whisper model (#5529)

* [shardformer] update vit model (#5530)

* update vit model

* remove the output_hidden_states

* [shardformer] fix llama modeling

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [zero] support multiple (partial) backward passes (#5596)

* [zero] support multiple (partial) backward passes

* [misc] update requirements

* [zero] support multiple (partial) backward passes (#5596)

* [zero] support multiple (partial) backward passes

* [misc] update requirements

* fix conflicts

* [doc] fix ColossalMoE readme (#5599)

* fix readme

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* merge with main

* merge with main

* llama_model_forward

* remove useless comment

* remove the LATEST VERSION try

* [shardformer] update bloom model (#5518)

* update bloom model

* remove the version restriction

* [shardformer] update mistral model (#5511)

* [shardformer] update opt (#5522)

* [shardformer] update whisper model (#5529)

* [shardformer] fix llama modeling

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [hotfix] Fix examples no pad token & auto parallel codegen bug; (#5606)

* fix no pad token bug

* fixed some auto parallel codegen bug, but might not run on torch 2.1

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [shardformer] fix pipeline grad ckpt (#5620)

* [shardformer] fix pipeline grad ckpt

* [shardformer] fix whisper (#5628)

* [test] fix llama model test

* fix the opt upgrade (#5634)

* [shardformer] fix attn replacement (#5636)

* [shardformer] update flashattention replacement (#5637)

* update transformers

update transformers

fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [test] fix llama test (#5638)

* [gemini] fix buffer cast (#5639)

* Fix shardformer upgrade (#5640)

* fix llama model

* fix the mistral

* fix the shardformer model

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [shardformer]support pipeline parallelism for mistral. (#5642)

* [shardformer] fix attn replacement (#5636)

* [shardformer] update flashattention replacement (#5637)

* update transformers

update transformers

fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Support LLaMA-3 CPT and ST (#5619)

* support LLaMA-3

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Run pre-commit

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [exampe] update llama example (#5626)

* [plugin] support dp inside for hybriad parallel

* [example] update llama benchmark

* [example] update llama benchmark

* [example] update llama readme

* [example] update llama readme

* [example] llama3 (#5631)

* release llama3

* [release] llama3

* [release] llama3

* [release] llama3

* [release] llama3

* [test] fix llama test (#5638)

* [gemini] fix buffer cast (#5639)

* support pp for mistral

* fix

* fix

fix

fix

* fix

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
This commit is contained in:
Wang Binluo
2024-04-24 22:51:50 +08:00
committed by GitHub
parent f4c5aafe29
commit 0d0a582033
27 changed files with 1155 additions and 441 deletions

View File

@@ -1,9 +1,16 @@
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@@ -99,11 +106,17 @@ def get_tp_falcon_decoder_layer_forward():
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs,
):
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states
if self.config.new_decoder_architecture:
@@ -117,10 +130,12 @@ def get_tp_falcon_decoder_layer_forward():
attention_layernorm_out,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
**kwargs,
)
attention_output = attn_outputs[0]
@@ -154,87 +169,6 @@ def get_tp_falcon_decoder_layer_forward():
return forward
def get_falcon_flash_attention_forward():
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.falcon.modeling_falcon import FalconAttention
def forward(
self: FalconAttention,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, query_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads,
query_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
_, kv_length, _ = key_layer.shape
if use_cache:
present = (key_layer, value_layer)
else:
present = None
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2).contiguous()
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous()
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous()
if alibi is not None:
attention_mask_float = (
attention_mask_float + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta
)
batch_size, src_len = query_layer_.size()[0], query_layer_.size()[1]
tgt_len = key_layer_.size()[1]
attention_mask_float = attention_mask_float.expand(batch_size, self.num_heads, src_len, tgt_len).contiguous()
context_layer = me_attention(
query_layer_,
key_layer_,
value_layer_,
attn_bias=attention_mask_float,
scale=self.inv_norm_factor,
p=self.attention_dropout.p,
)
batch_size, seq_length, _, _ = context_layer.shape
context_layer = context_layer.reshape(batch_size, seq_length, -1)
output_tensor = self.dense(context_layer)
return output_tensor, present
return forward
class FalconPipelineForwards:
"""
This class serves as a micro library for falcon pipeline forwards.
@@ -246,6 +180,7 @@ class FalconPipelineForwards:
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@@ -274,17 +209,6 @@ class FalconPipelineForwards:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
else:
past_key_values = self._convert_to_rw_cache(past_key_values)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# case: First stage of training
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
@@ -295,16 +219,22 @@ class FalconPipelineForwards:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
@@ -312,22 +242,80 @@ class FalconPipelineForwards:
# Compute alibi tensor: check build_alibi_tensor documentation
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
past_key_values_length = past_key_values[0][0].shape[-2]
if self.use_alibi:
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
mask = (
torch.ones(
(batch_size, seq_length + past_key_values_length), device=inputs_embeds.device, dtype=torch.long
)
if attention_mask is None
else attention_mask
)
alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype)
else:
alibi = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
if alibi is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
elif head_mask is None:
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
attention_mask_2d = attention_mask
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# We take care to integrate alibi bias in the attention_mask here.
if attention_mask_2d is None:
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
else:
attention_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
attention_mask < -1,
torch.finfo(alibi.dtype).min,
)
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if seq_length > 1:
attention_mask = AttentionMaskConverter._unmask_unattended(
attention_mask, attention_mask_2d, unmasked_value=0.0
)
else:
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
start_idx, end_idx = stage_index[0], stage_index[1]
for i, (block, layer_past) in enumerate(
@@ -337,31 +325,23 @@ class FalconPipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
alibi,
causal_mask,
attention_mask,
position_ids,
head_mask[i],
layer_past,
use_cache,
output_attentions,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
@@ -382,9 +362,6 @@ class FalconPipelineForwards:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if presents is not None:
presents = self._convert_cache_to_standard_format(presents, batch_size)
if stage_manager.is_last_stage():
if not return_dict:
return tuple(