mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-10 04:18:05 +00:00
upgrade_bloom
This commit is contained in:
parent
08787f0b6e
commit
5480b811c5
@ -6,7 +6,7 @@ import torch.distributed as dist
|
|||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
CausalLMOutputWithCrossAttentions,
|
CausalLMOutputWithCrossAttentions,
|
||||||
@ -108,7 +108,7 @@ class BloomPipelineForwards:
|
|||||||
def bloom_model_forward(
|
def bloom_model_forward(
|
||||||
self: BloomModel,
|
self: BloomModel,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
head_mask: Optional[torch.LongTensor] = None,
|
head_mask: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||||
@ -116,6 +116,7 @@ class BloomPipelineForwards:
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
@ -151,6 +152,8 @@ class BloomPipelineForwards:
|
|||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
past_key_values = None
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
# attention_probs has shape batch_size x num_heads x N x N
|
# attention_probs has shape batch_size x num_heads x N x N
|
||||||
@ -161,46 +164,60 @@ class BloomPipelineForwards:
|
|||||||
# case: First stage of training
|
# case: First stage of training
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
# check input_ids and inputs_embeds
|
# check input_ids and inputs_embeds
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
elif input_ids is not None:
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.word_embeddings(input_ids)
|
|
||||||
|
|
||||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
|
||||||
# initialize in the first stage and then pass to the next stage
|
|
||||||
else:
|
|
||||||
input_shape = hidden_states.shape[:-1]
|
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
|
|
||||||
# extra recording tensor should be generated in the first stage
|
|
||||||
|
|
||||||
presents = () if use_cache else None
|
|
||||||
all_self_attentions = () if output_attentions else None
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
if past_key_values is None:
|
if inputs_embeds is None:
|
||||||
past_key_values = tuple([None] * len(self.h))
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
# Compute alibi tensor: check build_alibi_tensor documentation,build for every stage
|
|
||||||
seq_length_with_past = seq_length
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||||
past_key_values_length = 0
|
|
||||||
if past_key_values[0] is not None:
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
past_key_values_length = past_key_values[0][0].shape[2] # source_len
|
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
if cache_position is None:
|
||||||
|
cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device)
|
||||||
|
# initialize in the first stage and then pass to the next stage
|
||||||
|
else:
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
if cache_position is None:
|
||||||
|
cache_position = torch.arange(past_length, past_length + seq_length, device=hidden_states.device)
|
||||||
|
|
||||||
|
# extra recording tensor should be generated in the first stage
|
||||||
|
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
|
return_legacy_cache = False
|
||||||
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
|
return_legacy_cache = True
|
||||||
|
if past_key_values is None:
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute alibi tensor: check build_alibi_tensor documentation,build for every stage
|
||||||
|
past_length = 0
|
||||||
|
seq_length_with_past = seq_length + past_length
|
||||||
|
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||||
else:
|
else:
|
||||||
@ -209,13 +226,10 @@ class BloomPipelineForwards:
|
|||||||
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||||
|
|
||||||
# causal_mask is constructed every stage and its input is passed through different stages
|
# causal_mask is constructed every stage and its input is passed through different stages
|
||||||
causal_mask = _prepare_4d_causal_attention_mask(
|
causal_mask = self._update_causal_mask(
|
||||||
attention_mask,
|
attention_mask, hidden_states, cache_position, past_key_values, output_attentions
|
||||||
input_shape=(batch_size, seq_length),
|
|
||||||
inputs_embeds=hidden_states,
|
|
||||||
past_key_values_length=past_key_values_length,
|
|
||||||
)
|
)
|
||||||
causal_mask = causal_mask.bool()
|
|
||||||
# split the input tensor along sequence dimension
|
# split the input tensor along sequence dimension
|
||||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||||
if shard_config and shard_config.enable_sequence_parallelism:
|
if shard_config and shard_config.enable_sequence_parallelism:
|
||||||
@ -228,9 +242,7 @@ class BloomPipelineForwards:
|
|||||||
)
|
)
|
||||||
|
|
||||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
for i, (block, layer_past) in enumerate(
|
for i, block in enumerate(self.h[start_idx:end_idx], start=start_idx):
|
||||||
zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx
|
|
||||||
):
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
@ -240,26 +252,28 @@ class BloomPipelineForwards:
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
alibi,
|
alibi,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
layer_past,
|
past_key_values,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
use_cache,
|
use_cache,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
|
cache_position,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=past_key_values,
|
||||||
attention_mask=causal_mask,
|
attention_mask=causal_mask,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
alibi=alibi,
|
alibi=alibi,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = outputs[1]
|
||||||
|
|
||||||
if use_cache is True:
|
|
||||||
presents = presents + (outputs[1],)
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
|
|
||||||
@ -277,20 +291,23 @@ class BloomPipelineForwards:
|
|||||||
# Add last hidden state
|
# Add last hidden state
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
|
||||||
# TODO(jianghai): deal with all_hidden_states, all_self_attentions, presents
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if return_legacy_cache:
|
||||||
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
|
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
# attention_mask is not returned ; presents = past_key_values
|
# attention_mask is not returned ; presents = past_key_values
|
||||||
return BaseModelOutputWithPastAndCrossAttentions(
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=presents,
|
past_key_values=next_cache,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attentions,
|
attentions=all_self_attentions,
|
||||||
)
|
)
|
||||||
@ -845,7 +862,7 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||||||
def forward(
|
def forward(
|
||||||
self: BloomModel,
|
self: BloomModel,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
head_mask: Optional[torch.LongTensor] = None,
|
head_mask: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||||
@ -853,6 +870,7 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**deprecated_arguments,
|
**deprecated_arguments,
|
||||||
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
if deprecated_arguments.pop("position_ids", False) is not False:
|
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||||
@ -864,7 +882,6 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||||||
)
|
)
|
||||||
if len(deprecated_arguments) > 0:
|
if len(deprecated_arguments) > 0:
|
||||||
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@ -872,62 +889,60 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
elif input_ids is not None:
|
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if past_key_values is None:
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
past_key_values = tuple([None] * len(self.h))
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
|
return_legacy_cache = False
|
||||||
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
|
return_legacy_cache = True
|
||||||
|
if past_key_values is None:
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
seq_length_with_past = seq_length + past_length
|
||||||
|
if cache_position is None:
|
||||||
|
cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
# attention_probs has shape batch_size x num_heads x N x N
|
# attention_probs has shape batch_size x num_heads x N x N
|
||||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.word_embeddings(input_ids)
|
|
||||||
|
|
||||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||||
|
|
||||||
presents = () if use_cache else None
|
next_decoder_cache = None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
# Compute alibi tensor: check build_alibi_tensor documentation
|
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||||
seq_length_with_past = seq_length
|
|
||||||
past_key_values_length = 0
|
|
||||||
if past_key_values[0] is not None:
|
|
||||||
past_key_values_length = past_key_values[0][0].shape[2]
|
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||||
else:
|
else:
|
||||||
attention_mask = attention_mask.to(hidden_states.device)
|
attention_mask = attention_mask.to(hidden_states.device)
|
||||||
|
|
||||||
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||||
|
causal_mask = self._update_causal_mask(
|
||||||
causal_mask = _prepare_4d_causal_attention_mask(
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||||
attention_mask,
|
|
||||||
input_shape=(batch_size, seq_length),
|
|
||||||
inputs_embeds=hidden_states,
|
|
||||||
past_key_values_length=past_key_values_length,
|
|
||||||
)
|
)
|
||||||
causal_mask = causal_mask.bool()
|
|
||||||
# split the input tensor along sequence dimension
|
|
||||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
|
||||||
hidden_states = split_forward_gather_backward(
|
hidden_states = split_forward_gather_backward(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
dim=1,
|
dim=1,
|
||||||
@ -935,7 +950,7 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||||||
fp8_communication=shard_config.fp8_communication,
|
fp8_communication=shard_config.fp8_communication,
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
for i, block in enumerate(self.h):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
@ -945,48 +960,49 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
alibi,
|
alibi,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
layer_past,
|
past_key_values,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
use_cache,
|
use_cache,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
|
cache_position,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=past_key_values,
|
||||||
attention_mask=causal_mask,
|
attention_mask=causal_mask,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
alibi=alibi,
|
alibi=alibi,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if use_cache is True:
|
if use_cache:
|
||||||
presents = presents + (outputs[1],)
|
next_decoder_cache = outputs[1]
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
|
|
||||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
|
||||||
hidden_states = gather_forward_split_backward(
|
|
||||||
hidden_states,
|
|
||||||
dim=1,
|
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
|
||||||
fp8_communication=shard_config.fp8_communication,
|
|
||||||
)
|
|
||||||
# Add last hidden state
|
# Add last hidden state
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if return_legacy_cache:
|
||||||
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
return tuple(
|
||||||
|
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
return BaseModelOutputWithPastAndCrossAttentions(
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=presents,
|
past_key_values=next_cache,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attentions,
|
attentions=all_self_attentions,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user