mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-01 17:52:05 +00:00
[hotfix] fix opt pipeline (#4293)
* opt forward and test * pause * finish opt model pipeline * finish opt pipeline * opt forward and test * pause * finish opt model pipeline * finish opt pipeline * fix opt * set transformers version * refactor the test pipeline * fix bug
This commit is contained in:
parent
d8408d185c
commit
0a8f3c851a
@ -12,6 +12,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||||
|
|
||||||
|
from .._utils import getattr_, setattr_
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -198,8 +199,8 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
|||||||
|
|
||||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
opt_model = self.model
|
opt_model = self.model
|
||||||
num_stages = self.pipeline_stage_manager.num_stages
|
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||||
if self.pipeline_stage_manager and num_stages > 1:
|
num_stages = self.pipeline_stage_manager.num_stages
|
||||||
if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight):
|
if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight):
|
||||||
return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}]
|
return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user