mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 13:41:43 +00:00
upgrade command
This commit is contained in:
parent
46ed5d856b
commit
a9bb7cb943
@ -1,9 +1,8 @@
|
|||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union, Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.cache_utils import Cache, DynamicCache
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
@ -15,6 +14,12 @@ from transformers.models.cohere.modeling_cohere import (
|
|||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
from transformers.processing_utils import Unpack
|
||||||
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from transformers.models.cohere.modeling_cohere import eager_attention_forward
|
||||||
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
|
from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
|
||||||
@ -27,6 +32,7 @@ _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring"]
|
|||||||
|
|
||||||
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
class CommandPipelineForwards:
|
class CommandPipelineForwards:
|
||||||
"""
|
"""
|
||||||
@ -37,22 +43,23 @@ class CommandPipelineForwards:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def command_model_forward(
|
def command_model_forward(
|
||||||
self: CohereModel,
|
self: CohereModel,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Cache] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
shard_config: ShardConfig = None,
|
shard_config: ShardConfig = None,
|
||||||
force_sp_output_gather: bool = True,
|
force_sp_output_gather: bool = True,
|
||||||
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||||
):
|
):
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
@ -67,8 +74,6 @@ class CommandPipelineForwards:
|
|||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
# retrieve input_ids and inputs_embeds
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
@ -133,7 +138,7 @@ class CommandPipelineForwards:
|
|||||||
is_causal=True,
|
is_causal=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training and use_cache:
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
@ -164,6 +169,8 @@ class CommandPipelineForwards:
|
|||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = None
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
num_ckpt_layers = 0
|
num_ckpt_layers = 0
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
@ -193,6 +200,7 @@ class CommandPipelineForwards:
|
|||||||
output_attentions,
|
output_attentions,
|
||||||
use_cache,
|
use_cache,
|
||||||
cache_position,
|
cache_position,
|
||||||
|
position_embeddings
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
@ -203,6 +211,7 @@ class CommandPipelineForwards:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
@ -224,17 +233,6 @@ class CommandPipelineForwards:
|
|||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
if not return_dict:
|
|
||||||
return tuple(
|
|
||||||
v
|
|
||||||
for v in [
|
|
||||||
hidden_states,
|
|
||||||
next_cache,
|
|
||||||
all_hidden_states,
|
|
||||||
all_self_attns,
|
|
||||||
]
|
|
||||||
if v is not None
|
|
||||||
)
|
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=next_cache,
|
past_key_values=next_cache,
|
||||||
@ -350,29 +348,25 @@ class CommandPipelineForwards:
|
|||||||
return {"hidden_states": hidden_states}
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if sp_mode is not None:
|
if sp_mode is not None:
|
||||||
assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet"
|
assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet"
|
||||||
assert (sp_size is not None) and (
|
assert (sp_size is not None) and (
|
||||||
sp_group is not None
|
sp_group is not None
|
||||||
), "Must specify sp_size and sp_group for sequence parallel"
|
), "Must specify sp_size and sp_group for sequence parallel"
|
||||||
if "padding_mask" in kwargs:
|
|
||||||
warnings.warn(
|
|
||||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
|
||||||
)
|
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
# sp: modify sp_len when sequence parallel mode is ring
|
# sp: modify sp_len when sequence parallel mode is ring
|
||||||
if sp_mode in ["split_gather", "ring"]:
|
if sp_mode in ["split_gather", "ring"]:
|
||||||
q_len *= sp_size
|
q_len *= sp_size
|
||||||
@ -388,9 +382,9 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
|||||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
bsz, q_len, _ = query_states.size()
|
bsz, q_len, _ = query_states.size()
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@ -403,7 +397,8 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
|||||||
|
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@ -413,11 +408,14 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
|||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
if shard_config.enable_flash_attention:
|
if shard_config.enable_flash_attention:
|
||||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||||
else:
|
else:
|
||||||
|
# to be fixed:
|
||||||
|
# precision issue
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
@ -451,40 +449,36 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
|||||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
if not output_attentions:
|
return attn_output, attn_weights
|
||||||
attn_weights = None
|
|
||||||
return attn_output, attn_weights, past_key_value
|
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Cache] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
force_sp_output_gather: bool = True,
|
force_sp_output_gather: bool = True,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||||
|
) -> BaseModelOutputWithPast:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
# retrieve input_ids and inputs_embeds
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
@ -527,7 +521,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
|||||||
is_causal=True,
|
is_causal=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
|
||||||
|
|
||||||
if sp_mode in ["ring", "split_gather"]:
|
if sp_mode in ["ring", "split_gather"]:
|
||||||
inputs_embeds = split_forward_gather_backward(
|
inputs_embeds = split_forward_gather_backward(
|
||||||
@ -544,6 +538,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
|||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = None
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
for decoder_layer in self.layers:
|
for decoder_layer in self.layers:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
@ -557,6 +553,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
|||||||
output_attentions,
|
output_attentions,
|
||||||
use_cache,
|
use_cache,
|
||||||
cache_position,
|
cache_position,
|
||||||
|
position_embeddings
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -568,16 +565,11 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attns += (layer_outputs[1],)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
# Cases that don't support parallelizing cross entropy computation along sequence
|
# Cases that don't support parallelizing cross entropy computation along sequence
|
||||||
@ -594,8 +586,6 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
|||||||
next_cache = (
|
next_cache = (
|
||||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
||||||
)
|
)
|
||||||
if not return_dict:
|
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
||||||
|
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
|
@ -41,15 +41,13 @@ class CommandPolicy(Policy):
|
|||||||
from transformers.models.cohere.modeling_cohere import (
|
from transformers.models.cohere.modeling_cohere import (
|
||||||
CohereAttention,
|
CohereAttention,
|
||||||
CohereDecoderLayer,
|
CohereDecoderLayer,
|
||||||
CohereFlashAttention2,
|
|
||||||
CohereModel,
|
CohereModel,
|
||||||
CohereSdpaAttention,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ATTN_IMPLEMENTATION = {
|
ATTN_IMPLEMENTATION = {
|
||||||
"eager": CohereAttention,
|
"eager": CohereAttention,
|
||||||
"flash_attention_2": CohereFlashAttention2,
|
"flash_attention_2": CohereAttention,
|
||||||
"sdpa": CohereSdpaAttention,
|
"sdpa": CohereAttention,
|
||||||
}
|
}
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
@ -61,11 +59,6 @@ class CommandPolicy(Policy):
|
|||||||
if self.tie_weight:
|
if self.tie_weight:
|
||||||
embedding_cls = PaddingEmbedding
|
embedding_cls = PaddingEmbedding
|
||||||
|
|
||||||
if self.shard_config.enable_fused_normalization:
|
|
||||||
norm_cls = FusedLayerNorm
|
|
||||||
else:
|
|
||||||
norm_cls = LayerNorm
|
|
||||||
|
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
sp_size = self.shard_config.sequence_parallel_size or None
|
sp_size = self.shard_config.sequence_parallel_size or None
|
||||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||||
@ -280,29 +273,6 @@ class CommandPolicy(Policy):
|
|||||||
target_key=CohereModel,
|
target_key=CohereModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
# optimization configuration
|
|
||||||
self.append_or_create_submodule_replacement(
|
|
||||||
description=[
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="input_layernorm",
|
|
||||||
target_module=norm_cls,
|
|
||||||
kwargs={"sp_partial_derived": sp_partial_derived},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
policy=policy,
|
|
||||||
target_key=CohereDecoderLayer,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.append_or_create_submodule_replacement(
|
|
||||||
description=SubModuleReplacementDescription(
|
|
||||||
suffix="norm",
|
|
||||||
target_module=norm_cls,
|
|
||||||
kwargs={"sp_partial_derived": sp_partial_derived},
|
|
||||||
),
|
|
||||||
policy=policy,
|
|
||||||
target_key=CohereModel,
|
|
||||||
)
|
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
@ -349,6 +319,7 @@ class CommandPolicy(Policy):
|
|||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
|
held_layers.append(module.rotary_emb)
|
||||||
if stage_manager.is_interleave:
|
if stage_manager.is_interleave:
|
||||||
assert stage_manager.num_model_chunks is not None
|
assert stage_manager.num_model_chunks is not None
|
||||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
|
@ -213,17 +213,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"tp_size": 4,
|
|
||||||
"pp_size": 1,
|
|
||||||
"num_microbatches": 1,
|
|
||||||
"enable_sequence_parallelism": True,
|
|
||||||
"sequence_parallelism_mode": "split_gather",
|
|
||||||
"enable_flash_attention": False,
|
|
||||||
"use_lazy_init": True,
|
|
||||||
"precision": "fp16",
|
|
||||||
"initial_scale": 1,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
@ -231,6 +221,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"num_microbatches": 1,
|
"num_microbatches": 1,
|
||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
"sequence_parallelism_mode": "all_to_all",
|
"sequence_parallelism_mode": "all_to_all",
|
||||||
|
"enable_flash_attention": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"zero_stage": 2,
|
"zero_stage": 2,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
@ -241,6 +232,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 2,
|
"num_microbatches": 2,
|
||||||
"enable_all_optimization": True,
|
"enable_all_optimization": True,
|
||||||
|
"enable_flash_attention": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
@ -252,6 +244,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
|
"enable_flash_attention": True,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
"enable_gradient_checkpointing": True,
|
"enable_gradient_checkpointing": True,
|
||||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
|
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
|
||||||
@ -260,6 +253,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"enable_all_optimization": True,
|
"enable_all_optimization": True,
|
||||||
|
"enable_flash_attention": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"zero_stage": 2,
|
"zero_stage": 2,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
@ -270,6 +264,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 2,
|
"num_microbatches": 2,
|
||||||
"enable_all_optimization": True,
|
"enable_all_optimization": True,
|
||||||
|
"enable_flash_attention": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
|
Loading…
Reference in New Issue
Block a user