diff --git a/colossalai/inference/modeling/models/glide_llama.py b/colossalai/inference/modeling/models/glide_llama.py index 0ee78a303..520eccef4 100644 --- a/colossalai/inference/modeling/models/glide_llama.py +++ b/colossalai/inference/modeling/models/glide_llama.py @@ -6,19 +6,16 @@ from typing import List, Optional, Tuple, Union import torch import torch.nn as nn -from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.cache_utils import DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaConfig, LlamaDecoderLayer, - LlamaDynamicNTKScalingRotaryEmbedding, LlamaForCausalLM, - LlamaLinearScalingRotaryEmbedding, LlamaMLP, LlamaModel, LlamaRMSNorm, - LlamaRotaryEmbedding, ) from colossalai.inference.spec import GlideInput @@ -156,15 +153,11 @@ def glide_llama_model_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -172,15 +165,17 @@ def glide_llama_model_forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) + if hasattr(glide_input, "n_spec_tokens"): + position_ids = position_ids + glide_input.n_spec_tokens # embed positions hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -189,9 +184,9 @@ def glide_llama_model_forward( # GlideLlamaDecoderLayer layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, glide_input=glide_input, attention_mask=attention_mask, - position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, @@ -200,9 +195,6 @@ def glide_llama_model_forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -212,16 +204,11 @@ def glide_llama_model_forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache - ) if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple(v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) @@ -267,31 +254,6 @@ class LlamaCrossAttention(nn.Module): self.q_proj = nn.Linear(self.hidden_size, self.large_num_heads * self.large_head_dim, bias=False) self.o_proj = nn.Linear(self.large_num_heads * self.large_head_dim, self.hidden_size, bias=False) - self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding( - self.large_head_dim, - max_position_embeddings=self.max_position_embeddings, - ) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.large_head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - ) - elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.large_head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -299,9 +261,10 @@ class LlamaCrossAttention(nn.Module): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, glide_input: GlideInput = None, # Used for glimpsing main model's KV caches attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Optional[torch.Tensor]: @@ -319,8 +282,7 @@ class LlamaCrossAttention(nn.Module): query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2) # for RoPE - position_ids = position_ids + glide_input.n_spec_tokens - cos, sin = self.rotary_emb(query_states, position_ids) + cos, sin = position_embeddings query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids) query_states = query_states.transpose(1, 2) query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim) @@ -367,9 +329,10 @@ class GlideLlamaDecoderLayer(nn.Module): def forward( self, hidden_states: torch.Tensor, + position_embeddings: torch.Tensor = None, + position_ids: Optional[torch.LongTensor] = None, glide_input: GlideInput = None, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -399,10 +362,10 @@ class GlideLlamaDecoderLayer(nn.Module): hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, - position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -425,9 +388,10 @@ class GlideLlamaDecoderLayer(nn.Module): hidden_states = self.cross_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, + position_ids=position_ids, glide_input=glide_input, attention_mask=attention_mask, - position_ids=position_ids, output_attentions=output_attentions, use_cache=True, ) @@ -441,9 +405,6 @@ class GlideLlamaDecoderLayer(nn.Module): outputs = (hidden_states,) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index c7c7473ac..6c040dd22 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -478,9 +478,9 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): attn_oproj=attn_oproj, process_group=process_group, model_shard_infer_config=model_shard_infer_config, - num_heads=module.num_heads, - hidden_size=module.hidden_size, - num_key_value_heads=module.num_key_value_heads, + num_heads=module.config.num_attention_heads, + hidden_size=module.config.hidden_size, + num_key_value_heads=module.config.num_key_value_heads, ) return attn_layer diff --git a/colossalai/inference/spec/drafter.py b/colossalai/inference/spec/drafter.py index 3144b2c90..81d26be5c 100644 --- a/colossalai/inference/spec/drafter.py +++ b/colossalai/inference/spec/drafter.py @@ -3,6 +3,7 @@ from typing import Optional, Tuple import torch import torch.nn as nn from transformers import PreTrainedTokenizer +from transformers.cache_utils import DynamicCache from colossalai.utils import get_current_device @@ -93,9 +94,8 @@ class Drafter: for _ in range(n_spec_tokens): # update past key values - kwargs["past_key_values"] = past_key_values - outputs = self._drafter_model(input_ids, **kwargs) + outputs = self._drafter_model(input_ids, past_key_values=past_key_values, **kwargs) next_token_logits = outputs.logits[:, -1, :] # NOTE Only use greedy search for speculating. @@ -114,6 +114,8 @@ class Drafter: speculated_length = len(token_ids) # For now, only support bsz 1 logits = torch.concat(logits, dim=0) token_ids = torch.concat(token_ids, dim=-1) + if isinstance(past_key_values, DynamicCache): + past_key_values = past_key_values.to_legacy_cache() out = DrafterOutput( speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values diff --git a/colossalai/lazy/pretrained.py b/colossalai/lazy/pretrained.py index 226951598..66f4cf3bb 100644 --- a/colossalai/lazy/pretrained.py +++ b/colossalai/lazy/pretrained.py @@ -69,7 +69,7 @@ def new_from_pretrained( _ = kwargs.pop("mirror", None) from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) - _fast_init = kwargs.pop("_fast_init", True) + kwargs.pop("_fast_init", True) torch_dtype = kwargs.pop("torch_dtype", None) subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) @@ -286,7 +286,8 @@ def new_from_pretrained( config.name_or_path = pretrained_model_name_or_path # Instantiate model. - init_contexts = [no_init_weights(_enable=_fast_init)] + # init_contexts = [no_init_weights(_enable=_fast_init)] + init_contexts = [no_init_weights()] with ContextManagers(init_contexts): model = cls(config, *model_args, **model_kwargs) diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index a88db87bc..2d094040a 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -1,6 +1,6 @@ import inspect import warnings -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -13,6 +13,7 @@ from transformers.modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask_for_sdpa, ) from transformers.models.mixtral.modeling_mixtral import ( + MixtralModel, MixtralSparseMoeBlock, MoeCausalLMOutputWithPast, MoeModelOutputWithPast, @@ -215,7 +216,7 @@ class MixtralPipelineForwards: @staticmethod def mixtral_model_forward( - self, + self: MixtralModel, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -225,6 +226,7 @@ class MixtralPipelineForwards: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, @@ -340,11 +342,17 @@ class MixtralPipelineForwards: ) use_cache = False + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None - next_decoder_cache = None + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device + ) start_idx, end_idx = stage_index[0], stage_index[1] for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): @@ -370,6 +378,9 @@ class MixtralPipelineForwards: None, output_attentions, output_router_logits, + use_cache, + cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -380,6 +391,8 @@ class MixtralPipelineForwards: output_attentions, output_router_logits, use_cache, + cache_position, + position_embeddings, ) hidden_states = layer_outputs[0] @@ -559,14 +572,18 @@ class MixtralPipelineForwards: def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + from transformers.models.mixtral.modeling_mixtral import eager_attention_forward def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: @@ -614,54 +631,23 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - use_sliding_windows = ( - _flash_supports_window_size - and getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - ) if not _flash_supports_window_size: logger.warning_once( "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" " make sure to upgrade flash-attn library." ) if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if ( - getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents - ): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout + 0.0 if not self.training else self.attention_dropout # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need @@ -689,14 +675,27 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = self._flash_attention_forward( + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=dropout_rate, - use_sliding_windows=use_sliding_windows, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama + **kwargs, ) # sp: all-to-all comminucation when introducing sequence parallel @@ -712,7 +711,7 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights return forward @@ -731,6 +730,7 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, MoeModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -788,7 +788,7 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - if self._attn_implementation == "flash_attention_2": + if self.config._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self._attn_implementation == "sdpa" and not output_attentions: @@ -820,6 +820,16 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz ) hidden_states = inputs_embeds + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -840,6 +850,8 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz output_attentions, output_router_logits, use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, ) else: layer_outputs = decoder_layer( @@ -850,6 +862,8 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index fab437c01..a9584db9b 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -40,21 +40,9 @@ class MixtralPolicy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.mixtral.modeling_mixtral import ( - MixtralAttention, - MixtralDecoderLayer, - MixtralFlashAttention2, - MixtralModel, - MixtralSdpaAttention, - ) + from transformers.models.mixtral.modeling_mixtral import MixtralAttention, MixtralDecoderLayer, MixtralModel - ATTN_IMPLEMENTATION = { - "eager": MixtralAttention, - "flash_attention_2": MixtralFlashAttention2, - "sdpa": MixtralSdpaAttention, - } policy = {} - attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None @@ -76,7 +64,7 @@ class MixtralPolicy(Policy): num_kv_heads //= sp_size decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads - policy[attn_cls] = ModulePolicyDescription( + policy[MixtralAttention] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) if self.shard_config.enable_sequence_parallelism: @@ -89,7 +77,7 @@ class MixtralPolicy(Policy): "forward": get_mixtral_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), }, policy=policy, - target_key=attn_cls, + target_key=MixtralAttention, ) self.append_or_create_method_replacement( description={ @@ -330,7 +318,7 @@ class MixtralPolicy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] - + held_layers.append(module.rotary_emb) if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None layers_per_stage = stage_manager.distribute_layers(len(module.layers))