mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 20:46:00 +00:00
fix precommit
This commit is contained in:
@@ -7,12 +7,12 @@ from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from colossalai.shardformer.layer import (
|
||||
CohereLayerNorm,
|
||||
FusedCohereLayerNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
PaddingEmbedding,
|
||||
PaddingLMHead,
|
||||
CohereLayerNorm,
|
||||
VocabParallelEmbedding1D,
|
||||
VocabParallelLMHead1D,
|
||||
)
|
||||
@@ -383,7 +383,9 @@ class CommandForCausalLMPolicy(CommandPolicy):
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=CohereForCausalLM, new_forward=CommandPipelineForwards.command_for_causal_lm_forward, policy=policy
|
||||
model_cls=CohereForCausalLM,
|
||||
new_forward=CommandPipelineForwards.command_for_causal_lm_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
return policy
|
||||
@@ -410,4 +412,4 @@ class CommandForCausalLMPolicy(CommandPolicy):
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
return []
|
||||
return []
|
||||
|
Reference in New Issue
Block a user