mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 18:09:06 +00:00
[pipeline] refactor pipeline (#679)
* refactor pipeline---put runtime schedule into engine. * add type hint for schedule Optional[BaseSchedule] * preprocess schedule during engine initializing * infer pipeline schedule params from config
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from ._base_schedule import BaseSchedule
|
||||
from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule
|
||||
from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
|
||||
from ._non_pipeline_schedule import NonPipelineSchedule
|
||||
|
||||
__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule']
|
||||
__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape']
|
||||
|
@@ -16,6 +16,29 @@ from colossalai.zero.sharded_model import ShardedModelV2
|
||||
|
||||
from ._base_schedule import BaseSchedule
|
||||
|
||||
def get_tensor_shape():
|
||||
if hasattr(gpc.config, 'TENSOR_SHAPE'):
|
||||
return gpc.config.TENSOR_SHAPE
|
||||
|
||||
if not gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
return None
|
||||
|
||||
if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'):
|
||||
if gpc.is_initialized(ParallelMode.DATA):
|
||||
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
else:
|
||||
dp_size = 1
|
||||
if gpc.is_initialized(ParallelMode.SEQUENCE):
|
||||
seq_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
else:
|
||||
seq_size = 1
|
||||
|
||||
tensor_shape = (gpc.config.SEQ_LENGTH // seq_size,
|
||||
gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES,
|
||||
gpc.config.HIDDEN_SIZE)
|
||||
return tensor_shape
|
||||
else:
|
||||
return None
|
||||
|
||||
def pack_return_tensors(return_tensors):
|
||||
output, label = tuple(zip(*return_tensors))
|
||||
|
Reference in New Issue
Block a user