mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[upgrade]Upgrade transformers (#6320)
* fix for async io * test for upgrading transformers * add ci machine * fix * fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_fp16_torch.py * Update build_on_pr.yml * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fiux * fix * fix * fix * upgrade llama * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * upgrade_bert * upgrade_bloom * [upgrade] upgrade gpt2 (#6291) * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * upgrade command * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * add explanation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * [upgrade]Upgrade qwen2 (#6302) * upgrade qwen2 * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * update_bloom * fix * add explantion * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * upgrade_sam * add the explanation * upgrade_t * fix * fix * fix * upgrade_gptj * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [upgrade]upgrade opt (#6307) * upgrade opt * fix * [upgrade]Upgrade mixtral (#6317) * upgrade mixtral * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * upgrade infer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * upgrade drafter * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * upgrade lazy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * upgrade mixtral --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [upgrade]Upgrade vit (#6308) * fix * fix * fix rotate embedding test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [upgrade]upgrade mistral (#6296) * upgrade mistral * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix falcon * fix * Update test_shard_deepseek.py * Update build_on_pr.yml * Update requirements.txt * fix (#6327) * fix (#6328) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update bert.py * fix (#6329) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hanks <hangxu0304@gmail.com> Co-authored-by: wangbluo <2538539015@qq.com> Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
This commit is contained in:
@@ -1,19 +1,19 @@
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
CohereAttention,
|
||||
CohereForCausalLM,
|
||||
CohereModel,
|
||||
StaticCache,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
@@ -27,6 +27,8 @@ _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring"]
|
||||
|
||||
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class CommandPipelineForwards:
|
||||
"""
|
||||
@@ -37,22 +39,23 @@ class CommandPipelineForwards:
|
||||
@staticmethod
|
||||
def command_model_forward(
|
||||
self: CohereModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
force_sp_output_gather: bool = True,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
):
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@@ -67,8 +70,6 @@ class CommandPipelineForwards:
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
@@ -122,6 +123,7 @@ class CommandPipelineForwards:
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
shard_config.enable_flash_attention = True
|
||||
if shard_config.enable_flash_attention:
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
|
||||
@@ -133,7 +135,8 @@ class CommandPipelineForwards:
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
||||
# v4.51.3 transformers attention_mask calculation
|
||||
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values)
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
if use_cache:
|
||||
@@ -163,6 +166,8 @@ class CommandPipelineForwards:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
# v4.51.3 transformers position_embeddings calculation
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
num_ckpt_layers = 0
|
||||
@@ -193,6 +198,7 @@ class CommandPipelineForwards:
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@@ -203,6 +209,7 @@ class CommandPipelineForwards:
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@@ -224,17 +231,6 @@ class CommandPipelineForwards:
|
||||
all_hidden_states += (hidden_states,)
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
@@ -352,27 +348,22 @@ class CommandPipelineForwards:
|
||||
|
||||
def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||
def forward(
|
||||
self,
|
||||
self: CohereAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if sp_mode is not None:
|
||||
assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet"
|
||||
assert (sp_size is not None) and (
|
||||
sp_group is not None
|
||||
), "Must specify sp_size and sp_group for sequence parallel"
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# sp: modify sp_len when sequence parallel mode is ring
|
||||
if sp_mode in ["split_gather", "ring"]:
|
||||
q_len *= sp_size
|
||||
@@ -388,60 +379,36 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
cos, sin = position_embeddings
|
||||
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attn_weights = None
|
||||
|
||||
shard_config.enable_flash_attention = True
|
||||
|
||||
if shard_config.enable_flash_attention:
|
||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||
else:
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
# attn_weights and attn_output calculation is modified on the v4.51.3 of transformers.models.cohere.modeling_cohere.CohereAttention.forward.
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
dropout = 0.0 if not self.training else self.attention_dropout
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
@@ -451,13 +418,11 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights
|
||||
|
||||
return forward
|
||||
|
||||
@@ -467,24 +432,23 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
force_sp_output_gather: bool = True,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
@@ -516,6 +480,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
shard_config.enable_flash_attention = True
|
||||
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
if shard_config.enable_flash_attention:
|
||||
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
|
||||
@@ -527,7 +493,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
# v4.51.3 transformers attention_mask calculation
|
||||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
@@ -544,6 +511,9 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
# v4.51.3 transformers position_embeddings calculation
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
@@ -557,8 +527,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
@@ -568,16 +538,11 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# Cases that don't support parallelizing cross entropy computation along sequence
|
||||
@@ -594,8 +559,6 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||
next_cache = (
|
||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
||||
)
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
|
Reference in New Issue
Block a user