mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-04 05:09:46 +00:00
Merge 753db97eb3
into 46ed5d856b
This commit is contained in:
commit
ea281cd4c9
302
colossalai/shardformer/modeling/gemma2.py
Normal file
302
colossalai/shardformer/modeling/gemma2.py
Normal file
@ -0,0 +1,302 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.utils.checkpoint
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM, Gemma2Model
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer._operation import gather_sp_output
|
||||
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import RingAttention, dist_cross_entropy
|
||||
|
||||
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
||||
|
||||
|
||||
class Gemma2PipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of Llama models
|
||||
under pipeline setting.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def gemma2_model_forward(
|
||||
self: Gemma2Model,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
force_sp_gather: bool = True, # Set to false only when computing cross entropy
|
||||
):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with pipeline parallelism. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
disable_pp = stage_manager is None
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if disable_pp or stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape[:2]
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
device = hidden_states.device
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
# Support SP + PP
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
shard_config.sequence_parallel_process_group
|
||||
sp_size = shard_config.sequence_parallel_size
|
||||
# Generating full positions ids for modes that gather sequence before attn
|
||||
if stage_manager and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()):
|
||||
seq_length *= sp_size
|
||||
|
||||
past_seen_tokens = 0
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
|
||||
|
||||
seq_length + past_seen_tokens
|
||||
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
attn_kwargs: torch.Tensor = self._update_causal_mask(
|
||||
attention_mask, hidden_states, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1])
|
||||
|
||||
num_ckpt_layers = 0
|
||||
if self.gradient_checkpointing and self.training:
|
||||
num_ckpt_layers = end_idx - start_idx
|
||||
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
|
||||
if shard_config.gradient_checkpoint_config is not None:
|
||||
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
|
||||
stage=stage_manager.stage,
|
||||
num_stages=stage_manager.num_stages,
|
||||
num_layers=end_idx - start_idx,
|
||||
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
)
|
||||
assert num_ckpt_layers <= end_idx - start_idx
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
if idx - start_idx < num_ckpt_layers:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attn_kwargs,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attn_kwargs,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if disable_pp or stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode): # noqa
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if disable_pp or stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
# always return dict for intermediate stage
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def gemma2_for_causal_lm_forward(
|
||||
self: Gemma2ForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
logger = logging.get_logger(__name__)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output:
|
||||
# Split labels in a zigzag fashion too
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
if attention_mask.bool().all():
|
||||
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
|
||||
else:
|
||||
# [B, max_seqlen // sp_size]
|
||||
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = Gemma2PipelineForwards.gemma2_model_forward(
|
||||
self.model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
force_sp_gather=False,
|
||||
)
|
||||
past_key_values = None
|
||||
|
||||
disable_pp = stage_manager is None
|
||||
if disable_pp or stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
@ -141,7 +141,9 @@ class LlamaPipelineForwards:
|
||||
invert=(sp_mode != "ring_attn"),
|
||||
)
|
||||
else:
|
||||
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
||||
attn_kwargs: torch.Tensor = self._update_causal_mask(
|
||||
attention_mask, hidden_states, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
|
||||
# Support SP + PP. Later stages have already received the split input.
|
||||
split_input = disable_pp or stage_manager.is_first_stage()
|
||||
|
@ -227,6 +227,10 @@ _POLICY_LIST = {
|
||||
"transformers.models.cohere.modeling_cohere.CohereForCausalLM": PolicyLocation(
|
||||
file_name="command", class_name="CommandForCausalLMPolicy"
|
||||
),
|
||||
# gemma2
|
||||
"transformers.models.gemma2.modeling_gemma2.Gemma2ForCausalLM": PolicyLocation(
|
||||
file_name="gemma2", class_name="Gemma2ForCausalLMPolicy"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
153
colossalai/shardformer/policies/gemma2.py
Normal file
153
colossalai/shardformer/policies/gemma2.py
Normal file
@ -0,0 +1,153 @@
|
||||
from functools import partial
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.shardformer.layer import (
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
PaddingEmbedding,
|
||||
PaddingLMHead,
|
||||
RMSNorm,
|
||||
VocabParallelEmbedding1D,
|
||||
VocabParallelLMHead1D,
|
||||
)
|
||||
|
||||
from ..modeling.gemma2 import Gemma2PipelineForwards
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["Gemma2Policy", "Gemma2ForCausalLMPolicy"]
|
||||
|
||||
|
||||
class Gemma2Policy(Policy):
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer, Gemma2Model
|
||||
|
||||
policy = {}
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = PaddingEmbedding
|
||||
|
||||
norm_cls = RMSNorm
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
tp_size = self.shard_config.tensor_parallel_size
|
||||
num_q_heads = self.model.config.num_attention_heads // tp_size
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // tp_size,
|
||||
"self_attn.num_heads": num_q_heads,
|
||||
}
|
||||
num_kv_heads = self.model.config.num_key_value_heads // tp_size
|
||||
decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads
|
||||
policy[Gemma2DecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="mlp.gate_proj", target_module=Linear1D_Col),
|
||||
SubModuleReplacementDescription(suffix="mlp.up_proj", target_module=Linear1D_Col),
|
||||
SubModuleReplacementDescription(suffix="mlp.down_proj", target_module=Linear1D_Row),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=Gemma2Model,
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(suffix="input_layernorm", target_module=norm_cls),
|
||||
SubModuleReplacementDescription(suffix="pre_feedforward_layernorm", target_module=norm_cls),
|
||||
SubModuleReplacementDescription(suffix="post_feedforward_layernorm", target_module=norm_cls),
|
||||
SubModuleReplacementDescription(suffix="post_attention_layernorm", target_module=norm_cls),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=Gemma2DecoderLayer,
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=Gemma2Model,
|
||||
)
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
||||
class Gemma2ForCausalLMPolicy(Gemma2Policy):
|
||||
def module_policy(self):
|
||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=VocabParallelLMHead1D,
|
||||
kwargs=dict(
|
||||
gather_output=not self.shard_config.parallel_output,
|
||||
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=Gemma2ForCausalLM,
|
||||
)
|
||||
if self.shard_config.parallel_output:
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
Gemma2PipelineForwards.gemma2_for_causal_lm_forward, shard_config=self.shard_config
|
||||
)
|
||||
}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=Gemma2ForCausalLM
|
||||
)
|
||||
else:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=PaddingLMHead,
|
||||
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=Gemma2ForCausalLM,
|
||||
)
|
||||
|
||||
return policy
|
Loading…
Reference in New Issue
Block a user