diff --git a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py b/colossalai/engine/gradient_accumulation/_gradient_accumulation.py index 5bfe3f449..89c28c3be 100644 --- a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py +++ b/colossalai/engine/gradient_accumulation/_gradient_accumulation.py @@ -145,6 +145,7 @@ class GradAccumDataloader: def __next__(self) -> Union[Tensor, Tuple[Tensor]]: if self._cur_step < self.steps_per_epoch: self._cur_step += 1 + data = next(self._dataiter) if self._cur_step == self.steps_per_epoch and self.consume_remain_data: # this is to handle non standard pytorch dataloader @@ -154,7 +155,7 @@ class GradAccumDataloader: _ = next(self._dataiter) except StopIteration: break - return next(self._dataiter) + return data else: raise StopIteration