This commit is contained in:
jing 2025-04-21 15:35:19 +00:00 committed by GitHub
commit ea281cd4c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 462 additions and 1 deletions

View File

@ -0,0 +1,302 @@
from typing import List, Optional
import torch
import torch.distributed
import torch.utils.checkpoint
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM, Gemma2Model
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import gather_sp_output
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
from colossalai.shardformer.shard import ShardConfig
from ..layer import RingAttention, dist_cross_entropy
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
class Gemma2PipelineForwards:
"""
This class serves as a micro library for forward function substitution of Llama models
under pipeline setting.
"""
@staticmethod
def gemma2_model_forward(
self: Gemma2Model,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
force_sp_gather: bool = True, # Set to false only when computing cross entropy
):
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with pipeline parallelism. Setting `use_cache=False`..."
)
use_cache = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
disable_pp = stage_manager is None
# retrieve input_ids and inputs_embeds
if disable_pp or stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
device = hidden_states.device
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
# Support SP + PP
sp_mode = shard_config.sequence_parallelism_mode
shard_config.sequence_parallel_process_group
sp_size = shard_config.sequence_parallel_size
# Generating full positions ids for modes that gather sequence before attn
if stage_manager and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()):
seq_length *= sp_size
past_seen_tokens = 0
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
seq_length + past_seen_tokens
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
attn_kwargs: torch.Tensor = self._update_causal_mask(
attention_mask, hidden_states, cache_position, past_key_values, output_attentions
)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1])
num_ckpt_layers = 0
if self.gradient_checkpointing and self.training:
num_ckpt_layers = end_idx - start_idx
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage,
num_stages=stage_manager.num_stages,
num_layers=end_idx - start_idx,
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
num_model_chunks=stage_manager.num_model_chunks,
)
assert num_ckpt_layers <= end_idx - start_idx
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if idx - start_idx < num_ckpt_layers:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attn_kwargs,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attn_kwargs,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if disable_pp or stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
if (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode): # noqa
hidden_states = gather_sp_output(hidden_states, shard_config)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if disable_pp or stage_manager.is_last_stage():
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
# always return dict for intermediate stage
return {"hidden_states": hidden_states}
@staticmethod
def gemma2_for_causal_lm_forward(
self: Gemma2ForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**kwargs,
):
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output:
# Split labels in a zigzag fashion too
sp_group = shard_config.sequence_parallel_process_group
if attention_mask.bool().all():
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
else:
# [B, max_seqlen // sp_size]
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = Gemma2PipelineForwards.gemma2_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
force_sp_gather=False,
)
past_key_values = None
disable_pp = stage_manager is None
if disable_pp or stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
hidden_states = outputs.get("hidden_states")
return {"hidden_states": hidden_states}

View File

@ -141,7 +141,9 @@ class LlamaPipelineForwards:
invert=(sp_mode != "ring_attn"),
)
else:
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position)
attn_kwargs: torch.Tensor = self._update_causal_mask(
attention_mask, hidden_states, cache_position, past_key_values, output_attentions
)
# Support SP + PP. Later stages have already received the split input.
split_input = disable_pp or stage_manager.is_first_stage()

View File

@ -227,6 +227,10 @@ _POLICY_LIST = {
"transformers.models.cohere.modeling_cohere.CohereForCausalLM": PolicyLocation(
file_name="command", class_name="CommandForCausalLMPolicy"
),
# gemma2
"transformers.models.gemma2.modeling_gemma2.Gemma2ForCausalLM": PolicyLocation(
file_name="gemma2", class_name="Gemma2ForCausalLMPolicy"
),
}

View File

@ -0,0 +1,153 @@
from functools import partial
from typing import Dict, Union
import torch.nn as nn
from colossalai.shardformer.layer import (
Linear1D_Col,
Linear1D_Row,
PaddingEmbedding,
PaddingLMHead,
RMSNorm,
VocabParallelEmbedding1D,
VocabParallelLMHead1D,
)
from ..modeling.gemma2 import Gemma2PipelineForwards
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["Gemma2Policy", "Gemma2ForCausalLMPolicy"]
class Gemma2Policy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
self.tie_weight = self.tie_weight_check()
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer, Gemma2Model
policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
norm_cls = RMSNorm
if self.shard_config.enable_tensor_parallelism:
tp_size = self.shard_config.tensor_parallel_size
num_q_heads = self.model.config.num_attention_heads // tp_size
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // tp_size,
"self_attn.num_heads": num_q_heads,
}
num_kv_heads = self.model.config.num_key_value_heads // tp_size
decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads
policy[Gemma2DecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(suffix="mlp.gate_proj", target_module=Linear1D_Col),
SubModuleReplacementDescription(suffix="mlp.up_proj", target_module=Linear1D_Col),
SubModuleReplacementDescription(suffix="mlp.down_proj", target_module=Linear1D_Row),
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
],
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=Gemma2Model,
)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(suffix="input_layernorm", target_module=norm_cls),
SubModuleReplacementDescription(suffix="pre_feedforward_layernorm", target_module=norm_cls),
SubModuleReplacementDescription(suffix="post_feedforward_layernorm", target_module=norm_cls),
SubModuleReplacementDescription(suffix="post_attention_layernorm", target_module=norm_cls),
],
policy=policy,
target_key=Gemma2DecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=norm_cls,
),
policy=policy,
target_key=Gemma2Model,
)
return policy
def postprocess(self):
return self.model
class Gemma2ForCausalLMPolicy(Gemma2Policy):
def module_policy(self):
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
target_module=VocabParallelLMHead1D,
kwargs=dict(
gather_output=not self.shard_config.parallel_output,
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
),
),
policy=policy,
target_key=Gemma2ForCausalLM,
)
if self.shard_config.parallel_output:
method_replacement = {
"forward": partial(
Gemma2PipelineForwards.gemma2_for_causal_lm_forward, shard_config=self.shard_config
)
}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=Gemma2ForCausalLM
)
else:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
target_module=PaddingLMHead,
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
),
policy=policy,
target_key=Gemma2ForCausalLM,
)
return policy