mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-14 06:05:26 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
a9bb7cb943
commit
06724492ca
@ -1,10 +1,10 @@
|
|||||||
import math
|
import math
|
||||||
import warnings
|
from typing import List, Optional, Tuple, Union
|
||||||
from typing import List, Optional, Tuple, Union, Callable
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.cache_utils import Cache, DynamicCache
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from transformers.models.cohere.modeling_cohere import (
|
from transformers.models.cohere.modeling_cohere import (
|
||||||
CohereForCausalLM,
|
CohereForCausalLM,
|
||||||
@ -13,13 +13,8 @@ from transformers.models.cohere.modeling_cohere import (
|
|||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
from transformers.utils import logging
|
|
||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
from transformers.utils import logging
|
||||||
from transformers.models.cohere.modeling_cohere import eager_attention_forward
|
|
||||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
|
from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
|
||||||
@ -34,6 +29,7 @@ _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CommandPipelineForwards:
|
class CommandPipelineForwards:
|
||||||
"""
|
"""
|
||||||
This class serves as a micro library for forward function substitution of Command models
|
This class serves as a micro library for forward function substitution of Command models
|
||||||
@ -200,7 +196,7 @@ class CommandPipelineForwards:
|
|||||||
output_attentions,
|
output_attentions,
|
||||||
use_cache,
|
use_cache,
|
||||||
cache_position,
|
cache_position,
|
||||||
position_embeddings
|
position_embeddings,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
@ -211,7 +207,7 @@ class CommandPipelineForwards:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
@ -348,7 +344,6 @@ class CommandPipelineForwards:
|
|||||||
return {"hidden_states": hidden_states}
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -457,6 +452,7 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
|||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@ -553,7 +549,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
|||||||
output_attentions,
|
output_attentions,
|
||||||
use_cache,
|
use_cache,
|
||||||
cache_position,
|
cache_position,
|
||||||
position_embeddings
|
position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -565,7 +561,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
@ -6,8 +6,6 @@ from torch import Tensor
|
|||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
from colossalai.shardformer.layer import (
|
from colossalai.shardformer.layer import (
|
||||||
FusedLayerNorm,
|
|
||||||
LayerNorm,
|
|
||||||
Linear1D_Col,
|
Linear1D_Col,
|
||||||
Linear1D_Row,
|
Linear1D_Row,
|
||||||
LinearWithGradAccum,
|
LinearWithGradAccum,
|
||||||
@ -38,11 +36,7 @@ class CommandPolicy(Policy):
|
|||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
from transformers.models.cohere.modeling_cohere import (
|
from transformers.models.cohere.modeling_cohere import CohereAttention, CohereDecoderLayer, CohereModel
|
||||||
CohereAttention,
|
|
||||||
CohereDecoderLayer,
|
|
||||||
CohereModel,
|
|
||||||
)
|
|
||||||
|
|
||||||
ATTN_IMPLEMENTATION = {
|
ATTN_IMPLEMENTATION = {
|
||||||
"eager": CohereAttention,
|
"eager": CohereAttention,
|
||||||
@ -62,7 +56,7 @@ class CommandPolicy(Policy):
|
|||||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
sp_size = self.shard_config.sequence_parallel_size or None
|
sp_size = self.shard_config.sequence_parallel_size or None
|
||||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
sp_mode in ["split_gather", "ring"]
|
||||||
if sp_mode == "ring_attn" and not self.is_causal:
|
if sp_mode == "ring_attn" and not self.is_causal:
|
||||||
raise ValueError("Ring attention is only meant for causal language modeling.")
|
raise ValueError("Ring attention is only meant for causal language modeling.")
|
||||||
|
|
||||||
|
@ -213,7 +213,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
|
Loading…
Reference in New Issue
Block a user