[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-05-08 08:13:33 +00:00
parent a9bb7cb943
commit 06724492ca
3 changed files with 21 additions and 32 deletions

View File

@@ -6,8 +6,6 @@ from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import (
FusedLayerNorm,
LayerNorm,
Linear1D_Col,
Linear1D_Row,
LinearWithGradAccum,
@@ -38,11 +36,7 @@ class CommandPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.cohere.modeling_cohere import (
CohereAttention,
CohereDecoderLayer,
CohereModel,
)
from transformers.models.cohere.modeling_cohere import CohereAttention, CohereDecoderLayer, CohereModel
ATTN_IMPLEMENTATION = {
"eager": CohereAttention,
@@ -58,11 +52,11 @@ class CommandPolicy(Policy):
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"]
sp_mode in ["split_gather", "ring"]
if sp_mode == "ring_attn" and not self.is_causal:
raise ValueError("Ring attention is only meant for causal language modeling.")