mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 13:41:43 +00:00
fix
This commit is contained in:
parent
06724492ca
commit
e78c4560c6
@ -38,6 +38,7 @@ class CommandPolicy(Policy):
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.cohere.modeling_cohere import CohereAttention, CohereDecoderLayer, CohereModel
|
||||
|
||||
# The eager, flash_attention_2, sdpa will all be passed to CohereAttention in v4.51.3 transformers.
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": CohereAttention,
|
||||
"flash_attention_2": CohereAttention,
|
||||
@ -53,10 +54,11 @@ class CommandPolicy(Policy):
|
||||
if self.tie_weight:
|
||||
embedding_cls = PaddingEmbedding
|
||||
|
||||
# CohereLayerNorm has no bias in v4.51.3 transformers, so we don't replace it.
|
||||
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||
sp_size = self.shard_config.sequence_parallel_size or None
|
||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||
sp_mode in ["split_gather", "ring"]
|
||||
if sp_mode == "ring_attn" and not self.is_causal:
|
||||
raise ValueError("Ring attention is only meant for causal language modeling.")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user