mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 09:59:38 +00:00
@@ -48,9 +48,13 @@ class PipelineSchedule(BaseSchedule):
|
||||
# Pipeline schedule just puts data in memory
|
||||
self.batch_data, self.batch_label = super().load_batch(data_iter, to_gpu=False)
|
||||
self.microbatch_offset = 0
|
||||
assert self.batch_size % self.num_microbatches == 0, \
|
||||
if isinstance(self.batch_data, torch.Tensor):
|
||||
batch_size = self.batch_data.size(0)
|
||||
else:
|
||||
batch_size = next(iter(self.batch_data.values())).size(0)
|
||||
assert batch_size % self.num_microbatches == 0, \
|
||||
"Batch size should divided by the number of microbatches"
|
||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||
self.microbatch_size = batch_size // self.num_microbatches
|
||||
|
||||
def _get_data_slice(self, data, offset):
|
||||
if isinstance(data, torch.Tensor):
|
||||
|
Reference in New Issue
Block a user