fix layers/schedule for hybrid parallelization (#111) (#112)

This commit is contained in:
ver217
2022-01-04 20:52:31 +08:00
committed by GitHub
parent f03bcb359b
commit 7904baf6e1
6 changed files with 44 additions and 18 deletions

View File

@@ -48,9 +48,13 @@ class PipelineSchedule(BaseSchedule):
# Pipeline schedule just puts data in memory
self.batch_data, self.batch_label = super().load_batch(data_iter, to_gpu=False)
self.microbatch_offset = 0
assert self.batch_size % self.num_microbatches == 0, \
if isinstance(self.batch_data, torch.Tensor):
batch_size = self.batch_data.size(0)
else:
batch_size = next(iter(self.batch_data.values())).size(0)
assert batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
self.microbatch_size = batch_size // self.num_microbatches
def _get_data_slice(self, data, offset):
if isinstance(data, torch.Tensor):