mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[pipeline] refactor pipeline (#679)
* refactor pipeline---put runtime schedule into engine. * add type hint for schedule Optional[BaseSchedule] * preprocess schedule during engine initializing * infer pipeline schedule params from config
This commit is contained in:
@@ -20,7 +20,7 @@ from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.builder.builder import build_gradient_handler
|
||||
from colossalai.context import Config, ConfigException, ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
|
||||
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
|
||||
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.engine import Engine
|
||||
@@ -391,14 +391,18 @@ def initialize(model: nn.Module,
|
||||
|
||||
# initialize schedule for engine
|
||||
if is_using_pp():
|
||||
tensor_shape = getattr(gpc.config, 'TENSOR_SHAPE', None)
|
||||
tensor_shape = get_tensor_shape()
|
||||
use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
|
||||
if gpc.is_initialized(ParallelMode.PARALLEL_1D):
|
||||
scatter_gather = True
|
||||
else:
|
||||
scatter_gather = False
|
||||
if use_interleaved:
|
||||
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
||||
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=True)
|
||||
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather)
|
||||
else:
|
||||
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
||||
tensor_shape=tensor_shape, scatter_gather_tensors=True)
|
||||
tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather)
|
||||
else:
|
||||
schedule = NonPipelineSchedule()
|
||||
|
||||
|
Reference in New Issue
Block a user