mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[pipeline] 1f1b schedule receive microbatch size (#4589)
This commit is contained in:
@@ -17,14 +17,26 @@ from .base import PipelineSchedule
|
||||
|
||||
class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
|
||||
def __init__(self, num_microbatches: int, stage_manager: PipelineStageManager) -> None:
|
||||
def __init__(self,
|
||||
stage_manager: PipelineStageManager,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None) -> None:
|
||||
"""1F1B pipeline schedule.
|
||||
|
||||
Args:
|
||||
stage_manager (PipelineStageManager): Pipeline stage manager
|
||||
num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None.
|
||||
microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None.
|
||||
"""
|
||||
super().__init__(stage_manager)
|
||||
assert num_microbatches is not None or microbatch_size is not None, \
|
||||
"Either num_microbatches or microbatch_size should be provided"
|
||||
self.comm = PipelineP2PCommunication(stage_manager)
|
||||
self.num_microbatches = num_microbatches
|
||||
self.microbatch_size = microbatch_size
|
||||
self.batch: Optional[Any] = None
|
||||
self.batch_size: Optional[int] = None
|
||||
self.microbatch_offset: Optional[int] = None
|
||||
self.microbatch_size: Optional[int] = None
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
@@ -39,9 +51,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
self.batch = batch
|
||||
self.batch_size = get_batch_size(batch)
|
||||
self.microbatch_offset = 0
|
||||
assert self.batch_size % self.num_microbatches == 0, \
|
||||
"Batch size should divided by the number of microbatches"
|
||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||
if self.num_microbatches is not None:
|
||||
assert self.batch_size % self.num_microbatches == 0, \
|
||||
"Batch size should divided by the number of microbatches"
|
||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||
else:
|
||||
assert self.batch_size % self.microbatch_size == 0, \
|
||||
"Batch size should divided by the microbatch size"
|
||||
self.num_microbatches = self.batch_size // self.microbatch_size
|
||||
|
||||
def load_micro_batch(self) -> Any:
|
||||
"""Load a micro batch from the current batch.
|
||||
|
Reference in New Issue
Block a user