[shardformer] refactor pipeline grad ckpt config (#5646)

* [shardformer] refactor pipeline grad ckpt config

* [shardformer] refactor pipeline grad ckpt config

* [pipeline] fix stage manager
This commit is contained in:
Hongxin Liu
2024-04-25 15:19:30 +08:00
committed by GitHub
parent 7ef91606e1
commit 1b387ca9fe
11 changed files with 59 additions and 102 deletions

View File

@@ -14,6 +14,7 @@ class _PipelineStageManager(PipelineStageManager):
def __init__(self):
self.is_interleave = False
self.num_layers_per_stage = None
self.num_model_chunks = 1
@property
def num_stages(self):

View File

@@ -14,6 +14,7 @@ class _PipelineStageManager(PipelineStageManager):
def __init__(self):
self.is_interleave = False
self.num_layers_per_stage = None
self.num_model_chunks = 1
@property
def num_stages(self):