diff --git a/colossalai/engine/schedule/__init__.py b/colossalai/engine/schedule/__init__.py index 36472413e..54170286e 100644 --- a/colossalai/engine/schedule/__init__.py +++ b/colossalai/engine/schedule/__init__.py @@ -1,5 +1,5 @@ from ._base_schedule import BaseSchedule -from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule +from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape from ._non_pipeline_schedule import NonPipelineSchedule -__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule'] +__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape'] diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 1226f8323..ae55f091b 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -16,6 +16,29 @@ from colossalai.zero.sharded_model import ShardedModelV2 from ._base_schedule import BaseSchedule +def get_tensor_shape(): + if hasattr(gpc.config, 'TENSOR_SHAPE'): + return gpc.config.TENSOR_SHAPE + + if not gpc.is_initialized(ParallelMode.PIPELINE): + return None + + if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'): + if gpc.is_initialized(ParallelMode.DATA): + dp_size = gpc.get_world_size(ParallelMode.DATA) + else: + dp_size = 1 + if gpc.is_initialized(ParallelMode.SEQUENCE): + seq_size = gpc.get_world_size(ParallelMode.SEQUENCE) + else: + seq_size = 1 + + tensor_shape = (gpc.config.SEQ_LENGTH // seq_size, + gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, + gpc.config.HIDDEN_SIZE) + return tensor_shape + else: + return None def pack_return_tensors(return_tensors): output, label = tuple(zip(*return_tensors)) diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 8521bd7d9..9435b37e3 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -20,7 +20,7 @@ from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.builder.builder import build_gradient_handler from colossalai.context import Config, ConfigException, ParallelMode from colossalai.core import global_context as gpc -from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule +from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape from colossalai.context.moe_context import MOE_CONTEXT from colossalai.engine import Engine @@ -391,14 +391,18 @@ def initialize(model: nn.Module, # initialize schedule for engine if is_using_pp(): - tensor_shape = getattr(gpc.config, 'TENSOR_SHAPE', None) + tensor_shape = get_tensor_shape() use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks') + if gpc.is_initialized(ParallelMode.PARALLEL_1D): + scatter_gather = True + else: + scatter_gather = False if use_interleaved: schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=True) + gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather) else: schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - tensor_shape=tensor_shape, scatter_gather_tensors=True) + tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather) else: schedule = NonPipelineSchedule() diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/trainer/hooks/_metric_hook.py index 70da89fff..65693732a 100644 --- a/colossalai/trainer/hooks/_metric_hook.py +++ b/colossalai/trainer/hooks/_metric_hook.py @@ -312,7 +312,7 @@ class AccuracyHook(MetricHook): def after_test_iter(self, trainer, logits, targets, *args): if self._is_stage_to_compute: - batch_size = trainer.schedule.batch_size + batch_size = trainer.engine.schedule.batch_size self.metric.update(logits, targets, batch_size) @@ -392,7 +392,7 @@ class ThroughputHook(MetricHook): def after_train_iter(self, trainer, *args): if self._is_stage_to_compute: - self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time()) + self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time()) def before_test(self, trainer): if self._is_stage_to_compute: @@ -400,4 +400,4 @@ class ThroughputHook(MetricHook): def after_test_iter(self, trainer, *args): if self._is_stage_to_compute: - self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time()) + self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())