[pipeline] refactor pipeline (#679)

* refactor pipeline---put runtime schedule into engine.

* add type hint for schedule Optional[BaseSchedule]

* preprocess schedule during engine initializing

* infer pipeline schedule params from config
This commit is contained in:
YuliangLiu0306
2022-04-07 15:54:14 +08:00
committed by GitHub
parent eace69387d
commit 0ed7042f42
4 changed files with 36 additions and 9 deletions

View File

@@ -20,7 +20,7 @@ from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.engine import Engine
@@ -391,14 +391,18 @@ def initialize(model: nn.Module,
# initialize schedule for engine
if is_using_pp():
tensor_shape = getattr(gpc.config, 'TENSOR_SHAPE', None)
tensor_shape = get_tensor_shape()
use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
if gpc.is_initialized(ParallelMode.PARALLEL_1D):
scatter_gather = True
else:
scatter_gather = False
if use_interleaved:
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=True)
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather)
else:
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
tensor_shape=tensor_shape, scatter_gather_tensors=True)
tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather)
else:
schedule = NonPipelineSchedule()