mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[shardformer] write an shardformer example with bert finetuning (#4126)
* [shardformer] add benchmark of shardformer * [shardformer] add benchmark of shardformer
This commit is contained in:
146
colossalai/shardformer/examples/data.py
Normal file
146
colossalai/shardformer/examples/data.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import datasets
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
|
||||
|
||||
|
||||
class GLUEDataBuilder:
|
||||
|
||||
task_text_field_map = {
|
||||
"cola": ["sentence"],
|
||||
"sst2": ["sentence"],
|
||||
"mrpc": ["sentence1", "sentence2"],
|
||||
"qqp": ["question1", "question2"],
|
||||
"stsb": ["sentence1", "sentence2"],
|
||||
"mnli": ["premise", "hypothesis"],
|
||||
"qnli": ["question", "sentence"],
|
||||
"rte": ["sentence1", "sentence2"],
|
||||
"wnli": ["sentence1", "sentence2"],
|
||||
"ax": ["premise", "hypothesis"],
|
||||
}
|
||||
|
||||
glue_task_num_labels = {
|
||||
"cola": 2,
|
||||
"sst2": 2,
|
||||
"mrpc": 2,
|
||||
"qqp": 2,
|
||||
"stsb": 1,
|
||||
"mnli": 3,
|
||||
"qnli": 2,
|
||||
"rte": 2,
|
||||
"wnli": 2,
|
||||
"ax": 3,
|
||||
}
|
||||
|
||||
loader_columns = [
|
||||
"datasets_idx",
|
||||
"input_ids",
|
||||
"token_type_ids",
|
||||
"attention_mask",
|
||||
"start_positions",
|
||||
"end_positions",
|
||||
"labels",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
plugin: DPPluginBase = None,
|
||||
task_name: str = "mrpc",
|
||||
max_seq_length: int = 128,
|
||||
train_batch_size: int = 32,
|
||||
eval_batch_size: int = 32,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.task_name = task_name
|
||||
self.max_seq_length = max_seq_length
|
||||
self.train_batch_size = train_batch_size
|
||||
self.eval_batch_size = eval_batch_size
|
||||
self.plugin = plugin
|
||||
|
||||
self.text_fields = self.task_text_field_map[task_name]
|
||||
self.num_labels = self.glue_task_num_labels[task_name]
|
||||
self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
self.dataset = datasets.load_dataset("glue", self.task_name)
|
||||
|
||||
for split in self.dataset.keys():
|
||||
self.dataset[split] = self.dataset[split].map(
|
||||
self.convert_to_features,
|
||||
batched=True,
|
||||
remove_columns=["label"],
|
||||
)
|
||||
self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
|
||||
self.dataset[split].set_format(type="torch", columns=self.columns)
|
||||
|
||||
self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]
|
||||
|
||||
def prepare_data(self):
|
||||
datasets.load_dataset("glue", self.task_name)
|
||||
AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
|
||||
|
||||
def train_dataloader(self):
|
||||
if self.plugin == None:
|
||||
return self.native_prepare_dataloader(self.dataset["train"],
|
||||
batch_size=self.train_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
return self.plugin.prepare_dataloader(self.dataset["train"],
|
||||
batch_size=self.train_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
|
||||
def val_dataloader(self):
|
||||
if self.plugin == None:
|
||||
return self.native_prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
|
||||
if len(self.eval_splits) == 1:
|
||||
return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
|
||||
elif len(self.eval_splits) > 1:
|
||||
return [
|
||||
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
|
||||
for x in self.eval_splits
|
||||
]
|
||||
|
||||
def test_dataloader(self):
|
||||
if self.plugin == None:
|
||||
return self.native_prepare_dataloader(self.dataset["test"], batch_size=self.train_batch_size)
|
||||
if len(self.eval_splits) == 1:
|
||||
return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
|
||||
elif len(self.eval_splits) > 1:
|
||||
return [
|
||||
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
|
||||
for x in self.eval_splits
|
||||
]
|
||||
|
||||
def convert_to_features(self, example_batch):
|
||||
|
||||
# Either encode single sentence or sentence pairs
|
||||
if len(self.text_fields) > 1:
|
||||
texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
|
||||
else:
|
||||
texts_or_text_pairs = example_batch[self.text_fields[0]]
|
||||
|
||||
# Tokenize the text/text pairs
|
||||
features = self.tokenizer.batch_encode_plus(texts_or_text_pairs,
|
||||
max_length=self.max_seq_length,
|
||||
padding='max_length',
|
||||
truncation=True)
|
||||
|
||||
# Rename label to labels to make it easier to pass to model forward
|
||||
features["labels"] = example_batch["label"]
|
||||
|
||||
return features
|
||||
|
||||
def native_prepare_dataloader(self, dataset, batch_size, shuffle=False, drop_last=False, pin_memory=False):
|
||||
|
||||
return DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=None,
|
||||
shuffle=shuffle,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory)
|
Reference in New Issue
Block a user