From 0a8f3c851ab5a658869defa81227ea562eda1a30 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Thu, 20 Jul 2023 17:21:28 +0800 Subject: [PATCH] [hotfix] fix opt pipeline (#4293) * opt forward and test * pause * finish opt model pipeline * finish opt pipeline * opt forward and test * pause * finish opt model pipeline * finish opt pipeline * fix opt * set transformers version * refactor the test pipeline * fix bug --- colossalai/shardformer/policies/opt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 31934965e..244a0a54e 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -12,6 +12,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from .._utils import getattr_, setattr_ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -198,8 +199,8 @@ class OPTForCausalLMPolicy(OPTPolicy): def get_shared_params(self) -> List[Dict[int, Tensor]]: opt_model = self.model - num_stages = self.pipeline_stage_manager.num_stages - if self.pipeline_stage_manager and num_stages > 1: + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + num_stages = self.pipeline_stage_manager.num_stages if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight): return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}]