[pipeline] 1f1b schedule receive microbatch size (#4589)

This commit is contained in:
Hongxin Liu
2023-09-01 21:45:14 +08:00
committed by GitHub
parent 38ccb8b1a3
commit 508ca36fe3
3 changed files with 30 additions and 7 deletions

View File

@@ -17,14 +17,26 @@ from .base import PipelineSchedule
class OneForwardOneBackwardSchedule(PipelineSchedule):
def __init__(self, num_microbatches: int, stage_manager: PipelineStageManager) -> None:
def __init__(self,
stage_manager: PipelineStageManager,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None) -> None:
"""1F1B pipeline schedule.
Args:
stage_manager (PipelineStageManager): Pipeline stage manager
num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None.
microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None.
"""
super().__init__(stage_manager)
assert num_microbatches is not None or microbatch_size is not None, \
"Either num_microbatches or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches
self.microbatch_size = microbatch_size
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self.microbatch_size: Optional[int] = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
@@ -39,9 +51,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.batch = batch
self.batch_size = get_batch_size(batch)
self.microbatch_offset = 0
assert self.batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
if self.num_microbatches is not None:
assert self.batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
else:
assert self.batch_size % self.microbatch_size == 0, \
"Batch size should divided by the microbatch size"
self.num_microbatches = self.batch_size // self.microbatch_size
def load_micro_batch(self) -> Any:
"""Load a micro batch from the current batch.