From e58cc441e2142f53b61d2b95558974753f9a6e68 Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Wed, 18 Jan 2023 12:00:08 +0800 Subject: [PATCH] polish code and fix dataloader bugs --- .../language/gpt/titans/dataset/webtext.py | 42 +++++++------- examples/language/gpt/titans/run.sh | 3 +- examples/language/gpt/titans/train_gpt.py | 55 ++++--------------- 3 files changed, 35 insertions(+), 65 deletions(-) diff --git a/examples/language/gpt/titans/dataset/webtext.py b/examples/language/gpt/titans/dataset/webtext.py index 09d8870b5..64f5944a9 100644 --- a/examples/language/gpt/titans/dataset/webtext.py +++ b/examples/language/gpt/titans/dataset/webtext.py @@ -1,5 +1,6 @@ import json import os +from typing import Optional import torch from torch.utils.data import Dataset @@ -11,26 +12,29 @@ from colossalai.registry import DATASETS @DATASETS.register_module class WebtextDataset(Dataset): - def __init__(self, path, seq_len=1024) -> None: + def __init__(self, path: Optional[str] = None, seq_len=1024) -> None: super().__init__() - root = os.path.dirname(path) - encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt') - if os.path.isfile(encoded_data_cache_path): - seq_len_, data, attention_mask = torch.load(encoded_data_cache_path) - if seq_len_ == seq_len: - self.data = data - self.attention_mask = attention_mask - return - raw_data = [] - with open(path) as f: - for line in f.readlines(): - raw_data.append(json.loads(line)['text']) - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - tokenizer.pad_token = tokenizer.unk_token - encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt') - self.data = encoded_data['input_ids'] - self.attention_mask = encoded_data['attention_mask'] - torch.save((seq_len, self.data, self.attention_mask), encoded_data_cache_path) + if path is not None: + root = os.path.dirname(path) + encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt') + if os.path.isfile(encoded_data_cache_path): + seq_len_, data, attention_mask = torch.load(encoded_data_cache_path) + if seq_len_ == seq_len: + self.data = data + self.attention_mask = attention_mask + return + raw_data = [] + with open(path) as f: + for line in f.readlines(): + raw_data.append(json.loads(line)['text']) + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.unk_token + encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt') + self.data = encoded_data['input_ids'] + self.attention_mask = encoded_data['attention_mask'] + else: + self.data = torch.randint(0, 50257, (10240, seq_len)) + self.attention_mask = torch.ones_like(self.data) def __len__(self): return len(self.data) diff --git a/examples/language/gpt/titans/run.sh b/examples/language/gpt/titans/run.sh index 157bd377a..a1a7fc737 100644 --- a/examples/language/gpt/titans/run.sh +++ b/examples/language/gpt/titans/run.sh @@ -1,2 +1,3 @@ export DATA=/data/scratch/gpt_data/small-gpt-dataset.json -colossalai run --nproc_per_node=4 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch +DUMMY_DATA=--use_dummy_dataset +colossalai run --nproc_per_node=2 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch $DUMMY_DATA diff --git a/examples/language/gpt/titans/train_gpt.py b/examples/language/gpt/titans/train_gpt.py index 4db7a081f..66225d6c8 100644 --- a/examples/language/gpt/titans/train_gpt.py +++ b/examples/language/gpt/titans/train_gpt.py @@ -3,6 +3,7 @@ import os import torch import torch.nn as nn +from dataset.webtext import WebtextDataset from titans.model.gpt import GPTLMLoss import colossalai @@ -39,52 +40,16 @@ def main(): colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42) logger = get_dist_logger() - if not args.use_dummy_dataset: - data_path = os.environ['DATA'] - logger.info(f'Build data loader from path {data_path}', ranks=[0]) - from dataset.webtext import WebtextDataset - train_ds = WebtextDataset(os.environ['DATA'], seq_len=gpc.config.SEQ_LEN) - train_dataloader = utils.get_dataloader(train_ds, - seed=42, - batch_size=gpc.config.BATCH_SIZE, - pin_memory=True, - shuffle=True, - drop_last=True) - else: - # build a dummy train_dataloader - logger.info('Build data loader using dummy data', ranks=[0]) + data_path = None if args.use_dummy_dataset else os.environ['DATA'] + logger.info(f'Build data loader from path {data_path}', ranks=[0]) - def get_data(batch_size, seq_len, vocab_size): - input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) - attention_mask = torch.ones_like(input_ids) - return input_ids, attention_mask - - # 10 iterations - input_ids, attn_mask = get_data(gpc.config.BATCH_SIZE * 10, gpc.config.SEQ_LEN, VOCAB_SIZE) - from torch.utils.data import DataLoader, Dataset - - class TextSamplerDataset(Dataset): - - def __init__(self, data, seq_len): - super().__init__() - self.data = data - self.seq_len = seq_len - - def __getitem__(self, index): - rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) - full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long() - return full_seq.cuda() - - def __len__(self): - return self.data.size(0) // self.seq_len - - def cycle(loader): - while True: - for data in loader: - yield data - - train_dataset = TextSamplerDataset(input_ids, gpc.config.SEQ_LEN) - train_dataloader = DataLoader(train_dataset, batch_size=gpc.config.BATCH_SIZE) + train_ds = WebtextDataset(path=data_path, seq_len=gpc.config.SEQ_LEN) + train_dataloader = utils.get_dataloader(train_ds, + seed=42, + batch_size=gpc.config.BATCH_SIZE, + pin_memory=True, + shuffle=True, + drop_last=True) logger.info('Build model', ranks=[0]) use_pipeline = is_using_pp()