[engine] fixed bug in gradient accumulation dataloader to keep the last step (#1030)

This commit is contained in:
Frank Lee
2022-05-26 14:28:23 +08:00
committed by GitHub
parent 32291dd73f
commit e4685832f8

View File

@@ -145,6 +145,7 @@ class GradAccumDataloader:
def __next__(self) -> Union[Tensor, Tuple[Tensor]]:
if self._cur_step < self.steps_per_epoch:
self._cur_step += 1
data = next(self._dataiter)
if self._cur_step == self.steps_per_epoch and self.consume_remain_data:
# this is to handle non standard pytorch dataloader
@@ -154,7 +155,7 @@ class GradAccumDataloader:
_ = next(self._dataiter)
except StopIteration:
break
return next(self._dataiter)
return data
else:
raise StopIteration