fix precommit

This commit is contained in:
GuangyaoZhang
2024-06-14 08:09:24 +00:00
parent 98da648a4a
commit fe2e74c03a
7 changed files with 35 additions and 86 deletions

View File

@@ -7,12 +7,12 @@ from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import (
CohereLayerNorm,
FusedCohereLayerNorm,
Linear1D_Col,
Linear1D_Row,
PaddingEmbedding,
PaddingLMHead,
CohereLayerNorm,
VocabParallelEmbedding1D,
VocabParallelLMHead1D,
)
@@ -383,7 +383,9 @@ class CommandForCausalLMPolicy(CommandPolicy):
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=CohereForCausalLM, new_forward=CommandPipelineForwards.command_for_causal_lm_forward, policy=policy
model_cls=CohereForCausalLM,
new_forward=CommandPipelineForwards.command_for_causal_lm_forward,
policy=policy,
)
return policy
@@ -410,4 +412,4 @@ class CommandForCausalLMPolicy(CommandPolicy):
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
return []
return []