mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 20:46:00 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -6,8 +6,6 @@ from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from colossalai.shardformer.layer import (
|
||||
FusedLayerNorm,
|
||||
LayerNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
LinearWithGradAccum,
|
||||
@@ -38,11 +36,7 @@ class CommandPolicy(Policy):
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
CohereAttention,
|
||||
CohereDecoderLayer,
|
||||
CohereModel,
|
||||
)
|
||||
from transformers.models.cohere.modeling_cohere import CohereAttention, CohereDecoderLayer, CohereModel
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": CohereAttention,
|
||||
@@ -58,11 +52,11 @@ class CommandPolicy(Policy):
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = PaddingEmbedding
|
||||
|
||||
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||
sp_size = self.shard_config.sequence_parallel_size or None
|
||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||
sp_mode in ["split_gather", "ring"]
|
||||
if sp_mode == "ring_attn" and not self.is_causal:
|
||||
raise ValueError("Ring attention is only meant for causal language modeling.")
|
||||
|
||||
|
Reference in New Issue
Block a user