diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 31934965e..244a0a54e 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -12,6 +12,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from .._utils import getattr_, setattr_ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -198,8 +199,8 @@ class OPTForCausalLMPolicy(OPTPolicy): def get_shared_params(self) -> List[Dict[int, Tensor]]: opt_model = self.model - num_stages = self.pipeline_stage_manager.num_stages - if self.pipeline_stage_manager and num_stages > 1: + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + num_stages = self.pipeline_stage_manager.num_stages if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight): return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}]