Merge pull request #6298 from wangbluo/upgrade_command

upgrade command
This commit is contained in:
Hanks 2025-05-22 14:21:58 +08:00 committed by GitHub
commit 6a29abdefd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 134 deletions

View File

@ -1,19 +1,19 @@
import math
import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint
from torch import nn from torch import nn
from transformers.cache_utils import Cache, DynamicCache from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import ( from transformers.models.cohere.modeling_cohere import (
CohereAttention,
CohereForCausalLM, CohereForCausalLM,
CohereModel, CohereModel,
StaticCache, StaticCache,
apply_rotary_pos_emb, apply_rotary_pos_emb,
repeat_kv, repeat_kv,
) )
from transformers.processing_utils import Unpack
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
@ -27,6 +27,8 @@ _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring"]
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
logger = logging.get_logger(__name__)
class CommandPipelineForwards: class CommandPipelineForwards:
""" """
@ -37,22 +39,23 @@ class CommandPipelineForwards:
@staticmethod @staticmethod
def command_model_forward( def command_model_forward(
self: CohereModel, self: CohereModel,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None, shard_config: ShardConfig = None,
force_sp_output_gather: bool = True, force_sp_output_gather: bool = True,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
): ):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@ -67,8 +70,6 @@ class CommandPipelineForwards:
) )
use_cache = False use_cache = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds # retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
@ -122,6 +123,7 @@ class CommandPipelineForwards:
# embed positions, for the first stage, hidden_states is the input embeddings, # embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage # for the other stages, hidden_states is the output of the previous stage
shard_config.enable_flash_attention = True
if shard_config.enable_flash_attention: if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor # in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length, seq_length_with_past) mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
@ -133,7 +135,8 @@ class CommandPipelineForwards:
is_causal=True, is_causal=True,
) )
else: else:
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) # v4.51.3 transformers attention_mask calculation
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values)
if self.gradient_checkpointing and self.training and use_cache: if self.gradient_checkpointing and self.training and use_cache:
if use_cache: if use_cache:
@ -163,6 +166,8 @@ class CommandPipelineForwards:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
next_decoder_cache = None next_decoder_cache = None
# v4.51.3 transformers position_embeddings calculation
position_embeddings = self.rotary_emb(hidden_states, position_ids)
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0 num_ckpt_layers = 0
@ -193,6 +198,7 @@ class CommandPipelineForwards:
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
position_embeddings,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@ -203,6 +209,7 @@ class CommandPipelineForwards:
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -224,17 +231,6 @@ class CommandPipelineForwards:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None next_cache = next_decoder_cache if use_cache else None
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_cache, past_key_values=next_cache,
@ -352,27 +348,22 @@ class CommandPipelineForwards:
def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
def forward( def forward(
self, self: CohereAttention,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if sp_mode is not None: if sp_mode is not None:
assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet" assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet"
assert (sp_size is not None) and ( assert (sp_size is not None) and (
sp_group is not None sp_group is not None
), "Must specify sp_size and sp_group for sequence parallel" ), "Must specify sp_size and sp_group for sequence parallel"
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
# sp: modify sp_len when sequence parallel mode is ring # sp: modify sp_len when sequence parallel mode is ring
if sp_mode in ["split_gather", "ring"]: if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size q_len *= sp_size
@ -388,60 +379,36 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size() bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] cos, sin = position_embeddings
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads # repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = None
shard_config.enable_flash_attention = True
if shard_config.enable_flash_attention: if shard_config.enable_flash_attention:
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
else: else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) # attn_weights and attn_output calculation is modified on the v4.51.3 of transformers.models.cohere.modeling_cohere.CohereAttention.forward.
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None: if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
raise ValueError( attn_weights = attn_weights + causal_mask
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
dropout = 0.0 if not self.training else self.attention_dropout
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
# sp: all-to-all comminucation when introducing sequence parallel # sp: all-to-all comminucation when introducing sequence parallel
@ -451,13 +418,11 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
) )
else: else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
if not output_attentions: return attn_output, attn_weights
attn_weights = None
return attn_output, attn_weights, past_key_value
return forward return forward
@ -467,24 +432,23 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
force_sp_output_gather: bool = True, force_sp_output_gather: bool = True,
) -> Union[Tuple, BaseModelOutputWithPast]: **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds # retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
@ -516,6 +480,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
shard_config.enable_flash_attention = True
# in this case, attention_mask is a dict rather than a tensor # in this case, attention_mask is a dict rather than a tensor
if shard_config.enable_flash_attention: if shard_config.enable_flash_attention:
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
@ -527,7 +493,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
is_causal=True, is_causal=True,
) )
else: else:
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) # v4.51.3 transformers attention_mask calculation
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
if sp_mode in ["ring", "split_gather"]: if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward( inputs_embeds = split_forward_gather_backward(
@ -544,6 +511,9 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
next_decoder_cache = None next_decoder_cache = None
# v4.51.3 transformers position_embeddings calculation
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for decoder_layer in self.layers: for decoder_layer in self.layers:
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
@ -557,8 +527,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
position_embeddings,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
@ -568,16 +538,11 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
# Cases that don't support parallelizing cross entropy computation along sequence # Cases that don't support parallelizing cross entropy computation along sequence
@ -594,8 +559,6 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
next_cache = ( next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
) )
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,

View File

@ -6,8 +6,6 @@ from torch import Tensor
from torch.nn import Module from torch.nn import Module
from colossalai.shardformer.layer import ( from colossalai.shardformer.layer import (
FusedLayerNorm,
LayerNorm,
Linear1D_Col, Linear1D_Col,
Linear1D_Row, Linear1D_Row,
LinearWithGradAccum, LinearWithGradAccum,
@ -38,18 +36,13 @@ class CommandPolicy(Policy):
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.cohere.modeling_cohere import ( from transformers.models.cohere.modeling_cohere import CohereAttention, CohereDecoderLayer, CohereModel
CohereAttention,
CohereDecoderLayer,
CohereFlashAttention2,
CohereModel,
CohereSdpaAttention,
)
# The eager, flash_attention_2, sdpa will all be passed to CohereAttention in v4.51.3 transformers.
ATTN_IMPLEMENTATION = { ATTN_IMPLEMENTATION = {
"eager": CohereAttention, "eager": CohereAttention,
"flash_attention_2": CohereFlashAttention2, "flash_attention_2": CohereAttention,
"sdpa": CohereSdpaAttention, "sdpa": CohereAttention,
} }
policy = {} policy = {}
@ -61,15 +54,11 @@ class CommandPolicy(Policy):
if self.tie_weight: if self.tie_weight:
embedding_cls = PaddingEmbedding embedding_cls = PaddingEmbedding
if self.shard_config.enable_fused_normalization: # CohereLayerNorm has no bias in v4.51.3 transformers, so we don't replace it.
norm_cls = FusedLayerNorm
else:
norm_cls = LayerNorm
sp_mode = self.shard_config.sequence_parallelism_mode or None sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"]
if sp_mode == "ring_attn" and not self.is_causal: if sp_mode == "ring_attn" and not self.is_causal:
raise ValueError("Ring attention is only meant for causal language modeling.") raise ValueError("Ring attention is only meant for causal language modeling.")
@ -86,14 +75,15 @@ class CommandPolicy(Policy):
policy[attn_cls] = ModulePolicyDescription( policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement, attribute_replacement=decoder_attribute_replacement,
) )
self.append_or_create_method_replacement(
description={
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
)
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
self.append_or_create_method_replacement(
description={
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
)
if self.pipeline_stage_manager is None: if self.pipeline_stage_manager is None:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={ description={
@ -280,29 +270,6 @@ class CommandPolicy(Policy):
target_key=CohereModel, target_key=CohereModel,
) )
# optimization configuration
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=norm_cls,
kwargs={"sp_partial_derived": sp_partial_derived},
),
],
policy=policy,
target_key=CohereDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=norm_cls,
kwargs={"sp_partial_derived": sp_partial_derived},
),
policy=policy,
target_key=CohereModel,
)
return policy return policy
def postprocess(self): def postprocess(self):
@ -349,6 +316,7 @@ class CommandPolicy(Policy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
held_layers.append(module.rotary_emb)
if stage_manager.is_interleave: if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers)) layers_per_stage = stage_manager.distribute_layers(len(module.layers))

View File

@ -218,7 +218,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"pp_size": 1, "pp_size": 1,
"num_microbatches": 1, "num_microbatches": 1,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather", "sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": False, "enable_flash_attention": False,
"use_lazy_init": True, "use_lazy_init": True,
"precision": "fp16", "precision": "fp16",