[booster] update prepare dataloader method for plugin (#3706)

* [booster] add prepare dataloader method for plug

* [booster] update examples and docstr
This commit is contained in:
Hongxin Liu
2023-05-08 15:44:03 +08:00
committed by GitHub
parent f83ea813f5
commit 3bf09efe74
9 changed files with 41 additions and 40 deletions

View File

@@ -49,14 +49,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl
download=True)
# Data loader
train_dataloader = plugin.prepare_train_dataloader(train_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True)
test_dataloader = plugin.prepare_train_dataloader(test_dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False)
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
return train_dataloader, test_dataloader

View File

@@ -63,14 +63,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl
download=True)
# Data loader
train_dataloader = plugin.prepare_train_dataloader(train_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True)
test_dataloader = plugin.prepare_train_dataloader(test_dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False)
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
return train_dataloader, test_dataloader

View File

@@ -84,26 +84,26 @@ class GLUEDataBuilder:
AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
def train_dataloader(self):
return self.plugin.prepare_train_dataloader(self.dataset["train"],
batch_size=self.train_batch_size,
shuffle=True,
drop_last=True)
return self.plugin.prepare_dataloader(self.dataset["train"],
batch_size=self.train_batch_size,
shuffle=True,
drop_last=True)
def val_dataloader(self):
if len(self.eval_splits) == 1:
return self.plugin.prepare_train_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [
self.plugin.prepare_train_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits
]
def test_dataloader(self):
if len(self.eval_splits) == 1:
return self.plugin.prepare_train_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [
self.plugin.prepare_train_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits
]