[pipeline]: support arbitrary batch size in forward_only mode (#5201)

* fix: remove drop last in val & test dataloader

* feat: add run_forward_only, support arbitrary bs

* chore: modify ci script
This commit is contained in:
Wenhao Chen
2024-01-02 23:41:12 +08:00
committed by GitHub
parent 02d2328a04
commit 3c0d82b19b
4 changed files with 293 additions and 202 deletions

View File

@@ -88,24 +88,21 @@ class GLUEDataBuilder:
)
def val_dataloader(self):
# TODO: drop_last is set to True for now to avoid error when using PP
# as the last batch may not be divisible by the number of microbatches
if len(self.eval_splits) == 1:
return self.plugin.prepare_dataloader(
self.dataset["validation"], batch_size=self.eval_batch_size, drop_last=True
)
return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True)
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits
]
def test_dataloader(self):
if len(self.eval_splits) == 1:
return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size, drop_last=True)
return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True)
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits
]