mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 13:11:27 +00:00
Merge branch 'upgrade_transformers' into upgrade_falcon
This commit is contained in:
commit
0dede489d6
@ -58,7 +58,7 @@ class BertPipelineForwards:
|
|||||||
hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
|
hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
shard_config: ShardConfig = None,
|
shard_config: ShardConfig = None,
|
||||||
):
|
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||||
# TODO(jianghai): add explaination of the output here.
|
# TODO(jianghai): add explaination of the output here.
|
||||||
r"""
|
r"""
|
||||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
@ -1037,6 +1037,90 @@ def get_jit_fused_bert_output_forward():
|
|||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
# Fix the tgt_len size in sequence parallel attention:
|
||||||
|
# same with the one in BertSdpaSelfAttention forward in v4.51.3 transformers except the
|
||||||
|
# _, _, tgt_len, _ = query_layer.shape
|
||||||
|
def get_bert_sequence_parallel_attention_forward(shard_config: ShardConfig):
|
||||||
|
from transformers.models.bert.modeling_bert import BertSdpaSelfAttention
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self: BertSdpaSelfAttention,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
) -> Tuple[torch.Tensor]:
|
||||||
|
|
||||||
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||||
|
|
||||||
|
# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
|
||||||
|
# mask needs to be such that the encoder's padding tokens are not attended to.
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
|
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
||||||
|
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
|
||||||
|
|
||||||
|
# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
|
||||||
|
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
|
||||||
|
key_layer, value_layer = past_key_value
|
||||||
|
else:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(current_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(current_states))
|
||||||
|
if past_key_value is not None and not is_cross_attention:
|
||||||
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
|
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
||||||
|
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
||||||
|
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||||
|
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
|
||||||
|
query_layer = query_layer.contiguous()
|
||||||
|
key_layer = key_layer.contiguous()
|
||||||
|
value_layer = value_layer.contiguous()
|
||||||
|
|
||||||
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||||
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
|
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
|
||||||
|
# a causal mask in case tgt_len == 1.
|
||||||
|
is_causal = (
|
||||||
|
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
|
||||||
|
)
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_layer,
|
||||||
|
key_layer,
|
||||||
|
value_layer,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
dropout_p=self.dropout_prob if self.training else 0.0,
|
||||||
|
is_causal=is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
_, _, tgt_len, _ = query_layer.shape
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
|
||||||
|
|
||||||
|
outputs = (attn_output,)
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -6,7 +6,7 @@ import torch.distributed as dist
|
|||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
CausalLMOutputWithCrossAttentions,
|
CausalLMOutputWithCrossAttentions,
|
||||||
@ -21,6 +21,7 @@ from transformers.models.bloom.modeling_bloom import (
|
|||||||
BloomForSequenceClassification,
|
BloomForSequenceClassification,
|
||||||
BloomForTokenClassification,
|
BloomForTokenClassification,
|
||||||
BloomModel,
|
BloomModel,
|
||||||
|
dropout_add,
|
||||||
)
|
)
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
@ -108,7 +109,7 @@ class BloomPipelineForwards:
|
|||||||
def bloom_model_forward(
|
def bloom_model_forward(
|
||||||
self: BloomModel,
|
self: BloomModel,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
head_mask: Optional[torch.LongTensor] = None,
|
head_mask: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||||
@ -116,6 +117,7 @@ class BloomPipelineForwards:
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
@ -151,6 +153,8 @@ class BloomPipelineForwards:
|
|||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
past_key_values = None
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
# attention_probs has shape batch_size x num_heads x N x N
|
# attention_probs has shape batch_size x num_heads x N x N
|
||||||
@ -161,46 +165,60 @@ class BloomPipelineForwards:
|
|||||||
# case: First stage of training
|
# case: First stage of training
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
# check input_ids and inputs_embeds
|
# check input_ids and inputs_embeds
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
elif input_ids is not None:
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.word_embeddings(input_ids)
|
|
||||||
|
|
||||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
|
||||||
# initialize in the first stage and then pass to the next stage
|
|
||||||
else:
|
|
||||||
input_shape = hidden_states.shape[:-1]
|
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
|
|
||||||
# extra recording tensor should be generated in the first stage
|
|
||||||
|
|
||||||
presents = () if use_cache else None
|
|
||||||
all_self_attentions = () if output_attentions else None
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
if past_key_values is None:
|
if inputs_embeds is None:
|
||||||
past_key_values = tuple([None] * len(self.h))
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
# Compute alibi tensor: check build_alibi_tensor documentation,build for every stage
|
|
||||||
seq_length_with_past = seq_length
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||||
past_key_values_length = 0
|
|
||||||
if past_key_values[0] is not None:
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
past_key_values_length = past_key_values[0][0].shape[2] # source_len
|
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
if cache_position is None:
|
||||||
|
cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device)
|
||||||
|
# initialize in the first stage and then pass to the next stage
|
||||||
|
else:
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
if cache_position is None:
|
||||||
|
cache_position = torch.arange(past_length, past_length + seq_length, device=hidden_states.device)
|
||||||
|
|
||||||
|
# extra recording tensor should be generated in the first stage
|
||||||
|
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
|
return_legacy_cache = False
|
||||||
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
|
return_legacy_cache = True
|
||||||
|
if past_key_values is None:
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute alibi tensor: check build_alibi_tensor documentation,build for every stage
|
||||||
|
past_length = 0
|
||||||
|
seq_length_with_past = seq_length + past_length
|
||||||
|
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||||
else:
|
else:
|
||||||
@ -209,13 +227,10 @@ class BloomPipelineForwards:
|
|||||||
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||||
|
|
||||||
# causal_mask is constructed every stage and its input is passed through different stages
|
# causal_mask is constructed every stage and its input is passed through different stages
|
||||||
causal_mask = _prepare_4d_causal_attention_mask(
|
causal_mask = self._update_causal_mask(
|
||||||
attention_mask,
|
attention_mask, hidden_states, cache_position, past_key_values, output_attentions
|
||||||
input_shape=(batch_size, seq_length),
|
|
||||||
inputs_embeds=hidden_states,
|
|
||||||
past_key_values_length=past_key_values_length,
|
|
||||||
)
|
)
|
||||||
causal_mask = causal_mask.bool()
|
|
||||||
# split the input tensor along sequence dimension
|
# split the input tensor along sequence dimension
|
||||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||||
if shard_config and shard_config.enable_sequence_parallelism:
|
if shard_config and shard_config.enable_sequence_parallelism:
|
||||||
@ -228,9 +243,7 @@ class BloomPipelineForwards:
|
|||||||
)
|
)
|
||||||
|
|
||||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
for i, (block, layer_past) in enumerate(
|
for i, block in enumerate(self.h[start_idx:end_idx], start=start_idx):
|
||||||
zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx
|
|
||||||
):
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
@ -240,26 +253,28 @@ class BloomPipelineForwards:
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
alibi,
|
alibi,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
layer_past,
|
past_key_values,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
use_cache,
|
use_cache,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
|
cache_position,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=past_key_values,
|
||||||
attention_mask=causal_mask,
|
attention_mask=causal_mask,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
alibi=alibi,
|
alibi=alibi,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = outputs[1]
|
||||||
|
|
||||||
if use_cache is True:
|
|
||||||
presents = presents + (outputs[1],)
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
|
|
||||||
@ -277,20 +292,23 @@ class BloomPipelineForwards:
|
|||||||
# Add last hidden state
|
# Add last hidden state
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
|
||||||
# TODO(jianghai): deal with all_hidden_states, all_self_attentions, presents
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if return_legacy_cache:
|
||||||
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
|
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
# attention_mask is not returned ; presents = past_key_values
|
# attention_mask is not returned ; presents = past_key_values
|
||||||
return BaseModelOutputWithPastAndCrossAttentions(
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=presents,
|
past_key_values=next_cache,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attentions,
|
attentions=all_self_attentions,
|
||||||
)
|
)
|
||||||
@ -718,35 +736,24 @@ def get_jit_fused_bloom_attention_forward():
|
|||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
|
batch_size, q_length, _ = hidden_states.shape
|
||||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||||
|
# 3 x [batch_size, num_heads, seq_length, head_dim]
|
||||||
|
query_layer, key_layer, value_layer = self._reshape(fused_qkv)
|
||||||
|
|
||||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
|
||||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
|
||||||
|
|
||||||
batch_size, q_length, _, _ = query_layer.shape
|
|
||||||
|
|
||||||
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
|
||||||
key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
|
|
||||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key, past_value = layer_past
|
cache_kwargs = {"cache_position": cache_position}
|
||||||
# concatenate along seq_length dimension:
|
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
|
||||||
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
|
||||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
|
||||||
key_layer = torch.cat((past_key, key_layer), dim=2)
|
|
||||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
|
||||||
|
|
||||||
_, _, kv_length = key_layer.shape
|
# reshape qkv for further computations
|
||||||
|
query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
|
||||||
if use_cache is True:
|
key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2)
|
||||||
present = (key_layer, value_layer)
|
value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
|
||||||
else:
|
|
||||||
present = None
|
|
||||||
|
|
||||||
# [batch_size * num_heads, q_length, kv_length]
|
# [batch_size * num_heads, q_length, kv_length]
|
||||||
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
|
attention_scores = alibi.baddbmm(
|
||||||
matmul_result = alibi.baddbmm(
|
|
||||||
batch1=query_layer,
|
batch1=query_layer,
|
||||||
batch2=key_layer,
|
batch2=key_layer,
|
||||||
beta=self.beta,
|
beta=self.beta,
|
||||||
@ -754,15 +761,13 @@ def get_jit_fused_bloom_attention_forward():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# change view to [batch_size, num_heads, q_length, kv_length]
|
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||||
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1)
|
||||||
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
|
causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]]
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
|
||||||
input_dtype = attention_scores.dtype
|
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype)
|
||||||
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
|
||||||
if input_dtype == torch.float16:
|
|
||||||
attention_scores = attention_scores.to(torch.float)
|
|
||||||
attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
|
||||||
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
|
|
||||||
|
|
||||||
# [batch_size, num_heads, q_length, kv_length]
|
# [batch_size, num_heads, q_length, kv_length]
|
||||||
attention_probs = self.attention_dropout(attention_probs)
|
attention_probs = self.attention_dropout(attention_probs)
|
||||||
@ -771,12 +776,12 @@ def get_jit_fused_bloom_attention_forward():
|
|||||||
attention_probs = attention_probs * head_mask
|
attention_probs = attention_probs * head_mask
|
||||||
|
|
||||||
# change view [batch_size x num_heads, q_length, kv_length]
|
# change view [batch_size x num_heads, q_length, kv_length]
|
||||||
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
|
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1)
|
||||||
|
|
||||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||||
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
|
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
|
||||||
|
|
||||||
# change view [batch_size, num_heads, q_length, head_dim]
|
# change view [batch_size, q_length, num_heads * head_dim]
|
||||||
context_layer = self._merge_heads(context_layer)
|
context_layer = self._merge_heads(context_layer)
|
||||||
|
|
||||||
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
||||||
@ -791,9 +796,9 @@ def get_jit_fused_bloom_attention_forward():
|
|||||||
else:
|
else:
|
||||||
output_tensor = self.dense(context_layer)
|
output_tensor = self.dense(context_layer)
|
||||||
|
|
||||||
output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
||||||
|
|
||||||
outputs = (output_tensor, present)
|
outputs = (output_tensor, layer_past)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
outputs += (attention_probs,)
|
outputs += (attention_probs,)
|
||||||
|
|
||||||
@ -839,13 +844,99 @@ def get_jit_fused_bloom_gelu_forward():
|
|||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
# Fixed the q_length args when doing the sequence parallelism in bloom model.
|
||||||
|
def get_bloom_sequence_parallel_attention_forward(shard_config: ShardConfig):
|
||||||
|
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self: BloomAttention,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
alibi: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
layer_past: Optional[Cache] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
batch_size, q_length, _ = hidden_states.shape
|
||||||
|
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||||
|
# 3 x [batch_size, num_heads, seq_length, head_dim]
|
||||||
|
query_layer, key_layer, value_layer = self._reshape(fused_qkv)
|
||||||
|
|
||||||
|
if layer_past is not None:
|
||||||
|
cache_kwargs = {"cache_position": cache_position}
|
||||||
|
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
# reshape qkv for further computations
|
||||||
|
query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
|
||||||
|
key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2)
|
||||||
|
value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
|
||||||
|
|
||||||
|
# [batch_size * num_heads, q_length, kv_length]
|
||||||
|
attention_scores = alibi.baddbmm(
|
||||||
|
batch1=query_layer,
|
||||||
|
batch2=key_layer,
|
||||||
|
beta=self.beta,
|
||||||
|
alpha=self.inv_norm_factor,
|
||||||
|
)
|
||||||
|
if shard_config.enable_sequence_parallelism:
|
||||||
|
_, q_length, _ = query_layer.shape
|
||||||
|
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||||
|
attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1)
|
||||||
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
|
causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]]
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
|
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
|
||||||
|
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype)
|
||||||
|
|
||||||
|
# [batch_size, num_heads, q_length, kv_length]
|
||||||
|
attention_probs = self.attention_dropout(attention_probs)
|
||||||
|
|
||||||
|
if head_mask is not None:
|
||||||
|
attention_probs = attention_probs * head_mask
|
||||||
|
|
||||||
|
# change view [batch_size x num_heads, q_length, kv_length]
|
||||||
|
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1)
|
||||||
|
|
||||||
|
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||||
|
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
|
||||||
|
|
||||||
|
# change view [batch_size, q_length, num_heads * head_dim]
|
||||||
|
context_layer = self._merge_heads(context_layer)
|
||||||
|
|
||||||
|
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
||||||
|
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||||
|
slices = self.hidden_size / self.pretraining_tp
|
||||||
|
output_tensor = torch.zeros_like(context_layer)
|
||||||
|
for i in range(self.pretraining_tp):
|
||||||
|
output_tensor = output_tensor + F.linear(
|
||||||
|
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
|
||||||
|
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_tensor = self.dense(context_layer)
|
||||||
|
|
||||||
|
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
||||||
|
|
||||||
|
outputs = (output_tensor, layer_past)
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (attention_probs,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||||
from transformers import BloomModel
|
from transformers import BloomModel
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self: BloomModel,
|
self: BloomModel,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
head_mask: Optional[torch.LongTensor] = None,
|
head_mask: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||||
@ -853,6 +944,7 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**deprecated_arguments,
|
**deprecated_arguments,
|
||||||
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
if deprecated_arguments.pop("position_ids", False) is not False:
|
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||||
@ -864,7 +956,6 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||||||
)
|
)
|
||||||
if len(deprecated_arguments) > 0:
|
if len(deprecated_arguments) > 0:
|
||||||
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@ -872,62 +963,60 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
elif input_ids is not None:
|
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
|
return_legacy_cache = False
|
||||||
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
|
return_legacy_cache = True
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
past_key_values = tuple([None] * len(self.h))
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
seq_length_with_past = seq_length + past_length
|
||||||
|
if cache_position is None:
|
||||||
|
cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
# attention_probs has shape batch_size x num_heads x N x N
|
# attention_probs has shape batch_size x num_heads x N x N
|
||||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.word_embeddings(input_ids)
|
|
||||||
|
|
||||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||||
|
|
||||||
presents = () if use_cache else None
|
next_decoder_cache = None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
# Compute alibi tensor: check build_alibi_tensor documentation
|
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||||
seq_length_with_past = seq_length
|
|
||||||
past_key_values_length = 0
|
|
||||||
if past_key_values[0] is not None:
|
|
||||||
past_key_values_length = past_key_values[0][0].shape[2]
|
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||||
else:
|
else:
|
||||||
attention_mask = attention_mask.to(hidden_states.device)
|
attention_mask = attention_mask.to(hidden_states.device)
|
||||||
|
|
||||||
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||||
|
causal_mask = self._update_causal_mask(
|
||||||
causal_mask = _prepare_4d_causal_attention_mask(
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||||
attention_mask,
|
|
||||||
input_shape=(batch_size, seq_length),
|
|
||||||
inputs_embeds=hidden_states,
|
|
||||||
past_key_values_length=past_key_values_length,
|
|
||||||
)
|
)
|
||||||
causal_mask = causal_mask.bool()
|
|
||||||
# split the input tensor along sequence dimension
|
|
||||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
|
||||||
hidden_states = split_forward_gather_backward(
|
hidden_states = split_forward_gather_backward(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
dim=1,
|
dim=1,
|
||||||
@ -935,7 +1024,7 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||||||
fp8_communication=shard_config.fp8_communication,
|
fp8_communication=shard_config.fp8_communication,
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
for i, block in enumerate(self.h):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
@ -945,25 +1034,27 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
alibi,
|
alibi,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
layer_past,
|
past_key_values,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
use_cache,
|
use_cache,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
|
cache_position,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=past_key_values,
|
||||||
attention_mask=causal_mask,
|
attention_mask=causal_mask,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
alibi=alibi,
|
alibi=alibi,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if use_cache is True:
|
if use_cache:
|
||||||
presents = presents + (outputs[1],)
|
next_decoder_cache = outputs[1]
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
@ -975,18 +1066,25 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||||||
process_group=shard_config.tensor_parallel_process_group,
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
fp8_communication=shard_config.fp8_communication,
|
fp8_communication=shard_config.fp8_communication,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add last hidden state
|
# Add last hidden state
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if return_legacy_cache:
|
||||||
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
return tuple(
|
||||||
|
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
return BaseModelOutputWithPastAndCrossAttentions(
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=presents,
|
past_key_values=next_cache,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attentions,
|
attentions=all_self_attentions,
|
||||||
)
|
)
|
||||||
|
@ -21,6 +21,7 @@ from transformers.models.falcon.modeling_falcon import (
|
|||||||
build_alibi_tensor,
|
build_alibi_tensor,
|
||||||
)
|
)
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
import warnings
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
@ -103,7 +104,7 @@ def get_tp_falcon_decoder_layer_forward():
|
|||||||
alibi: Optional[torch.Tensor],
|
alibi: Optional[torch.Tensor],
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
layer_past: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
@ -113,7 +114,10 @@ def get_tp_falcon_decoder_layer_forward():
|
|||||||
] = None, # Add cache_position and position_embeddings args for v4.51.3 transformers
|
] = None, # Add cache_position and position_embeddings args for v4.51.3 transformers
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if "padding_mask" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||||
|
)
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
# same as v4.51.3 transformers
|
# same as v4.51.3 transformers
|
||||||
@ -170,7 +174,7 @@ def get_tp_falcon_decoder_layer_forward():
|
|||||||
else:
|
else:
|
||||||
outputs = (output,) + outputs[1:]
|
outputs = (output,) + outputs[1:]
|
||||||
|
|
||||||
return outputs # hidden_states, past_kv, attentions
|
return outputs # hidden_states, present, attentions
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
@ -226,7 +230,7 @@ class FalconPipelineForwards:
|
|||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
else:
|
else:
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.word_embeddings(input_ids)
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
@ -236,7 +240,7 @@ class FalconPipelineForwards:
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning_once(
|
logger.warning(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
@ -318,7 +322,7 @@ class FalconPipelineForwards:
|
|||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if use_cache is True:
|
if use_cache is True:
|
||||||
outputs[1]
|
next_decoder_cache = outputs[1]
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
|
@ -48,6 +48,7 @@ def _get_attention_mask(
|
|||||||
sp_mode = shard_config.sequence_parallelism_mode
|
sp_mode = shard_config.sequence_parallelism_mode
|
||||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
|
||||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||||
assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only."
|
assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only."
|
||||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||||
@ -55,7 +56,7 @@ def _get_attention_mask(
|
|||||||
encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
|
encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||||
(encoder_batch_size, 1, seq_len, encoder_sequence_length),
|
(encoder_batch_size, 1, seq_len, encoder_sequence_length),
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
dtype2=encoder_hidden_states.dtype,
|
device=encoder_hidden_states.device,
|
||||||
q_padding_mask=attention_mask,
|
q_padding_mask=attention_mask,
|
||||||
kv_padding_mask=encoder_attention_mask,
|
kv_padding_mask=encoder_attention_mask,
|
||||||
)
|
)
|
||||||
@ -77,7 +78,6 @@ def _get_attention_mask(
|
|||||||
if shard_config.enable_flash_attention:
|
if shard_config.enable_flash_attention:
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.view(batch_size, -1)
|
attention_mask = attention_mask.view(batch_size, -1)
|
||||||
|
|
||||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||||
(batch_size, 1, seq_len, seq_len + past_key_values_length),
|
(batch_size, 1, seq_len, seq_len + past_key_values_length),
|
||||||
hidden_states.dtype,
|
hidden_states.dtype,
|
||||||
@ -835,9 +835,12 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
|
|||||||
attention_mask = encoder_attention_mask
|
attention_mask = encoder_attention_mask
|
||||||
else:
|
else:
|
||||||
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
||||||
query = self._split_heads(query, self.num_heads, self.head_dim)
|
|
||||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
shape_q = (*query.shape[:-1], -1, self.head_dim)
|
||||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
shape_kv = (*key.shape[:-1], -1, self.head_dim)
|
||||||
|
query = query.view(shape_q).transpose(1, 2)
|
||||||
|
key = key.view(shape_kv).transpose(1, 2)
|
||||||
|
value = value.view(shape_kv).transpose(1, 2)
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key, past_value = layer_past
|
past_key, past_value = layer_past
|
||||||
@ -871,7 +874,9 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
|
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
|
||||||
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
|
||||||
|
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
|
||||||
|
attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
|
||||||
attn_output = self.c_proj(attn_output)
|
attn_output = self.c_proj(attn_output)
|
||||||
attn_output = self.resid_dropout(attn_output)
|
attn_output = self.resid_dropout(attn_output)
|
||||||
outputs = (attn_output, present, None)
|
outputs = (attn_output, present, None)
|
||||||
|
@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
@ -141,7 +140,9 @@ class LlamaPipelineForwards:
|
|||||||
invert=(sp_mode != "ring_attn"),
|
invert=(sp_mode != "ring_attn"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
attn_kwargs: torch.Tensor = self._update_causal_mask(
|
||||||
|
attention_mask, hidden_states, cache_position, past_key_values
|
||||||
|
)
|
||||||
|
|
||||||
# Support SP + PP. Later stages have already received the split input.
|
# Support SP + PP. Later stages have already received the split input.
|
||||||
split_input = disable_pp or stage_manager.is_first_stage()
|
split_input = disable_pp or stage_manager.is_first_stage()
|
||||||
@ -177,6 +178,7 @@ class LlamaPipelineForwards:
|
|||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = None
|
next_decoder_cache = None
|
||||||
start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1])
|
start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1])
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
num_ckpt_layers = 0
|
num_ckpt_layers = 0
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
@ -204,6 +206,7 @@ class LlamaPipelineForwards:
|
|||||||
output_attentions,
|
output_attentions,
|
||||||
use_cache,
|
use_cache,
|
||||||
cache_position,
|
cache_position,
|
||||||
|
position_embeddings,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
@ -214,6 +217,7 @@ class LlamaPipelineForwards:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@ -486,8 +490,8 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||||
attention_mask: Optional[Union[torch.Tensor, Dict]] = None,
|
attention_mask: Optional[Union[torch.Tensor, Dict]] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
@ -505,27 +509,11 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
)
|
)
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
# sp: modify sp_len when sequence parallel mode is ring
|
# sp: modify sp_len when sequence parallel mode is ring
|
||||||
if is_share_sp_tp(sp_mode):
|
if is_share_sp_tp(sp_mode):
|
||||||
q_len *= sp_size
|
q_len *= sp_size
|
||||||
|
|
||||||
if self.config.pretraining_tp > 1:
|
|
||||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
|
||||||
query_slices = self.q_proj.weight.split(
|
|
||||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
|
||||||
)
|
|
||||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
|
|
||||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
||||||
query_states = torch.cat(query_states, dim=-1)
|
|
||||||
|
|
||||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
||||||
key_states = torch.cat(key_states, dim=-1)
|
|
||||||
|
|
||||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
||||||
value_states = torch.cat(value_states, dim=-1)
|
|
||||||
else:
|
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states = self.v_proj(hidden_states)
|
||||||
@ -537,9 +525,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
bsz, q_len, _ = query_states.size()
|
bsz, q_len, _ = query_states.size()
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@ -552,7 +540,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
|
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@ -610,17 +598,12 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||||
|
|
||||||
if self.config.pretraining_tp > 1:
|
|
||||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
|
||||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
|
||||||
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
|
||||||
else:
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
@ -57,6 +57,7 @@ class Qwen2PipelineForwards:
|
|||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
@ -131,14 +132,6 @@ class Qwen2PipelineForwards:
|
|||||||
else:
|
else:
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
|
||||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
|
||||||
if is_padding_right:
|
|
||||||
raise ValueError(
|
|
||||||
"You are attempting to perform batched generation with padding_side='right'"
|
|
||||||
" this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
|
|
||||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
|
||||||
)
|
|
||||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||||
# for the other stages, hidden_states is the output of the previous stage
|
# for the other stages, hidden_states is the output of the previous stage
|
||||||
if shard_config.enable_flash_attention:
|
if shard_config.enable_flash_attention:
|
||||||
@ -152,16 +145,16 @@ class Qwen2PipelineForwards:
|
|||||||
is_causal=True,
|
is_causal=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
# 2d mask is passed through the layers
|
# 2d mask is passed through the layers
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||||
elif self._attn_implementation == "sdpa" and not output_attentions:
|
elif self.config._attn_implementation == "sdpa" and not output_attentions:
|
||||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||||
# the manual implementation that requires a 4D causal mask in all cases.
|
# the manual implementation that requires a 4D causal mask in all cases.
|
||||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
(batch_size, seq_length),
|
(batch_size, seq_length),
|
||||||
inputs_embeds,
|
hidden_states,
|
||||||
past_key_values_length,
|
past_key_values_length,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -195,6 +188,8 @@ class Qwen2PipelineForwards:
|
|||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = None
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
num_ckpt_layers = 0
|
num_ckpt_layers = 0
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
@ -214,7 +209,7 @@ class Qwen2PipelineForwards:
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
if idx - start_idx < num_ckpt_layers:
|
if idx - start_idx < num_ckpt_layers:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
@ -225,15 +220,19 @@ class Qwen2PipelineForwards:
|
|||||||
past_key_values,
|
past_key_values,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
use_cache,
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
position_embeddings,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids,
|
||||||
past_key_value=past_key_value,
|
past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
@ -491,11 +490,10 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
def forward(
|
def forward(
|
||||||
self: Qwen2Attention,
|
self: Qwen2Attention,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
use_cache: bool = False,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if sp_mode is not None:
|
if sp_mode is not None:
|
||||||
@ -519,9 +517,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
bsz, q_len, _ = query_states.size()
|
bsz, q_len, _ = query_states.size()
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@ -533,9 +531,8 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
)
|
)
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||||
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
cos, sin = position_embeddings
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||||
@ -563,7 +560,7 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
attention_mask = attention_mask[:, slicing_tokens:]
|
attention_mask = attention_mask[:, slicing_tokens:]
|
||||||
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
||||||
|
|
||||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
@ -605,11 +602,11 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
@ -627,6 +624,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
|||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
force_sp_output_gather: bool = True,
|
force_sp_output_gather: bool = True,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
@ -648,6 +646,9 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
|||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
seq_length_with_past = seq_length
|
seq_length_with_past = seq_length
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
|
|
||||||
@ -664,9 +665,6 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
|||||||
else:
|
else:
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
@ -700,6 +698,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
|||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = None
|
next_decoder_cache = None
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
if sp_mode in ["ring", "split_gather"]:
|
if sp_mode in ["ring", "split_gather"]:
|
||||||
hidden_states = split_forward_gather_backward(
|
hidden_states = split_forward_gather_backward(
|
||||||
@ -723,22 +722,23 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
|||||||
past_key_values,
|
past_key_values,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
use_cache,
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
position_embeddings,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids,
|
||||||
past_key_value=past_key_values,
|
past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ import colossalai.shardformer.layer as col_nn
|
|||||||
from ..modeling.bert import (
|
from ..modeling.bert import (
|
||||||
BertPipelineForwards,
|
BertPipelineForwards,
|
||||||
bert_sequence_parallel_forward_fn,
|
bert_sequence_parallel_forward_fn,
|
||||||
|
get_bert_sequence_parallel_attention_forward,
|
||||||
get_jit_fused_bert_intermediate_forward,
|
get_jit_fused_bert_intermediate_forward,
|
||||||
get_jit_fused_bert_output_forward,
|
get_jit_fused_bert_output_forward,
|
||||||
get_jit_fused_bert_self_output_forward,
|
get_jit_fused_bert_self_output_forward,
|
||||||
@ -48,6 +49,7 @@ class BertPolicy(Policy):
|
|||||||
BertLayer,
|
BertLayer,
|
||||||
BertModel,
|
BertModel,
|
||||||
BertOutput,
|
BertOutput,
|
||||||
|
BertSdpaSelfAttention,
|
||||||
BertSelfOutput,
|
BertSelfOutput,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -77,6 +79,16 @@ class BertPolicy(Policy):
|
|||||||
|
|
||||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
|
|
||||||
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
|
# Fix the tgt_len size in bert sequence parallel attention forward.
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_bert_sequence_parallel_attention_forward(self.shard_config),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=BertSdpaSelfAttention,
|
||||||
|
)
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
assert (
|
assert (
|
||||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
@ -11,6 +11,7 @@ import colossalai.shardformer.layer as col_nn
|
|||||||
from ..modeling.bloom import (
|
from ..modeling.bloom import (
|
||||||
BloomPipelineForwards,
|
BloomPipelineForwards,
|
||||||
build_bloom_alibi_tensor_fn,
|
build_bloom_alibi_tensor_fn,
|
||||||
|
get_bloom_sequence_parallel_attention_forward,
|
||||||
get_bloom_sequence_parallel_forward_fn,
|
get_bloom_sequence_parallel_forward_fn,
|
||||||
get_jit_fused_bloom_attention_forward,
|
get_jit_fused_bloom_attention_forward,
|
||||||
get_jit_fused_bloom_gelu_forward,
|
get_jit_fused_bloom_gelu_forward,
|
||||||
@ -61,6 +62,15 @@ class BloomPolicy(Policy):
|
|||||||
|
|
||||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
|
|
||||||
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_bloom_sequence_parallel_attention_forward(self.shard_config),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=BloomAttention,
|
||||||
|
)
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
assert (
|
assert (
|
||||||
self.model.config.n_head % self.shard_config.tensor_parallel_size == 0
|
self.model.config.n_head % self.shard_config.tensor_parallel_size == 0
|
||||||
|
@ -246,7 +246,6 @@ class FalconPolicy(Policy):
|
|||||||
module = self.model.transformer
|
module = self.model.transformer
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
held_layers = []
|
held_layers = []
|
||||||
held_layers.append(module.rotary_emb)
|
|
||||||
if stage_manager.is_interleave:
|
if stage_manager.is_interleave:
|
||||||
assert stage_manager.num_model_chunks is not None
|
assert stage_manager.num_model_chunks is not None
|
||||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||||
|
@ -38,14 +38,8 @@ class GPT2Policy(Policy):
|
|||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
|
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
|
||||||
|
|
||||||
ATTN_IMPLEMENTATION = {
|
|
||||||
"eager": GPT2Attention,
|
|
||||||
}
|
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
|
||||||
|
|
||||||
embedding_cls = None
|
embedding_cls = None
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
embedding_cls = col_nn.VocabParallelEmbedding1D
|
embedding_cls = col_nn.VocabParallelEmbedding1D
|
||||||
@ -280,7 +274,7 @@ class GPT2Policy(Policy):
|
|||||||
"forward": get_gpt2_flash_attention_forward(shard_config=self.shard_config),
|
"forward": get_gpt2_flash_attention_forward(shard_config=self.shard_config),
|
||||||
},
|
},
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=attn_cls,
|
target_key=GPT2Attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.shard_config.pipeline_stage_manager and self.shard_config.enable_sequence_parallelism:
|
if not self.shard_config.pipeline_stage_manager and self.shard_config.enable_sequence_parallelism:
|
||||||
|
@ -33,22 +33,9 @@ class LlamaPolicy(Policy):
|
|||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
|
||||||
LlamaAttention,
|
|
||||||
LlamaDecoderLayer,
|
|
||||||
LlamaFlashAttention2,
|
|
||||||
LlamaModel,
|
|
||||||
LlamaSdpaAttention,
|
|
||||||
)
|
|
||||||
|
|
||||||
ATTN_IMPLEMENTATION = {
|
|
||||||
"eager": LlamaAttention,
|
|
||||||
"flash_attention_2": LlamaFlashAttention2,
|
|
||||||
"sdpa": LlamaSdpaAttention,
|
|
||||||
}
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
|
||||||
embedding_cls = None
|
embedding_cls = None
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
embedding_cls = VocabParallelEmbedding1D
|
embedding_cls = VocabParallelEmbedding1D
|
||||||
@ -82,7 +69,7 @@ class LlamaPolicy(Policy):
|
|||||||
num_kv_heads //= sp_size
|
num_kv_heads //= sp_size
|
||||||
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
|
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
|
||||||
|
|
||||||
policy[attn_cls] = ModulePolicyDescription(
|
policy[LlamaAttention] = ModulePolicyDescription(
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
)
|
)
|
||||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||||
@ -91,7 +78,7 @@ class LlamaPolicy(Policy):
|
|||||||
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||||
},
|
},
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=attn_cls,
|
target_key=LlamaAttention,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.pipeline_stage_manager is None:
|
if self.pipeline_stage_manager is None:
|
||||||
@ -354,6 +341,7 @@ class LlamaPolicy(Policy):
|
|||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
|
held_layers.append(module.rotary_emb)
|
||||||
if stage_manager.is_interleave:
|
if stage_manager.is_interleave:
|
||||||
assert stage_manager.num_model_chunks is not None
|
assert stage_manager.num_model_chunks is not None
|
||||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
|
@ -65,15 +65,9 @@ class Qwen2Policy(Policy):
|
|||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
ATTN_IMPLEMENTATION = {
|
|
||||||
"eager": Qwen2Attention,
|
|
||||||
"flash_attention_2": Qwen2FlashAttention2,
|
|
||||||
"sdpa": Qwen2SdpaAttention,
|
|
||||||
}
|
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
|
||||||
embedding_cls = None
|
embedding_cls = None
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
embedding_cls = VocabParallelEmbedding1D
|
embedding_cls = VocabParallelEmbedding1D
|
||||||
@ -93,7 +87,7 @@ class Qwen2Policy(Policy):
|
|||||||
if getattr(self.model.config, "num_key_value_heads", False):
|
if getattr(self.model.config, "num_key_value_heads", False):
|
||||||
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
|
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
|
||||||
|
|
||||||
policy[attn_cls] = ModulePolicyDescription(
|
policy[Qwen2Attention] = ModulePolicyDescription(
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -301,12 +295,13 @@ class Qwen2Policy(Policy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||||
|
print("self.shard_config.enable_flash_attention", self.shard_config.enable_flash_attention)
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={
|
description={
|
||||||
"forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
"forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||||
},
|
},
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=attn_cls,
|
target_key=Qwen2Attention,
|
||||||
)
|
)
|
||||||
if self.pipeline_stage_manager is None:
|
if self.pipeline_stage_manager is None:
|
||||||
# replace qwen2 model forward method
|
# replace qwen2 model forward method
|
||||||
@ -370,6 +365,7 @@ class Qwen2Policy(Policy):
|
|||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
|
held_layers.append(module.rotary_emb)
|
||||||
if stage_manager.is_interleave:
|
if stage_manager.is_interleave:
|
||||||
assert stage_manager.num_model_chunks is not None
|
assert stage_manager.num_model_chunks is not None
|
||||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
|
@ -180,7 +180,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
"sequence_parallelism_mode": "split_gather",
|
"sequence_parallelism_mode": "split_gather",
|
||||||
"enable_flash_attention": True,
|
"enable_flash_attention": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": False,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
@ -238,7 +238,7 @@ def run_gpt2_test(test_config):
|
|||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
@ -247,7 +247,7 @@ def run_gpt2_test(test_config):
|
|||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
|
Loading…
Reference in New Issue
Block a user