mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
a9bb7cb943
commit
06724492ca
@ -1,10 +1,10 @@
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union, Callable
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
CohereForCausalLM,
|
||||
@ -13,13 +13,8 @@ from transformers.models.cohere.modeling_cohere import (
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.models.cohere.modeling_cohere import eager_attention_forward
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from functools import partial
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
|
||||
@ -34,6 +29,7 @@ _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class CommandPipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of Command models
|
||||
@ -168,7 +164,7 @@ class CommandPipelineForwards:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
@ -200,7 +196,7 @@ class CommandPipelineForwards:
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@ -211,7 +207,7 @@ class CommandPipelineForwards:
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@ -348,7 +344,6 @@ class CommandPipelineForwards:
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
|
||||
def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||
def forward(
|
||||
self,
|
||||
@ -370,22 +365,22 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
||||
# sp: modify sp_len when sequence parallel mode is ring
|
||||
if sp_mode in ["split_gather", "ring"]:
|
||||
q_len *= sp_size
|
||||
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
|
||||
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
@ -409,7 +404,7 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attn_weights = None
|
||||
|
||||
|
||||
if shard_config.enable_flash_attention:
|
||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||
@ -452,11 +447,12 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -537,7 +533,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
@ -553,7 +549,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings
|
||||
position_embeddings,
|
||||
)
|
||||
|
||||
else:
|
||||
@ -565,7 +561,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -6,8 +6,6 @@ from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from colossalai.shardformer.layer import (
|
||||
FusedLayerNorm,
|
||||
LayerNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
LinearWithGradAccum,
|
||||
@ -38,11 +36,7 @@ class CommandPolicy(Policy):
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
CohereAttention,
|
||||
CohereDecoderLayer,
|
||||
CohereModel,
|
||||
)
|
||||
from transformers.models.cohere.modeling_cohere import CohereAttention, CohereDecoderLayer, CohereModel
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": CohereAttention,
|
||||
@ -58,11 +52,11 @@ class CommandPolicy(Policy):
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = PaddingEmbedding
|
||||
|
||||
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||
sp_size = self.shard_config.sequence_parallel_size or None
|
||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||
sp_mode in ["split_gather", "ring"]
|
||||
if sp_mode == "ring_attn" and not self.is_causal:
|
||||
raise ValueError("Ring attention is only meant for causal language modeling.")
|
||||
|
||||
|
@ -213,7 +213,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
@ -365,4 +364,4 @@ def test_command_3d():
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_command()
|
||||
test_command_3d()
|
||||
test_command_3d()
|
||||
|
Loading…
Reference in New Issue
Block a user