Merge pull request #6299 from wangbluo/upgrade_bloom

Upgrade bloom
This commit is contained in:
Hanks 2025-05-14 10:19:44 +08:00 committed by GitHub
commit 1ace29b54d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 239 additions and 131 deletions

View File

@ -6,7 +6,7 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F from torch.nn import functional as F
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
@ -21,6 +21,7 @@ from transformers.models.bloom.modeling_bloom import (
BloomForSequenceClassification, BloomForSequenceClassification,
BloomForTokenClassification, BloomForTokenClassification,
BloomModel, BloomModel,
dropout_add,
) )
from transformers.utils import logging from transformers.utils import logging
@ -108,7 +109,7 @@ class BloomPipelineForwards:
def bloom_model_forward( def bloom_model_forward(
self: BloomModel, self: BloomModel,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None,
@ -116,6 +117,7 @@ class BloomPipelineForwards:
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
@ -151,6 +153,8 @@ class BloomPipelineForwards:
if use_cache: if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False use_cache = False
past_key_values = None
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N # attention_probs has shape batch_size x num_heads x N x N
@ -161,46 +165,60 @@ class BloomPipelineForwards:
# case: First stage of training # case: First stage of training
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
# check input_ids and inputs_embeds # check input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None: if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
elif input_ids is not None: if self.gradient_checkpointing and self.training and use_cache:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
# initialize in the first stage and then pass to the next stage
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
# extra recording tensor should be generated in the first stage
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
if past_key_values is None: if inputs_embeds is None:
past_key_values = tuple([None] * len(self.h)) inputs_embeds = self.word_embeddings(input_ids)
# Compute alibi tensor: check build_alibi_tensor documentation,build for every stage
seq_length_with_past = seq_length hidden_states = self.word_embeddings_layernorm(inputs_embeds)
past_key_values_length = 0
if past_key_values[0] is not None: batch_size, seq_length, _ = inputs_embeds.shape
past_key_values_length = past_key_values[0][0].shape[2] # source_len past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device)
# initialize in the first stage and then pass to the next stage
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(past_length, past_length + seq_length, device=hidden_states.device)
# extra recording tensor should be generated in the first stage
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)
# Compute alibi tensor: check build_alibi_tensor documentation,build for every stage
past_length = 0
seq_length_with_past = seq_length + past_length
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
else: else:
@ -209,13 +227,10 @@ class BloomPipelineForwards:
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
# causal_mask is constructed every stage and its input is passed through different stages # causal_mask is constructed every stage and its input is passed through different stages
causal_mask = _prepare_4d_causal_attention_mask( causal_mask = self._update_causal_mask(
attention_mask, attention_mask, hidden_states, cache_position, past_key_values, output_attentions
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_key_values_length,
) )
causal_mask = causal_mask.bool()
# split the input tensor along sequence dimension # split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config and shard_config.enable_sequence_parallelism: if shard_config and shard_config.enable_sequence_parallelism:
@ -228,9 +243,7 @@ class BloomPipelineForwards:
) )
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
for i, (block, layer_past) in enumerate( for i, block in enumerate(self.h[start_idx:end_idx], start=start_idx):
zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx
):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
@ -240,26 +253,28 @@ class BloomPipelineForwards:
hidden_states, hidden_states,
alibi, alibi,
causal_mask, causal_mask,
layer_past, past_key_values,
head_mask[i], head_mask[i],
use_cache, use_cache,
output_attentions, output_attentions,
cache_position,
) )
else: else:
outputs = block( outputs = block(
hidden_states, hidden_states,
layer_past=layer_past, layer_past=past_key_values,
attention_mask=causal_mask, attention_mask=causal_mask,
head_mask=head_mask[i], head_mask=head_mask[i],
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
alibi=alibi, alibi=alibi,
cache_position=cache_position,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache:
next_decoder_cache = outputs[1]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
@ -277,20 +292,23 @@ class BloomPipelineForwards:
# Add last hidden state # Add last hidden state
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
# TODO(jianghai): deal with all_hidden_states, all_self_attentions, presents
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
if not return_dict: if not return_dict:
return tuple( return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
) )
# attention_mask is not returned ; presents = past_key_values # attention_mask is not returned ; presents = past_key_values
return BaseModelOutputWithPastAndCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=presents, past_key_values=next_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attentions, attentions=all_self_attentions,
) )
@ -718,35 +736,24 @@ def get_jit_fused_bloom_attention_forward():
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False, use_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
): ):
batch_size, q_length, _ = hidden_states.shape
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, num_heads, seq_length, head_dim]
query_layer, key_layer, value_layer = self._reshape(fused_qkv)
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, q_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
if layer_past is not None: if layer_past is not None:
past_key, past_value = layer_past cache_kwargs = {"cache_position": cache_position}
# concatenate along seq_length dimension: key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
# - key: [batch_size * self.num_heads, head_dim, kv_length]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=2)
value_layer = torch.cat((past_value, value_layer), dim=1)
_, _, kv_length = key_layer.shape # reshape qkv for further computations
query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
if use_cache is True: key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2)
present = (key_layer, value_layer) value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
else:
present = None
# [batch_size * num_heads, q_length, kv_length] # [batch_size * num_heads, q_length, kv_length]
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 attention_scores = alibi.baddbmm(
matmul_result = alibi.baddbmm(
batch1=query_layer, batch1=query_layer,
batch2=key_layer, batch2=key_layer,
beta=self.beta, beta=self.beta,
@ -754,15 +761,13 @@ def get_jit_fused_bloom_attention_forward():
) )
# change view to [batch_size, num_heads, q_length, kv_length] # change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]]
attn_weights = attn_weights + causal_mask
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
input_dtype = attention_scores.dtype attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype)
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16:
attention_scores = attention_scores.to(torch.float)
attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
# [batch_size, num_heads, q_length, kv_length] # [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
@ -771,12 +776,12 @@ def get_jit_fused_bloom_attention_forward():
attention_probs = attention_probs * head_mask attention_probs = attention_probs * head_mask
# change view [batch_size x num_heads, q_length, kv_length] # change view [batch_size x num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1)
# matmul: [batch_size * num_heads, q_length, head_dim] # matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs_reshaped, value_layer) context_layer = torch.bmm(attention_probs_reshaped, value_layer)
# change view [batch_size, num_heads, q_length, head_dim] # change view [batch_size, q_length, num_heads * head_dim]
context_layer = self._merge_heads(context_layer) context_layer = self._merge_heads(context_layer)
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
@ -791,9 +796,9 @@ def get_jit_fused_bloom_attention_forward():
else: else:
output_tensor = self.dense(context_layer) output_tensor = self.dense(context_layer)
output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training) output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
outputs = (output_tensor, present) outputs = (output_tensor, layer_past)
if output_attentions: if output_attentions:
outputs += (attention_probs,) outputs += (attention_probs,)
@ -839,13 +844,99 @@ def get_jit_fused_bloom_gelu_forward():
return forward return forward
# Fixed the q_length args when doing the sequence parallelism in bloom model.
def get_bloom_sequence_parallel_attention_forward(shard_config: ShardConfig):
from transformers.models.bloom.modeling_bloom import BloomAttention
def forward(
self: BloomAttention,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Cache] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
):
batch_size, q_length, _ = hidden_states.shape
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, num_heads, seq_length, head_dim]
query_layer, key_layer, value_layer = self._reshape(fused_qkv)
if layer_past is not None:
cache_kwargs = {"cache_position": cache_position}
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
# reshape qkv for further computations
query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2)
value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
# [batch_size * num_heads, q_length, kv_length]
attention_scores = alibi.baddbmm(
batch1=query_layer,
batch2=key_layer,
beta=self.beta,
alpha=self.inv_norm_factor,
)
if shard_config.enable_sequence_parallelism:
_, q_length, _ = query_layer.shape
# change view to [batch_size, num_heads, q_length, kv_length]
attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]]
attn_weights = attn_weights + causal_mask
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype)
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# change view [batch_size x num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
# change view [batch_size, q_length, num_heads * head_dim]
context_layer = self._merge_heads(context_layer)
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
outputs = (output_tensor, layer_past)
if output_attentions:
outputs += (attention_probs,)
return outputs
return forward
def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
from transformers import BloomModel from transformers import BloomModel
def forward( def forward(
self: BloomModel, self: BloomModel,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None,
@ -853,6 +944,7 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**deprecated_arguments, **deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
if deprecated_arguments.pop("position_ids", False) is not False: if deprecated_arguments.pop("position_ids", False) is not False:
@ -864,7 +956,6 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
) )
if len(deprecated_arguments) > 0: if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -872,62 +963,60 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None: if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past_key_values is None: if self.gradient_checkpointing and self.training and use_cache:
past_key_values = tuple([None] * len(self.h)) logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)
batch_size, seq_length, _ = inputs_embeds.shape
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
seq_length_with_past = seq_length + past_length
if cache_position is None:
cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N # attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer) head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(inputs_embeds) hidden_states = self.word_embeddings_layernorm(inputs_embeds)
presents = () if use_cache else None next_decoder_cache = None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# Compute alibi tensor: check build_alibi_tensor documentation # Compute alibi tensor: check build_alibi_tensor documentation
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
else: else:
attention_mask = attention_mask.to(hidden_states.device) attention_mask = attention_mask.to(hidden_states.device)
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
causal_mask = self._update_causal_mask(
causal_mask = _prepare_4d_causal_attention_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_key_values_length,
) )
causal_mask = causal_mask.bool()
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, hidden_states,
dim=1, dim=1,
@ -935,7 +1024,7 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
fp8_communication=shard_config.fp8_communication, fp8_communication=shard_config.fp8_communication,
) )
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): for i, block in enumerate(self.h):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
@ -945,25 +1034,27 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
hidden_states, hidden_states,
alibi, alibi,
causal_mask, causal_mask,
layer_past, past_key_values,
head_mask[i], head_mask[i],
use_cache, use_cache,
output_attentions, output_attentions,
cache_position,
) )
else: else:
outputs = block( outputs = block(
hidden_states, hidden_states,
layer_past=layer_past, layer_past=past_key_values,
attention_mask=causal_mask, attention_mask=causal_mask,
head_mask=head_mask[i], head_mask=head_mask[i],
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
alibi=alibi, alibi=alibi,
cache_position=cache_position,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache is True: if use_cache:
presents = presents + (outputs[1],) next_decoder_cache = outputs[1]
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
@ -975,18 +1066,25 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
process_group=shard_config.tensor_parallel_process_group, process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication, fp8_communication=shard_config.fp8_communication,
) )
# Add last hidden state # Add last hidden state
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) return tuple(
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=presents, past_key_values=next_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attentions, attentions=all_self_attentions,
) )

View File

@ -11,6 +11,7 @@ import colossalai.shardformer.layer as col_nn
from ..modeling.bloom import ( from ..modeling.bloom import (
BloomPipelineForwards, BloomPipelineForwards,
build_bloom_alibi_tensor_fn, build_bloom_alibi_tensor_fn,
get_bloom_sequence_parallel_attention_forward,
get_bloom_sequence_parallel_forward_fn, get_bloom_sequence_parallel_forward_fn,
get_jit_fused_bloom_attention_forward, get_jit_fused_bloom_attention_forward,
get_jit_fused_bloom_gelu_forward, get_jit_fused_bloom_gelu_forward,
@ -61,6 +62,15 @@ class BloomPolicy(Policy):
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_sequence_parallelism:
self.append_or_create_method_replacement(
description={
"forward": get_bloom_sequence_parallel_attention_forward(self.shard_config),
},
policy=policy,
target_key=BloomAttention,
)
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
assert ( assert (
self.model.config.n_head % self.shard_config.tensor_parallel_size == 0 self.model.config.n_head % self.shard_config.tensor_parallel_size == 0