mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-05-08 23:05:41 +00:00
[upgrade]Upgrade transformers (#6320)
* fix for async io * test for upgrading transformers * add ci machine * fix * fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_fp16_torch.py * Update build_on_pr.yml * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fiux * fix * fix * fix * upgrade llama * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * upgrade_bert * upgrade_bloom * [upgrade] upgrade gpt2 (#6291) * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * upgrade command * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * add explanation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * [upgrade]Upgrade qwen2 (#6302) * upgrade qwen2 * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * update_bloom * fix * add explantion * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * upgrade_sam * add the explanation * upgrade_t * fix * fix * fix * upgrade_gptj * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [upgrade]upgrade opt (#6307) * upgrade opt * fix * [upgrade]Upgrade mixtral (#6317) * upgrade mixtral * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * upgrade infer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * upgrade drafter * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * upgrade lazy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * upgrade mixtral --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [upgrade]Upgrade vit (#6308) * fix * fix * fix rotate embedding test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [upgrade]upgrade mistral (#6296) * upgrade mistral * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix falcon * fix * Update test_shard_deepseek.py * Update build_on_pr.yml * Update requirements.txt * fix (#6327) * fix (#6328) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update bert.py * fix (#6329) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hanks <hangxu0304@gmail.com> Co-authored-by: wangbluo <2538539015@qq.com> Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
This commit is contained in:
@@ -6,19 +6,16 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaConfig,
|
||||
LlamaDecoderLayer,
|
||||
LlamaDynamicNTKScalingRotaryEmbedding,
|
||||
LlamaForCausalLM,
|
||||
LlamaLinearScalingRotaryEmbedding,
|
||||
LlamaMLP,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
LlamaRotaryEmbedding,
|
||||
)
|
||||
|
||||
from colossalai.inference.spec import GlideInput
|
||||
@@ -156,15 +153,11 @@ def glide_llama_model_forward(
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
past_seen_tokens = 0
|
||||
if use_cache: # kept for BC (cache positions)
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
if isinstance(past_key_values, StaticCache):
|
||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
@@ -172,15 +165,17 @@ def glide_llama_model_forward(
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
|
||||
if hasattr(glide_input, "n_spec_tokens"):
|
||||
position_ids = position_ids + glide_input.n_spec_tokens
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
@@ -189,9 +184,9 @@ def glide_llama_model_forward(
|
||||
# GlideLlamaDecoderLayer
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
glide_input=glide_input,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
@@ -200,9 +195,6 @@ def glide_llama_model_forward(
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
@@ -212,16 +204,11 @@ def glide_llama_model_forward(
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = (
|
||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
||||
)
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return tuple(v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
@@ -267,31 +254,6 @@ class LlamaCrossAttention(nn.Module):
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.large_num_heads * self.large_head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.large_num_heads * self.large_head_dim, self.hidden_size, bias=False)
|
||||
self._init_rope()
|
||||
|
||||
def _init_rope(self):
|
||||
if self.config.rope_scaling is None:
|
||||
self.rotary_emb = LlamaRotaryEmbedding(
|
||||
self.large_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
)
|
||||
else:
|
||||
scaling_type = self.config.rope_scaling["type"]
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
||||
self.large_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
||||
self.large_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
@@ -299,9 +261,10 @@ class LlamaCrossAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
glide_input: GlideInput = None, # Used for glimpsing main model's KV caches
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Optional[torch.Tensor]:
|
||||
@@ -319,8 +282,7 @@ class LlamaCrossAttention(nn.Module):
|
||||
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)
|
||||
|
||||
# for RoPE
|
||||
position_ids = position_ids + glide_input.n_spec_tokens
|
||||
cos, sin = self.rotary_emb(query_states, position_ids)
|
||||
cos, sin = position_embeddings
|
||||
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
|
||||
query_states = query_states.transpose(1, 2)
|
||||
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)
|
||||
@@ -367,9 +329,10 @@ class GlideLlamaDecoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: torch.Tensor = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
glide_input: GlideInput = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
@@ -399,10 +362,10 @@ class GlideLlamaDecoderLayer(nn.Module):
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
@@ -425,9 +388,10 @@ class GlideLlamaDecoderLayer(nn.Module):
|
||||
|
||||
hidden_states = self.cross_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
position_ids=position_ids,
|
||||
glide_input=glide_input,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=True,
|
||||
)
|
||||
@@ -441,9 +405,6 @@ class GlideLlamaDecoderLayer(nn.Module):
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
@@ -478,9 +478,9 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||
attn_oproj=attn_oproj,
|
||||
process_group=process_group,
|
||||
model_shard_infer_config=model_shard_infer_config,
|
||||
num_heads=module.num_heads,
|
||||
hidden_size=module.hidden_size,
|
||||
num_key_value_heads=module.num_key_value_heads,
|
||||
num_heads=module.config.num_attention_heads,
|
||||
hidden_size=module.config.hidden_size,
|
||||
num_key_value_heads=module.config.num_key_value_heads,
|
||||
)
|
||||
|
||||
return attn_layer
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.cache_utils import DynamicCache
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
@@ -93,9 +94,8 @@ class Drafter:
|
||||
|
||||
for _ in range(n_spec_tokens):
|
||||
# update past key values
|
||||
kwargs["past_key_values"] = past_key_values
|
||||
|
||||
outputs = self._drafter_model(input_ids, **kwargs)
|
||||
outputs = self._drafter_model(input_ids, past_key_values=past_key_values, **kwargs)
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# NOTE Only use greedy search for speculating.
|
||||
@@ -114,6 +114,8 @@ class Drafter:
|
||||
speculated_length = len(token_ids) # For now, only support bsz 1
|
||||
logits = torch.concat(logits, dim=0)
|
||||
token_ids = torch.concat(token_ids, dim=-1)
|
||||
if isinstance(past_key_values, DynamicCache):
|
||||
past_key_values = past_key_values.to_legacy_cache()
|
||||
|
||||
out = DrafterOutput(
|
||||
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values
|
||||
|
||||
Reference in New Issue
Block a user