mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-24 17:33:39 +00:00
[tutorial] edited hands-on practices (#1899)
* Add handson to ColossalAI. * Change names of handsons and edit sequence parallel example. * Edit wrong folder name * resolve conflict * delete readme
This commit is contained in:
143
examples/tutorial/sequence_parallel/README.md
Normal file
143
examples/tutorial/sequence_parallel/README.md
Normal file
@@ -0,0 +1,143 @@
|
||||
# Handson 2: Sequence Parallelism with BERT
|
||||
|
||||
In this example, we implemented BERT with sequence parallelism. Sequence parallelism splits the input tensor and intermediate
|
||||
activation along the sequence dimension. This method can achieve better memory efficiency and allows us to train with larger batch size and longer sequence length.
|
||||
|
||||
Paper: [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120)
|
||||
|
||||
## How to Prepare WikiPedia Dataset
|
||||
|
||||
First, let's prepare the WikiPedia dataset from scratch. To generate a preprocessed dataset, we need four items:
|
||||
1. raw WikiPedia dataset
|
||||
2. wikipedia extractor (extract data from the raw dataset)
|
||||
3. vocabulary file
|
||||
4. preprocessing scripts (generate final data from extracted data)
|
||||
|
||||
For the preprocessing script, we thank Megatron-LM for providing a preprocessing script to generate the corpus file.
|
||||
|
||||
```python
|
||||
# download raw data
|
||||
mkdir data && cd ./data
|
||||
wget https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2
|
||||
|
||||
# install wiki extractor
|
||||
git clone https://github.com/FrankLeeeee/wikiextractor.git
|
||||
pip install ./wikiextractor
|
||||
|
||||
# extractmodule
|
||||
wikiextractor --json enwiki-latest-pages-articles.xml.bz2
|
||||
cat text/*/* > ./corpus.json
|
||||
cd ..
|
||||
|
||||
# download vocab file
|
||||
mkdir vocab && cd ./vocab
|
||||
wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt
|
||||
cd ..
|
||||
|
||||
# preprocess some data
|
||||
git clone https://github.com/NVIDIA/Megatron-LM.git
|
||||
cd ./Megatron-LM
|
||||
python tools/preprocess_data.py \
|
||||
--input ../data/corpus.json \
|
||||
--output-prefix my-bert \
|
||||
--vocab ../vocab/bert-large-uncased-vocab.txt \
|
||||
--dataset-impl mmap \
|
||||
--tokenizer-type BertWordPieceLowerCase \
|
||||
--split-sentences \
|
||||
--workers 24
|
||||
```
|
||||
|
||||
After running the preprocessing scripts, you will obtain two files:
|
||||
1. my-bert_text_sentence.bin
|
||||
2. my-bert_text_sentence.idx
|
||||
|
||||
If you happen to encouter `index out of range` problem when running Megatron's script,
|
||||
this is probably because that a sentence starts with a punctuation and cannot be tokenized. A work-around is to update `Encoder.encode` method with the code below:
|
||||
|
||||
```python
|
||||
class Encoder(object):
|
||||
def __init__(self, args):
|
||||
...
|
||||
|
||||
def initializer(self):
|
||||
...
|
||||
|
||||
def encode(self, json_line):
|
||||
data = json.loads(json_line)
|
||||
ids = {}
|
||||
for key in self.args.json_keys:
|
||||
text = data[key]
|
||||
doc_ids = []
|
||||
|
||||
# lsg: avoid sentences which start with a punctuation
|
||||
# as it cannot be tokenized by splitter
|
||||
if len(text) > 0 and text[0] in string.punctuation:
|
||||
text = text[1:]
|
||||
|
||||
for sentence in Encoder.splitter.tokenize(text):
|
||||
sentence_ids = Encoder.tokenizer.tokenize(sentence)
|
||||
if len(sentence_ids) > 0:
|
||||
doc_ids.append(sentence_ids)
|
||||
if len(doc_ids) > 0 and self.args.append_eod:
|
||||
doc_ids[-1].append(Encoder.tokenizer.eod)
|
||||
ids[key] = doc_ids
|
||||
return ids, len(json_line)
|
||||
```
|
||||
|
||||
## How to Train with Sequence Parallelism
|
||||
|
||||
We provided `train.py` for you to execute training. Before invoking the script, there are several
|
||||
steps to perform.
|
||||
|
||||
### Step 1. Set data path and vocab path
|
||||
|
||||
At the top of `config.py`, you can see two global variables `DATA_PATH` and `VOCAB_FILE_PATH`.
|
||||
|
||||
```python
|
||||
DATA_PATH = <data-path>
|
||||
VOCAB_FILE_PATH = <vocab-path>
|
||||
```
|
||||
|
||||
`DATA_PATH` refers to the path to the data file generated by Megatron's script. For example, in the section above, you should get two data files (my-bert_text_sentence.bin and my-bert_text_sentence.idx). You just need to `DATA_PATH` to the path to the bin file without the file extension.
|
||||
|
||||
For example, if your my-bert_text_sentence.bin is /home/Megatron-LM/my-bert_text_sentence.bin, then you should set
|
||||
|
||||
```python
|
||||
DATA_PATH = '/home/Megatron-LM/my-bert_text_sentence'
|
||||
```
|
||||
|
||||
The `VOCAB_FILE_PATH` refers to the path to the vocabulary downloaded when you prepare the dataset
|
||||
(e.g. bert-large-uncased-vocab.txt).
|
||||
|
||||
### Step 3. Make Dataset Helper
|
||||
|
||||
Build BERT dataset helper. Requirements are `CUDA`, `g++`, `pybind11` and `make`.
|
||||
|
||||
```python
|
||||
cd ./data/datasets
|
||||
make
|
||||
```
|
||||
|
||||
### Step 3. Configure your parameters
|
||||
|
||||
In the `config.py` provided, a set of parameters are defined including training scheme, model, etc.
|
||||
You can also modify the ColossalAI setting. For example, if you wish to parallelize over the
|
||||
sequence dimension on 8 GPUs. You can change `size=4` to `size=8`. If you wish to use pipeline parallelism, you can set `pipeline=<num_of_pipeline_stages>`.
|
||||
|
||||
### Step 4. Invoke parallel training
|
||||
|
||||
Lastly, you can start training with sequence parallelism. How you invoke `train.py` depends on your
|
||||
machine setting.
|
||||
|
||||
- If you are using a single machine with multiple GPUs, PyTorch launch utility can easily let you
|
||||
start your script. A sample command is like below:
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.launch --nproc_per_node <num_gpus_on_this_machine> --master_addr localhost --master_port 29500 train.py
|
||||
```
|
||||
|
||||
- If you are using multiple machines with multiple GPUs, we suggest that you refer to `colossalai
|
||||
launch_from_slurm` or `colossalai.launch_from_openmpi` as it is easier to use SLURM and OpenMPI
|
||||
to start multiple processes over multiple nodes. If you have your own launcher, you can fall back
|
||||
to the default `colossalai.launch` function.
|
||||
|
||||
40
examples/tutorial/sequence_parallel/config.py
Normal file
40
examples/tutorial/sequence_parallel/config.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
DATA_PATH = ''
|
||||
VOCAB_FILE_PATH = ''
|
||||
|
||||
# hyper-parameters
|
||||
TRAIN_ITERS = 1000000
|
||||
DECAY_ITERS = 990000
|
||||
WARMUP_FRACTION = 0.01
|
||||
GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU
|
||||
EVAL_ITERS = 10
|
||||
EVAL_INTERVAL = 10
|
||||
LR = 0.0001
|
||||
MIN_LR = 1e-05
|
||||
WEIGHT_DECAY = 0.01
|
||||
SEQ_LENGTH = 512
|
||||
|
||||
# BERT config
|
||||
DEPTH = 12
|
||||
NUM_ATTENTION_HEADS = 12
|
||||
HIDDEN_SIZE = 768
|
||||
|
||||
# model config
|
||||
ADD_BINARY_HEAD = False
|
||||
|
||||
# random seed
|
||||
SEED = 1234
|
||||
|
||||
# pipeline config
|
||||
# only enabled when pipeline > 1
|
||||
NUM_MICRO_BATCHES = 4
|
||||
|
||||
# colossalai config
|
||||
parallel = dict(pipeline=1, tensor=dict(size=4, mode='sequence'))
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE, verbose=True)
|
||||
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
gradient_handler = [dict(type='SequenceParallelGradientHandler')]
|
||||
102
examples/tutorial/sequence_parallel/data/__init__.py
Normal file
102
examples/tutorial/sequence_parallel/data/__init__.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from colossalai.context.parallel_context import ParallelContext
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.context import ParallelMode
|
||||
from .datasets.data_samplers import build_pretraining_data_loader
|
||||
from .datasets.builder import build_train_valid_test_datasets
|
||||
import torch
|
||||
|
||||
|
||||
def cyclic_iter(iter):
|
||||
while True:
|
||||
for x in iter:
|
||||
yield x
|
||||
|
||||
|
||||
def build_train_valid_test_data_iterators(train_iters,
|
||||
global_batch_size,
|
||||
eval_interval,
|
||||
eval_iters,
|
||||
dataloader_type='single',
|
||||
**kwargs
|
||||
):
|
||||
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.info('> building train, validation, and test datasets ...', ranks=[0])
|
||||
|
||||
# Backward compatibility, assume fixed batch size.
|
||||
# if iteration > 0 and consumed_train_samples == 0:
|
||||
# assert train_samples is None, \
|
||||
# 'only backward compatibility support for iteration-based training'
|
||||
# consumed_train_samples = iteration * global_batch_size
|
||||
# if iteration > 0 and consumed_valid_samples == 0:
|
||||
# if train_samples is None:
|
||||
# consumed_valid_samples = (iteration // eval_interval) * \
|
||||
# eval_iters * global_batch_size
|
||||
|
||||
# Data loader only on rank 0 of each model parallel group.
|
||||
if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
|
||||
# Number of train/valid/test samples.
|
||||
train_samples = train_iters * global_batch_size
|
||||
eval_iters_ = (train_iters // eval_interval + 1) * eval_iters
|
||||
test_iters = eval_iters
|
||||
train_val_test_num_samples = [train_samples,
|
||||
eval_iters_ * global_batch_size,
|
||||
test_iters * global_batch_size]
|
||||
logger.info(' > datasets target sizes (minimum size):')
|
||||
logger.info(' train: {}'.format(train_val_test_num_samples[0]), ranks=[0])
|
||||
logger.info(' validation: {}'.format(train_val_test_num_samples[1]), ranks=[0])
|
||||
logger.info(' test: {}'.format(train_val_test_num_samples[2]), ranks=[0])
|
||||
|
||||
# Build the datasets.
|
||||
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
|
||||
train_valid_test_num_samples=train_val_test_num_samples, **kwargs)
|
||||
|
||||
# Build dataloaders.
|
||||
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
train_dataloader = build_pretraining_data_loader(
|
||||
train_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size)
|
||||
valid_dataloader = build_pretraining_data_loader(
|
||||
valid_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size)
|
||||
test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size//dp_size)
|
||||
|
||||
# Flags to know if we need to do training/validation/testing.
|
||||
do_train = train_dataloader is not None and train_iters > 0
|
||||
do_valid = valid_dataloader is not None and eval_iters > 0
|
||||
do_test = test_dataloader is not None and eval_iters > 0
|
||||
# Need to broadcast num_tokens and num_type_tokens.
|
||||
flags = torch.cuda.LongTensor(
|
||||
[int(do_train), int(do_valid), int(do_test)])
|
||||
else:
|
||||
flags = torch.cuda.LongTensor([0, 0, 0])
|
||||
|
||||
# Broadcast num tokens.
|
||||
torch.distributed.broadcast(flags,
|
||||
gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
|
||||
group=gpc.get_group(ParallelMode.TENSOR))
|
||||
|
||||
# Build iterators.
|
||||
dl_type = dataloader_type
|
||||
assert dl_type in ['single', 'cyclic']
|
||||
|
||||
if train_dataloader is not None:
|
||||
train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
|
||||
else iter(cyclic_iter(train_dataloader))
|
||||
else:
|
||||
train_data_iterator = None
|
||||
|
||||
if valid_dataloader is not None:
|
||||
valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
|
||||
else iter(cyclic_iter(valid_dataloader))
|
||||
else:
|
||||
valid_data_iterator = None
|
||||
|
||||
if test_dataloader is not None:
|
||||
test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
|
||||
else iter(cyclic_iter(test_dataloader))
|
||||
else:
|
||||
test_data_iterator = None
|
||||
|
||||
return train_data_iterator, valid_data_iterator, test_data_iterator
|
||||
165
examples/tutorial/sequence_parallel/data/bert_helper.py
Normal file
165
examples/tutorial/sequence_parallel/data/bert_helper.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
import torch
|
||||
|
||||
_MAX_DATA_DIM = 5
|
||||
|
||||
|
||||
def _build_key_size_numel_dictionaries(keys, data):
|
||||
"""Build the size on rank 0 and broadcast."""
|
||||
max_dim = _MAX_DATA_DIM
|
||||
sizes = [0 for _ in range(max_dim) for _ in keys]
|
||||
|
||||
# Pack the sizes on rank zero.
|
||||
if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
offset = 0
|
||||
for key in keys:
|
||||
assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
|
||||
size = data[key].size()
|
||||
for i, s in enumerate(size):
|
||||
sizes[i + offset] = s
|
||||
offset += max_dim
|
||||
|
||||
# Move to GPU and broadcast.
|
||||
sizes_cuda = torch.cuda.LongTensor(sizes)
|
||||
torch.distributed.broadcast(sizes_cuda, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
|
||||
group=gpc.get_group(ParallelMode.TENSOR))
|
||||
|
||||
# Move back to cpu and unpack.
|
||||
sizes_cpu = sizes_cuda.cpu()
|
||||
key_size = {}
|
||||
key_numel = {}
|
||||
total_numel = 0
|
||||
offset = 0
|
||||
for key in keys:
|
||||
i = 0
|
||||
size = []
|
||||
numel = 1
|
||||
while sizes_cpu[offset + i] > 0:
|
||||
this_size = sizes_cpu[offset + i]
|
||||
size.append(this_size)
|
||||
numel *= this_size
|
||||
i += 1
|
||||
key_size[key] = size
|
||||
key_numel[key] = numel
|
||||
total_numel += numel
|
||||
offset += max_dim
|
||||
|
||||
return key_size, key_numel, total_numel
|
||||
|
||||
|
||||
def broadcast_data(keys, data, datatype):
|
||||
"""Broadcast data from rank zero of each model parallel group to the
|
||||
members of the same model parallel group.
|
||||
|
||||
Arguments:
|
||||
keys: list of keys in the data dictionary to be broadcasted
|
||||
data: data dictionary of string keys and cpu tensor values.
|
||||
datatype: torch data type of all tensors in data associated
|
||||
with keys.
|
||||
"""
|
||||
# Build (key, size) and (key, number of elements) dictionaries along
|
||||
# with the total number of elements on all ranks.
|
||||
key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys,
|
||||
data)
|
||||
|
||||
# Pack on rank zero.
|
||||
if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# Check that all keys have the same data type.
|
||||
# Flatten the data associated with the keys
|
||||
flatten_data = torch.cat(
|
||||
[data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
|
||||
else:
|
||||
flatten_data = torch.empty(total_numel,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=datatype)
|
||||
|
||||
# Broadcast
|
||||
torch.distributed.broadcast(flatten_data,
|
||||
gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
|
||||
group=gpc.get_group(ParallelMode.TENSOR))
|
||||
|
||||
# Unpack
|
||||
output = {}
|
||||
offset = 0
|
||||
for key in keys:
|
||||
size = key_size[key]
|
||||
numel = key_numel[key]
|
||||
output[key] = flatten_data.narrow(0, offset, numel).view(size)
|
||||
offset += numel
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_batch(data_iterator):
|
||||
"""Build the batch."""
|
||||
|
||||
# Items and their type.
|
||||
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
|
||||
datatype = torch.int64
|
||||
|
||||
# Broadcast data.
|
||||
if data_iterator is not None:
|
||||
data = next(data_iterator)
|
||||
else:
|
||||
data = None
|
||||
data_b = broadcast_data(keys, data, datatype)
|
||||
|
||||
# Unpack.
|
||||
tokens = data_b['text'].long()
|
||||
types = data_b['types'].long()
|
||||
sentence_order = data_b['is_random'].long()
|
||||
loss_mask = data_b['loss_mask'].float()
|
||||
lm_labels = data_b['labels'].long()
|
||||
padding_mask = data_b['padding_mask'].long()
|
||||
|
||||
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
|
||||
|
||||
|
||||
def get_batch_for_sequence_parallel(data_iterator):
|
||||
"""Build the batch."""
|
||||
|
||||
# Items and their type.
|
||||
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
|
||||
datatype = torch.int64
|
||||
|
||||
# Broadcast data.
|
||||
if data_iterator is not None:
|
||||
data = next(data_iterator)
|
||||
else:
|
||||
data = None
|
||||
|
||||
# unpack
|
||||
data_b = broadcast_data(keys, data, datatype)
|
||||
|
||||
# # get tensor parallel local rank
|
||||
global_rank = torch.distributed.get_rank()
|
||||
local_world_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR)
|
||||
local_rank = global_rank % local_world_size
|
||||
seq_length = data_b['text'].size(1)
|
||||
sub_seq_length = seq_length // local_world_size
|
||||
sub_seq_start = local_rank * sub_seq_length
|
||||
sub_seq_end = (local_rank+1) * sub_seq_length
|
||||
#
|
||||
# # Unpack.
|
||||
tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long()
|
||||
types = data_b['types'][:, sub_seq_start:sub_seq_end].long()
|
||||
sentence_order = data_b['is_random'].long()
|
||||
loss_mask = data_b['loss_mask'][:, sub_seq_start:sub_seq_end].float()
|
||||
lm_labels = data_b['labels'][:, sub_seq_start:sub_seq_end].long()
|
||||
padding_mask = data_b['padding_mask'].long()
|
||||
|
||||
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
|
||||
|
||||
|
||||
class SequenceParallelDataIterator:
|
||||
|
||||
def __init__(self, data_iter):
|
||||
self.data_iter = data_iter
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
return self.data_iter
|
||||
|
||||
def __next__(self):
|
||||
return get_batch_for_sequence_parallel(self.data_iter)
|
||||
@@ -0,0 +1,9 @@
|
||||
CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color
|
||||
CPPFLAGS += $(shell python3 -m pybind11 --includes)
|
||||
LIBNAME = helpers
|
||||
LIBEXT = $(shell python3-config --extension-suffix)
|
||||
|
||||
default: $(LIBNAME)$(LIBEXT)
|
||||
|
||||
%$(LIBEXT): %.cpp
|
||||
$(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@
|
||||
@@ -0,0 +1 @@
|
||||
from . import indexed_dataset
|
||||
@@ -0,0 +1,225 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""BERT Style dataset."""
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from ..tokenizer import get_tokenizer
|
||||
from .dataset_utils import (get_a_and_b_segments, truncate_segments, create_tokens_and_tokentypes,
|
||||
create_masked_lm_predictions, pad_and_convert_to_numpy)
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
import time
|
||||
import os
|
||||
from . import helpers
|
||||
|
||||
|
||||
class BertDataset(Dataset):
|
||||
|
||||
def __init__(self, name, indexed_dataset, data_prefix, num_epochs, max_num_samples, masked_lm_prob, max_seq_length,
|
||||
short_seq_prob, seed, binary_head):
|
||||
|
||||
# Params to store.
|
||||
self.name = name
|
||||
self.seed = seed
|
||||
self.masked_lm_prob = masked_lm_prob
|
||||
self.max_seq_length = max_seq_length
|
||||
self.binary_head = binary_head
|
||||
|
||||
# Dataset.
|
||||
self.indexed_dataset = indexed_dataset
|
||||
|
||||
# Build the samples mapping.
|
||||
self.samples_mapping = get_samples_mapping_(
|
||||
self.indexed_dataset,
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
self.max_seq_length - 3, # account for added tokens,
|
||||
short_seq_prob,
|
||||
self.seed,
|
||||
self.name,
|
||||
self.binary_head)
|
||||
|
||||
# Vocab stuff.
|
||||
tokenizer = get_tokenizer()
|
||||
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
|
||||
self.vocab_id_to_token_dict = tokenizer.inv_vocab
|
||||
self.cls_id = tokenizer.cls
|
||||
self.sep_id = tokenizer.sep
|
||||
self.mask_id = tokenizer.mask
|
||||
self.pad_id = tokenizer.pad
|
||||
|
||||
def __len__(self):
|
||||
return self.samples_mapping.shape[0]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
start_idx, end_idx, seq_length = self.samples_mapping[idx]
|
||||
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
|
||||
# Note that this rng state should be numpy and not python since
|
||||
# python randint is inclusive whereas the numpy one is exclusive.
|
||||
# We % 2**32 since numpy requires the seed to be between 0 and 2**32 - 1
|
||||
np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
|
||||
return build_training_sample(
|
||||
sample,
|
||||
seq_length,
|
||||
self.max_seq_length, # needed for padding
|
||||
self.vocab_id_list,
|
||||
self.vocab_id_to_token_dict,
|
||||
self.cls_id,
|
||||
self.sep_id,
|
||||
self.mask_id,
|
||||
self.pad_id,
|
||||
self.masked_lm_prob,
|
||||
np_rng,
|
||||
self.binary_head)
|
||||
|
||||
|
||||
def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob,
|
||||
seed, name, binary_head):
|
||||
logger = get_dist_logger()
|
||||
if not num_epochs:
|
||||
if not max_num_samples:
|
||||
raise ValueError("Need to specify either max_num_samples "
|
||||
"or num_epochs")
|
||||
num_epochs = np.iinfo(np.int32).max - 1
|
||||
if not max_num_samples:
|
||||
max_num_samples = np.iinfo(np.int64).max - 1
|
||||
|
||||
# Filename of the index mapping
|
||||
indexmap_filename = data_prefix
|
||||
indexmap_filename += '_{}_indexmap'.format(name)
|
||||
if num_epochs != (np.iinfo(np.int32).max - 1):
|
||||
indexmap_filename += '_{}ep'.format(num_epochs)
|
||||
if max_num_samples != (np.iinfo(np.int64).max - 1):
|
||||
indexmap_filename += '_{}mns'.format(max_num_samples)
|
||||
indexmap_filename += '_{}msl'.format(max_seq_length)
|
||||
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
|
||||
indexmap_filename += '_{}s'.format(seed)
|
||||
indexmap_filename += '.npy'
|
||||
|
||||
# Build the indexed mapping if not exist.
|
||||
if torch.distributed.get_rank() == 0 and \
|
||||
not os.path.isfile(indexmap_filename):
|
||||
print(' > WARNING: could not find index map file {}, building '
|
||||
'the indices on rank 0 ...'.format(indexmap_filename))
|
||||
|
||||
# Make sure the types match the helpers input types.
|
||||
assert indexed_dataset.doc_idx.dtype == np.int64
|
||||
assert indexed_dataset.sizes.dtype == np.int32
|
||||
|
||||
# Build samples mapping
|
||||
verbose = torch.distributed.get_rank() == 0
|
||||
start_time = time.time()
|
||||
logger.info('\n > building samples index mapping for {} ...'.format(name), ranks=[0])
|
||||
# First compile and then import.
|
||||
samples_mapping = helpers.build_mapping(indexed_dataset.doc_idx, indexed_dataset.sizes, num_epochs,
|
||||
max_num_samples, max_seq_length, short_seq_prob, seed, verbose,
|
||||
2 if binary_head else 1)
|
||||
logger.info('\n > done building samples index maping', ranks=[0])
|
||||
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
|
||||
logger.info('\n > saved the index mapping in {}'.format(indexmap_filename), ranks=[0])
|
||||
# Make sure all the ranks have built the mapping
|
||||
logger.info('\n > elapsed time to build and save samples mapping '
|
||||
'(seconds): {:4f}'.format(time.time() - start_time),
|
||||
ranks=[0])
|
||||
# This should be a barrier but nccl barrier assumes
|
||||
# device_index=rank which is not the case for model
|
||||
# parallel case
|
||||
counts = torch.cuda.LongTensor([1])
|
||||
torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.DATA))
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
assert counts[0].item() == (torch.distributed.get_world_size() //
|
||||
torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE)))
|
||||
|
||||
# Load indexed dataset.
|
||||
start_time = time.time()
|
||||
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
|
||||
logger.info('\n > loading indexed mapping from {}'.format(indexmap_filename) +
|
||||
'\n loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time) +
|
||||
'\n total number of samples: {}'.format(samples_mapping.shape[0]),
|
||||
ranks=[0])
|
||||
|
||||
return samples_mapping
|
||||
|
||||
|
||||
def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_list, vocab_id_to_token_dict, cls_id,
|
||||
sep_id, mask_id, pad_id, masked_lm_prob, np_rng, binary_head):
|
||||
"""Build training sample.
|
||||
|
||||
Arguments:
|
||||
sample: A list of sentences in which each sentence is a list token ids.
|
||||
target_seq_length: Desired sequence length.
|
||||
max_seq_length: Maximum length of the sequence. All values are padded to
|
||||
this length.
|
||||
vocab_id_list: List of vocabulary ids. Used to pick a random id.
|
||||
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
|
||||
cls_id: Start of example id.
|
||||
sep_id: Separator id.
|
||||
mask_id: Mask token id.
|
||||
pad_id: Padding token id.
|
||||
masked_lm_prob: Probability to mask tokens.
|
||||
np_rng: Random number genenrator. Note that this rng state should be
|
||||
numpy and not python since python randint is inclusive for
|
||||
the opper bound whereas the numpy one is exclusive.
|
||||
"""
|
||||
|
||||
if binary_head:
|
||||
# We assume that we have at least two sentences in the sample
|
||||
assert len(sample) > 1
|
||||
assert target_seq_length <= max_seq_length
|
||||
|
||||
# Divide sample into two segments (A and B).
|
||||
if binary_head:
|
||||
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng)
|
||||
else:
|
||||
tokens_a = []
|
||||
for j in range(len(sample)):
|
||||
tokens_a.extend(sample[j])
|
||||
tokens_b = []
|
||||
is_next_random = False
|
||||
|
||||
# Truncate to `target_sequence_length`.
|
||||
max_num_tokens = target_seq_length
|
||||
truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), len(tokens_b), max_num_tokens, np_rng)
|
||||
|
||||
# Build tokens and toketypes.
|
||||
tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id)
|
||||
|
||||
# Masking.
|
||||
max_predictions_per_seq = masked_lm_prob * max_num_tokens
|
||||
(tokens, masked_positions, masked_labels,
|
||||
_) = create_masked_lm_predictions(tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id,
|
||||
mask_id, max_predictions_per_seq, np_rng)
|
||||
|
||||
# Padding.
|
||||
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
|
||||
= pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
||||
masked_labels, pad_id, max_seq_length)
|
||||
|
||||
train_sample = {
|
||||
'text': tokens_np,
|
||||
'types': tokentypes_np,
|
||||
'labels': labels_np,
|
||||
'is_random': int(is_next_random),
|
||||
'loss_mask': loss_mask_np,
|
||||
'padding_mask': padding_mask_np,
|
||||
'truncated': int(truncated)
|
||||
}
|
||||
return train_sample
|
||||
@@ -0,0 +1,62 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Blendable dataset."""
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class BlendableDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, datasets, weights):
|
||||
|
||||
self.datasets = datasets
|
||||
num_datasets = len(datasets)
|
||||
assert num_datasets == len(weights)
|
||||
|
||||
self.size = 0
|
||||
for dataset in self.datasets:
|
||||
self.size += len(dataset)
|
||||
|
||||
# Normalize weights.
|
||||
weights = np.array(weights, dtype=np.float64)
|
||||
sum_weights = np.sum(weights)
|
||||
assert sum_weights > 0.0
|
||||
weights /= sum_weights
|
||||
|
||||
# Build indices.
|
||||
start_time = time.time()
|
||||
assert num_datasets < 255
|
||||
self.dataset_index = np.zeros(self.size, dtype=np.uint8)
|
||||
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
|
||||
|
||||
from . import helpers
|
||||
helpers.build_blending_indices(self.dataset_index,
|
||||
self.dataset_sample_index,
|
||||
weights, num_datasets, self.size,
|
||||
torch.distributed.get_rank() == 0)
|
||||
print('> elapsed time for building blendable dataset indices: '
|
||||
'{:.2f} (sec)'.format(time.time() - start_time))
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
def __getitem__(self, idx):
|
||||
dataset_idx = self.dataset_index[idx]
|
||||
sample_idx = self.dataset_sample_index[idx]
|
||||
return self.datasets[dataset_idx][sample_idx]
|
||||
152
examples/tutorial/sequence_parallel/data/datasets/builder.py
Normal file
152
examples/tutorial/sequence_parallel/data/datasets/builder.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from .blendable_dataset import BlendableDataset
|
||||
from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_
|
||||
from .bert_dataset import BertDataset
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
DSET_TYPE_BERT = 'standard_bert'
|
||||
DSET_TYPE_ICT = 'ict'
|
||||
DSET_TYPE_T5 = 't5'
|
||||
|
||||
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]
|
||||
|
||||
|
||||
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed, skip_warmup,
|
||||
binary_head,
|
||||
dataset_type='standard_bert'):
|
||||
|
||||
if dataset_type not in DSET_TYPES:
|
||||
raise ValueError("Invalid dataset_type: ", dataset_type)
|
||||
|
||||
# Indexed dataset.
|
||||
indexed_dataset = get_indexed_dataset_(data_prefix,
|
||||
data_impl,
|
||||
skip_warmup)
|
||||
|
||||
# Get start and end indices of train/valid/train into doc-idx
|
||||
# Note that doc-idx is designed to be num-docs + 1 so we can
|
||||
# easily iterate over it.
|
||||
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
|
||||
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# Print stats about the splits.
|
||||
logger.info('\n > dataset split:', ranks=[0])
|
||||
|
||||
def print_split_stats(name, index):
|
||||
start_index = indexed_dataset.doc_idx[splits[index]]
|
||||
end_index = indexed_dataset.doc_idx[splits[index + 1]]
|
||||
logger.info('\n {}:'.format(name) +
|
||||
'\n document indices in [{}, {}) total of {} documents'.format(
|
||||
splits[index], splits[index + 1],
|
||||
splits[index + 1] - splits[index]) +
|
||||
'\n sentence indices in [{}, {}) total of {} sentences'.format(
|
||||
start_index, end_index,
|
||||
end_index - start_index),
|
||||
ranks=[0])
|
||||
print_split_stats('train', 0)
|
||||
print_split_stats('validation', 1)
|
||||
print_split_stats('test', 2)
|
||||
|
||||
def build_dataset(index, name):
|
||||
dataset = None
|
||||
if splits[index + 1] > splits[index]:
|
||||
# Get the pointer to the original doc-idx so we can set it later.
|
||||
doc_idx_ptr = indexed_dataset.get_doc_idx()
|
||||
# Slice the doc-idx
|
||||
start_index = splits[index]
|
||||
# Add +1 so we can index into the dataset to get the upper bound.
|
||||
end_index = splits[index + 1] + 1
|
||||
# New doc_idx view.
|
||||
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
|
||||
# Build the dataset accordingly.
|
||||
kwargs = dict(
|
||||
name=name,
|
||||
data_prefix=data_prefix,
|
||||
num_epochs=None,
|
||||
max_num_samples=train_valid_test_num_samples[index],
|
||||
max_seq_length=max_seq_length,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
if dataset_type != DSET_TYPE_BERT:
|
||||
raise NotImplementedError("Only BERT dataset is supported")
|
||||
else:
|
||||
dataset = BertDataset(
|
||||
indexed_dataset=indexed_dataset,
|
||||
masked_lm_prob=masked_lm_prob,
|
||||
short_seq_prob=short_seq_prob,
|
||||
binary_head=binary_head,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Set the original pointer so dataset remains the main dataset.
|
||||
indexed_dataset.set_doc_idx(doc_idx_ptr)
|
||||
# Checks.
|
||||
assert indexed_dataset.doc_idx[0] == 0
|
||||
assert indexed_dataset.doc_idx.shape[0] == \
|
||||
(total_num_of_documents + 1)
|
||||
return dataset
|
||||
|
||||
train_dataset = build_dataset(0, 'train')
|
||||
valid_dataset = build_dataset(1, 'valid')
|
||||
test_dataset = build_dataset(2, 'test')
|
||||
|
||||
return (train_dataset, valid_dataset, test_dataset)
|
||||
|
||||
|
||||
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed, skip_warmup,
|
||||
binary_head,
|
||||
dataset_type='standard_bert'):
|
||||
|
||||
if len(data_prefix) == 1:
|
||||
return _build_train_valid_test_datasets(data_prefix[0],
|
||||
data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type=dataset_type)
|
||||
# Blending dataset.
|
||||
# Parse the values.
|
||||
output = get_datasets_weights_and_num_samples(data_prefix,
|
||||
train_valid_test_num_samples)
|
||||
prefixes, weights, datasets_train_valid_test_num_samples = output
|
||||
|
||||
# Build individual datasets.
|
||||
train_datasets = []
|
||||
valid_datasets = []
|
||||
test_datasets = []
|
||||
for i in range(len(prefixes)):
|
||||
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
|
||||
prefixes[i], data_impl, splits_string,
|
||||
datasets_train_valid_test_num_samples[i],
|
||||
max_seq_length, masked_lm_prob, short_seq_prob,
|
||||
seed, skip_warmup, binary_head, dataset_type=dataset_type)
|
||||
if train_ds:
|
||||
train_datasets.append(train_ds)
|
||||
if valid_ds:
|
||||
valid_datasets.append(valid_ds)
|
||||
if test_ds:
|
||||
test_datasets.append(test_ds)
|
||||
|
||||
# Blend.
|
||||
blending_train_dataset = None
|
||||
if train_datasets:
|
||||
blending_train_dataset = BlendableDataset(train_datasets, weights)
|
||||
blending_valid_dataset = None
|
||||
if valid_datasets:
|
||||
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
|
||||
blending_test_dataset = None
|
||||
if test_datasets:
|
||||
blending_test_dataset = BlendableDataset(test_datasets, weights)
|
||||
|
||||
return (blending_train_dataset, blending_valid_dataset,
|
||||
blending_test_dataset)
|
||||
@@ -0,0 +1,153 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Dataloaders."""
|
||||
|
||||
import torch
|
||||
import random
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
|
||||
|
||||
def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0):
|
||||
"""Build dataloader given an input dataset."""
|
||||
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
# Megatron sampler
|
||||
if dataloader_type == 'single':
|
||||
batch_sampler = MegatronPretrainingSampler(total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=micro_batch_size,
|
||||
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_parallel_size=gpc.get_world_size(ParallelMode.DATA))
|
||||
elif dataloader_type == 'cyclic':
|
||||
batch_sampler = MegatronPretrainingRandomSampler(total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=micro_batch_size,
|
||||
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_parallel_size=gpc.get_world_size(ParallelMode.DATA))
|
||||
else:
|
||||
raise Exception('{} dataloader type is not supported.'.format(dataloader_type))
|
||||
|
||||
# Torch dataloader.
|
||||
return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
|
||||
|
||||
|
||||
class MegatronPretrainingSampler:
|
||||
|
||||
def __init__(self,
|
||||
total_samples,
|
||||
consumed_samples,
|
||||
micro_batch_size,
|
||||
data_parallel_rank,
|
||||
data_parallel_size,
|
||||
drop_last=True):
|
||||
# Keep a copy of input params for later use.
|
||||
self.total_samples = total_samples
|
||||
self.consumed_samples = consumed_samples
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
self.micro_batch_times_data_parallel_size = \
|
||||
self.micro_batch_size * data_parallel_size
|
||||
self.drop_last = drop_last
|
||||
|
||||
# Sanity checks.
|
||||
assert self.total_samples > 0, \
|
||||
'no sample to consume: {}'.format(self.total_samples)
|
||||
assert self.consumed_samples < self.total_samples, \
|
||||
'no samples left to consume: {}, {}'.format(self.consumed_samples,
|
||||
self.total_samples)
|
||||
assert self.micro_batch_size > 0
|
||||
assert data_parallel_size > 0
|
||||
assert self.data_parallel_rank < data_parallel_size, \
|
||||
'data_parallel_rank should be smaller than data size: {}, ' \
|
||||
'{}'.format(self.data_parallel_rank, data_parallel_size)
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
|
||||
def get_start_end_idx(self):
|
||||
start_idx = self.data_parallel_rank * self.micro_batch_size
|
||||
end_idx = start_idx + self.micro_batch_size
|
||||
return start_idx, end_idx
|
||||
|
||||
def __iter__(self):
|
||||
batch = []
|
||||
# Last batch will be dropped if drop_last is not set False
|
||||
for idx in range(self.consumed_samples, self.total_samples):
|
||||
batch.append(idx)
|
||||
if len(batch) == self.micro_batch_times_data_parallel_size:
|
||||
start_idx, end_idx = self.get_start_end_idx()
|
||||
yield batch[start_idx:end_idx]
|
||||
batch = []
|
||||
|
||||
# Check the last partial batch and see drop_last is set
|
||||
if len(batch) > 0 and not self.drop_last:
|
||||
start_idx, end_idx = self.get_start_end_idx()
|
||||
yield batch[start_idx:end_idx]
|
||||
|
||||
|
||||
class MegatronPretrainingRandomSampler:
|
||||
|
||||
def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size):
|
||||
# Keep a copy of input params for later use.
|
||||
self.total_samples = total_samples
|
||||
self.consumed_samples = consumed_samples
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
self.data_parallel_size = data_parallel_size
|
||||
self.micro_batch_times_data_parallel_size = \
|
||||
self.micro_batch_size * data_parallel_size
|
||||
self.last_batch_size = \
|
||||
self.total_samples % self.micro_batch_times_data_parallel_size
|
||||
|
||||
# Sanity checks.
|
||||
assert self.total_samples > 0, \
|
||||
'no sample to consume: {}'.format(self.total_samples)
|
||||
assert self.micro_batch_size > 0
|
||||
assert data_parallel_size > 0
|
||||
assert self.data_parallel_rank < data_parallel_size, \
|
||||
'data_parallel_rank should be smaller than data size: {}, ' \
|
||||
'{}'.format(self.data_parallel_rank, data_parallel_size)
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
|
||||
def __iter__(self):
|
||||
active_total_samples = self.total_samples - self.last_batch_size
|
||||
self.epoch = self.consumed_samples // active_total_samples
|
||||
current_epoch_samples = self.consumed_samples % active_total_samples
|
||||
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
|
||||
|
||||
# data sharding and random sampling
|
||||
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
|
||||
* self.micro_batch_size
|
||||
bucket_offset = current_epoch_samples // self.data_parallel_size
|
||||
start_idx = self.data_parallel_rank * bucket_size
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
random_idx = torch.randperm(bucket_size, generator=g).tolist()
|
||||
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
|
||||
|
||||
batch = []
|
||||
# Last batch if not complete will be dropped.
|
||||
for idx in idx_range:
|
||||
batch.append(idx)
|
||||
if len(batch) == self.micro_batch_size:
|
||||
self.consumed_samples += self.micro_batch_times_data_parallel_size
|
||||
yield batch
|
||||
batch = []
|
||||
@@ -0,0 +1,592 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors, and NVIDIA.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# Most of the code here has been copied from:
|
||||
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
|
||||
# with some modifications.
|
||||
|
||||
import math
|
||||
import time
|
||||
import collections
|
||||
from colossalai.logging import get_dist_logger
|
||||
import numpy as np
|
||||
from .blendable_dataset import BlendableDataset
|
||||
from .indexed_dataset import make_dataset as make_indexed_dataset
|
||||
|
||||
DSET_TYPE_STD = 'standard_bert'
|
||||
DSET_TYPE_ICT = 'ict'
|
||||
|
||||
DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD]
|
||||
|
||||
|
||||
def get_datasets_weights_and_num_samples(data_prefix,
|
||||
train_valid_test_num_samples):
|
||||
|
||||
# The data prefix should be in the format of:
|
||||
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
|
||||
assert len(data_prefix) % 2 == 0
|
||||
num_datasets = len(data_prefix) // 2
|
||||
weights = [0]*num_datasets
|
||||
prefixes = [0]*num_datasets
|
||||
for i in range(num_datasets):
|
||||
weights[i] = float(data_prefix[2*i])
|
||||
prefixes[i] = (data_prefix[2*i+1]).strip()
|
||||
# Normalize weights
|
||||
weight_sum = 0.0
|
||||
for weight in weights:
|
||||
weight_sum += weight
|
||||
assert weight_sum > 0.0
|
||||
weights = [weight / weight_sum for weight in weights]
|
||||
|
||||
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
|
||||
# not uniformly distribute the number of samples, we still have
|
||||
# samples left to feed to the network.
|
||||
datasets_train_valid_test_num_samples = []
|
||||
for weight in weights:
|
||||
datasets_train_valid_test_num_samples.append(
|
||||
[int(math.ceil(val * weight * 1.005))
|
||||
for val in train_valid_test_num_samples])
|
||||
|
||||
return prefixes, weights, datasets_train_valid_test_num_samples
|
||||
|
||||
|
||||
def compile_helper():
|
||||
"""Compile helper function ar runtime. Make sure this
|
||||
is invoked on a single process."""
|
||||
import os
|
||||
import subprocess
|
||||
path = os.path.abspath(os.path.dirname(__file__))
|
||||
ret = subprocess.run(['make', '-C', path])
|
||||
if ret.returncode != 0:
|
||||
print("Making C++ dataset helpers module failed, exiting.")
|
||||
import sys
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def get_a_and_b_segments(sample, np_rng):
|
||||
"""Divide sample into a and b segments."""
|
||||
|
||||
# Number of sentences in the sample.
|
||||
n_sentences = len(sample)
|
||||
# Make sure we always have two sentences.
|
||||
assert n_sentences > 1, 'make sure each sample has at least two sentences.'
|
||||
|
||||
# First part:
|
||||
# `a_end` is how many sentences go into the `A`.
|
||||
a_end = 1
|
||||
if n_sentences >= 3:
|
||||
# Note that randin in numpy is exclusive.
|
||||
a_end = np_rng.randint(1, n_sentences)
|
||||
tokens_a = []
|
||||
for j in range(a_end):
|
||||
tokens_a.extend(sample[j])
|
||||
|
||||
# Second part:
|
||||
tokens_b = []
|
||||
for j in range(a_end, n_sentences):
|
||||
tokens_b.extend(sample[j])
|
||||
|
||||
# Random next:
|
||||
is_next_random = False
|
||||
if np_rng.random() < 0.5:
|
||||
is_next_random = True
|
||||
tokens_a, tokens_b = tokens_b, tokens_a
|
||||
|
||||
return tokens_a, tokens_b, is_next_random
|
||||
|
||||
|
||||
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
|
||||
"""Truncates a pair of sequences to a maximum sequence length."""
|
||||
#print(len_a, len_b, max_num_tokens)
|
||||
assert len_a > 0
|
||||
if len_a + len_b <= max_num_tokens:
|
||||
return False
|
||||
while len_a + len_b > max_num_tokens:
|
||||
if len_a > len_b:
|
||||
len_a -= 1
|
||||
tokens = tokens_a
|
||||
else:
|
||||
len_b -= 1
|
||||
tokens = tokens_b
|
||||
if np_rng.random() < 0.5:
|
||||
del tokens[0]
|
||||
else:
|
||||
tokens.pop()
|
||||
return True
|
||||
|
||||
|
||||
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
|
||||
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
|
||||
|
||||
tokens = []
|
||||
tokentypes = []
|
||||
# [CLS].
|
||||
tokens.append(cls_id)
|
||||
tokentypes.append(0)
|
||||
# Segment A.
|
||||
for token in tokens_a:
|
||||
tokens.append(token)
|
||||
tokentypes.append(0)
|
||||
# [SEP].
|
||||
tokens.append(sep_id)
|
||||
tokentypes.append(0)
|
||||
# Segment B.
|
||||
for token in tokens_b:
|
||||
tokens.append(token)
|
||||
tokentypes.append(1)
|
||||
if tokens_b:
|
||||
# [SEP].
|
||||
tokens.append(sep_id)
|
||||
tokentypes.append(1)
|
||||
|
||||
return tokens, tokentypes
|
||||
|
||||
|
||||
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
|
||||
["index", "label"])
|
||||
|
||||
|
||||
def is_start_piece(piece):
|
||||
"""Check if the current word piece is the starting piece (BERT)."""
|
||||
# When a word has been split into
|
||||
# WordPieces, the first token does not have any marker and any subsequence
|
||||
# tokens are prefixed with ##. So whenever we see the ## token, we
|
||||
# append it to the previous set of word indexes.
|
||||
return not piece.startswith("##")
|
||||
|
||||
|
||||
def create_masked_lm_predictions(tokens,
|
||||
vocab_id_list, vocab_id_to_token_dict,
|
||||
masked_lm_prob,
|
||||
cls_id, sep_id, mask_id,
|
||||
max_predictions_per_seq,
|
||||
np_rng,
|
||||
max_ngrams=3,
|
||||
do_whole_word_mask=True,
|
||||
favor_longer_ngram=False,
|
||||
do_permutation=False):
|
||||
"""Creates the predictions for the masked LM objective.
|
||||
Note: Tokens here are vocab ids and not text tokens."""
|
||||
|
||||
cand_indexes = []
|
||||
# Note(mingdachen): We create a list for recording if the piece is
|
||||
# the starting piece of current token, where 1 means true, so that
|
||||
# on-the-fly whole word masking is possible.
|
||||
token_boundary = [0] * len(tokens)
|
||||
|
||||
for (i, token) in enumerate(tokens):
|
||||
if token == cls_id or token == sep_id:
|
||||
token_boundary[i] = 1
|
||||
continue
|
||||
# Whole Word Masking means that if we mask all of the wordpieces
|
||||
# corresponding to an original word.
|
||||
#
|
||||
# Note that Whole Word Masking does *not* change the training code
|
||||
# at all -- we still predict each WordPiece independently, softmaxed
|
||||
# over the entire vocabulary.
|
||||
if (do_whole_word_mask and len(cand_indexes) >= 1 and
|
||||
not is_start_piece(vocab_id_to_token_dict[token])):
|
||||
cand_indexes[-1].append(i)
|
||||
else:
|
||||
cand_indexes.append([i])
|
||||
if is_start_piece(vocab_id_to_token_dict[token]):
|
||||
token_boundary[i] = 1
|
||||
|
||||
output_tokens = list(tokens)
|
||||
|
||||
masked_lm_positions = []
|
||||
masked_lm_labels = []
|
||||
|
||||
if masked_lm_prob == 0:
|
||||
return (output_tokens, masked_lm_positions,
|
||||
masked_lm_labels, token_boundary)
|
||||
|
||||
num_to_predict = min(max_predictions_per_seq,
|
||||
max(1, int(round(len(tokens) * masked_lm_prob))))
|
||||
|
||||
# Note(mingdachen):
|
||||
# By default, we set the probabilities to favor shorter ngram sequences.
|
||||
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
|
||||
pvals = 1. / np.arange(1, max_ngrams + 1)
|
||||
pvals /= pvals.sum(keepdims=True)
|
||||
|
||||
if favor_longer_ngram:
|
||||
pvals = pvals[::-1]
|
||||
|
||||
ngram_indexes = []
|
||||
for idx in range(len(cand_indexes)):
|
||||
ngram_index = []
|
||||
for n in ngrams:
|
||||
ngram_index.append(cand_indexes[idx:idx + n])
|
||||
ngram_indexes.append(ngram_index)
|
||||
|
||||
np_rng.shuffle(ngram_indexes)
|
||||
|
||||
masked_lms = []
|
||||
covered_indexes = set()
|
||||
for cand_index_set in ngram_indexes:
|
||||
if len(masked_lms) >= num_to_predict:
|
||||
break
|
||||
if not cand_index_set:
|
||||
continue
|
||||
# Note(mingdachen):
|
||||
# Skip current piece if they are covered in lm masking or previous ngrams.
|
||||
for index_set in cand_index_set[0]:
|
||||
for index in index_set:
|
||||
if index in covered_indexes:
|
||||
continue
|
||||
|
||||
n = np_rng.choice(ngrams[:len(cand_index_set)],
|
||||
p=pvals[:len(cand_index_set)] /
|
||||
pvals[:len(cand_index_set)].sum(keepdims=True))
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
# Note(mingdachen):
|
||||
# Repeatedly looking for a candidate that does not exceed the
|
||||
# maximum number of predictions by trying shorter ngrams.
|
||||
while len(masked_lms) + len(index_set) > num_to_predict:
|
||||
if n == 0:
|
||||
break
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
# If adding a whole-word mask would exceed the maximum number of
|
||||
# predictions, then just skip this candidate.
|
||||
if len(masked_lms) + len(index_set) > num_to_predict:
|
||||
continue
|
||||
is_any_index_covered = False
|
||||
for index in index_set:
|
||||
if index in covered_indexes:
|
||||
is_any_index_covered = True
|
||||
break
|
||||
if is_any_index_covered:
|
||||
continue
|
||||
for index in index_set:
|
||||
covered_indexes.add(index)
|
||||
|
||||
masked_token = None
|
||||
# 80% of the time, replace with [MASK]
|
||||
if np_rng.random() < 0.8:
|
||||
masked_token = mask_id
|
||||
else:
|
||||
# 10% of the time, keep original
|
||||
if np_rng.random() < 0.5:
|
||||
masked_token = tokens[index]
|
||||
# 10% of the time, replace with random word
|
||||
else:
|
||||
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
|
||||
|
||||
output_tokens[index] = masked_token
|
||||
|
||||
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
|
||||
assert len(masked_lms) <= num_to_predict
|
||||
|
||||
np_rng.shuffle(ngram_indexes)
|
||||
|
||||
select_indexes = set()
|
||||
if do_permutation:
|
||||
for cand_index_set in ngram_indexes:
|
||||
if len(select_indexes) >= num_to_predict:
|
||||
break
|
||||
if not cand_index_set:
|
||||
continue
|
||||
# Note(mingdachen):
|
||||
# Skip current piece if they are covered in lm masking or previous ngrams.
|
||||
for index_set in cand_index_set[0]:
|
||||
for index in index_set:
|
||||
if index in covered_indexes or index in select_indexes:
|
||||
continue
|
||||
|
||||
n = np.random.choice(ngrams[:len(cand_index_set)],
|
||||
p=pvals[:len(cand_index_set)] /
|
||||
pvals[:len(cand_index_set)].sum(keepdims=True))
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
|
||||
while len(select_indexes) + len(index_set) > num_to_predict:
|
||||
if n == 0:
|
||||
break
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
# If adding a whole-word mask would exceed the maximum number of
|
||||
# predictions, then just skip this candidate.
|
||||
if len(select_indexes) + len(index_set) > num_to_predict:
|
||||
continue
|
||||
is_any_index_covered = False
|
||||
for index in index_set:
|
||||
if index in covered_indexes or index in select_indexes:
|
||||
is_any_index_covered = True
|
||||
break
|
||||
if is_any_index_covered:
|
||||
continue
|
||||
for index in index_set:
|
||||
select_indexes.add(index)
|
||||
assert len(select_indexes) <= num_to_predict
|
||||
|
||||
select_indexes = sorted(select_indexes)
|
||||
permute_indexes = list(select_indexes)
|
||||
np_rng.shuffle(permute_indexes)
|
||||
orig_token = list(output_tokens)
|
||||
|
||||
for src_i, tgt_i in zip(select_indexes, permute_indexes):
|
||||
output_tokens[src_i] = orig_token[tgt_i]
|
||||
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
|
||||
|
||||
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
||||
|
||||
for p in masked_lms:
|
||||
masked_lm_positions.append(p.index)
|
||||
masked_lm_labels.append(p.label)
|
||||
|
||||
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
|
||||
|
||||
|
||||
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
||||
masked_labels, pad_id, max_seq_length):
|
||||
"""Pad sequences and convert them to numpy."""
|
||||
|
||||
# Some checks.
|
||||
num_tokens = len(tokens)
|
||||
padding_length = max_seq_length - num_tokens
|
||||
assert padding_length >= 0
|
||||
assert len(tokentypes) == num_tokens
|
||||
assert len(masked_positions) == len(masked_labels)
|
||||
|
||||
# Tokens and token types.
|
||||
filler = [pad_id] * padding_length
|
||||
tokens_np = np.array(tokens + filler, dtype=np.int64)
|
||||
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
|
||||
|
||||
# Padding mask.
|
||||
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
|
||||
dtype=np.int64)
|
||||
|
||||
# Lables and loss mask.
|
||||
labels = [-1] * max_seq_length
|
||||
loss_mask = [0] * max_seq_length
|
||||
for i in range(len(masked_positions)):
|
||||
assert masked_positions[i] < num_tokens
|
||||
labels[masked_positions[i]] = masked_labels[i]
|
||||
loss_mask[masked_positions[i]] = 1
|
||||
labels_np = np.array(labels, dtype=np.int64)
|
||||
loss_mask_np = np.array(loss_mask, dtype=np.int64)
|
||||
|
||||
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
|
||||
|
||||
|
||||
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed, skip_warmup,
|
||||
binary_head,
|
||||
dataset_type='standard_bert'):
|
||||
|
||||
if len(data_prefix) == 1:
|
||||
return _build_train_valid_test_datasets(data_prefix[0],
|
||||
data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type=dataset_type)
|
||||
# Blending dataset.
|
||||
# Parse the values.
|
||||
output = get_datasets_weights_and_num_samples(data_prefix,
|
||||
train_valid_test_num_samples)
|
||||
prefixes, weights, datasets_train_valid_test_num_samples = output
|
||||
|
||||
# Build individual datasets.
|
||||
train_datasets = []
|
||||
valid_datasets = []
|
||||
test_datasets = []
|
||||
for i in range(len(prefixes)):
|
||||
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
|
||||
prefixes[i], data_impl, splits_string,
|
||||
datasets_train_valid_test_num_samples[i],
|
||||
max_seq_length, masked_lm_prob, short_seq_prob,
|
||||
seed, skip_warmup, binary_head, dataset_type=dataset_type)
|
||||
if train_ds:
|
||||
train_datasets.append(train_ds)
|
||||
if valid_ds:
|
||||
valid_datasets.append(valid_ds)
|
||||
if test_ds:
|
||||
test_datasets.append(test_ds)
|
||||
|
||||
# Blend.
|
||||
blending_train_dataset = None
|
||||
if train_datasets:
|
||||
blending_train_dataset = BlendableDataset(train_datasets, weights)
|
||||
blending_valid_dataset = None
|
||||
if valid_datasets:
|
||||
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
|
||||
blending_test_dataset = None
|
||||
if test_datasets:
|
||||
blending_test_dataset = BlendableDataset(test_datasets, weights)
|
||||
|
||||
return (blending_train_dataset, blending_valid_dataset,
|
||||
blending_test_dataset)
|
||||
|
||||
|
||||
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed, skip_warmup,
|
||||
binary_head,
|
||||
dataset_type='standard_bert'):
|
||||
logger = get_dist_logger()
|
||||
|
||||
if dataset_type not in DSET_TYPES:
|
||||
raise ValueError("Invalid dataset_type: ", dataset_type)
|
||||
|
||||
# Indexed dataset.
|
||||
indexed_dataset = get_indexed_dataset_(data_prefix,
|
||||
data_impl,
|
||||
skip_warmup)
|
||||
|
||||
if dataset_type == DSET_TYPE_ICT:
|
||||
args = get_args()
|
||||
title_dataset = get_indexed_dataset_(args.titles_data_path,
|
||||
data_impl,
|
||||
skip_warmup)
|
||||
|
||||
# Get start and end indices of train/valid/train into doc-idx
|
||||
# Note that doc-idx is designed to be num-docs + 1 so we can
|
||||
# easily iterate over it.
|
||||
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
|
||||
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
|
||||
|
||||
# Print stats about the splits.
|
||||
logger.info('\n > dataset split:')
|
||||
|
||||
def print_split_stats(name, index):
|
||||
start_index = indexed_dataset.doc_idx[splits[index]]
|
||||
end_index = indexed_dataset.doc_idx[splits[index + 1]]
|
||||
logger.info('\n {}:'.format(name) +
|
||||
'\n document indices in [{}, {}) total of {} documents'.format(
|
||||
splits[index],
|
||||
splits[index + 1],
|
||||
splits[index + 1] - splits[index]) +
|
||||
'\n sentence indices in [{}, {}) total of {} sentences'.format(
|
||||
start_index,
|
||||
end_index,
|
||||
end_index - start_index),
|
||||
ranks=[0])
|
||||
print_split_stats('train', 0)
|
||||
print_split_stats('validation', 1)
|
||||
print_split_stats('test', 2)
|
||||
|
||||
def build_dataset(index, name):
|
||||
from .bert_dataset import BertDataset
|
||||
dataset = None
|
||||
if splits[index + 1] > splits[index]:
|
||||
# Get the pointer to the original doc-idx so we can set it later.
|
||||
doc_idx_ptr = indexed_dataset.get_doc_idx()
|
||||
# Slice the doc-idx
|
||||
start_index = splits[index]
|
||||
# Add +1 so we can index into the dataset to get the upper bound.
|
||||
end_index = splits[index + 1] + 1
|
||||
# New doc_idx view.
|
||||
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
|
||||
# Build the dataset accordingly.
|
||||
kwargs = dict(
|
||||
name=name,
|
||||
data_prefix=data_prefix,
|
||||
num_epochs=None,
|
||||
max_num_samples=train_valid_test_num_samples[index],
|
||||
max_seq_length=max_seq_length,
|
||||
seed=seed,
|
||||
binary_head=binary_head
|
||||
)
|
||||
|
||||
if dataset_type == DSET_TYPE_ICT:
|
||||
args = get_args()
|
||||
dataset = ICTDataset(
|
||||
block_dataset=indexed_dataset,
|
||||
title_dataset=title_dataset,
|
||||
query_in_block_prob=args.query_in_block_prob,
|
||||
use_one_sent_docs=args.use_one_sent_docs,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
dataset = BertDataset(
|
||||
indexed_dataset=indexed_dataset,
|
||||
masked_lm_prob=masked_lm_prob,
|
||||
short_seq_prob=short_seq_prob,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Set the original pointer so dataset remains the main dataset.
|
||||
indexed_dataset.set_doc_idx(doc_idx_ptr)
|
||||
# Checks.
|
||||
assert indexed_dataset.doc_idx[0] == 0
|
||||
assert indexed_dataset.doc_idx.shape[0] == \
|
||||
(total_num_of_documents + 1)
|
||||
return dataset
|
||||
|
||||
train_dataset = build_dataset(0, 'train')
|
||||
valid_dataset = build_dataset(1, 'valid')
|
||||
test_dataset = build_dataset(2, 'test')
|
||||
|
||||
return (train_dataset, valid_dataset, test_dataset)
|
||||
|
||||
|
||||
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
|
||||
logger = get_dist_logger()
|
||||
start_time = time.time()
|
||||
indexed_dataset = make_indexed_dataset(data_prefix,
|
||||
data_impl,
|
||||
skip_warmup)
|
||||
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
|
||||
logger.info('\n > building dataset index ...', ranks=[0])
|
||||
logger.info('\n > finished creating indexed dataset in {:4f} '
|
||||
'seconds'.format(time.time() - start_time), ranks=[0])
|
||||
logger.info('\n > indexed dataset stats:' +
|
||||
'\n number of documents: {}'.format(
|
||||
indexed_dataset.doc_idx.shape[0] - 1) +
|
||||
'\n number of sentences: {}'.format(
|
||||
indexed_dataset.sizes.shape[0]),
|
||||
ranks=[0]
|
||||
)
|
||||
|
||||
return indexed_dataset
|
||||
|
||||
|
||||
def get_train_valid_test_split_(splits_string, size):
|
||||
""" Get dataset splits from comma or '/' separated string list."""
|
||||
|
||||
splits = []
|
||||
if splits_string.find(',') != -1:
|
||||
splits = [float(s) for s in splits_string.split(',')]
|
||||
elif splits_string.find('/') != -1:
|
||||
splits = [float(s) for s in splits_string.split('/')]
|
||||
else:
|
||||
splits = [float(splits_string)]
|
||||
while len(splits) < 3:
|
||||
splits.append(0.)
|
||||
splits = splits[:3]
|
||||
splits_sum = sum(splits)
|
||||
assert splits_sum > 0.0
|
||||
splits = [split / splits_sum for split in splits]
|
||||
splits_index = [0]
|
||||
for index, split in enumerate(splits):
|
||||
splits_index.append(splits_index[index] +
|
||||
int(round(split * float(size))))
|
||||
diff = splits_index[-1] - size
|
||||
for index in range(1, len(splits_index)):
|
||||
splits_index[index] -= diff
|
||||
assert len(splits_index) == 4
|
||||
assert splits_index[-1] == size
|
||||
return splits_index
|
||||
717
examples/tutorial/sequence_parallel/data/datasets/helpers.cpp
Normal file
717
examples/tutorial/sequence_parallel/data/datasets/helpers.cpp
Normal file
@@ -0,0 +1,717 @@
|
||||
/*
|
||||
coding=utf-8
|
||||
Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
/* Helper methods for fast index mapping builds */
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <math.h>
|
||||
#include <stdexcept>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <random>
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace std;
|
||||
|
||||
const int32_t LONG_SENTENCE_LEN = 512;
|
||||
|
||||
|
||||
void build_blending_indices(py::array_t<uint8_t>& dataset_index,
|
||||
py::array_t<int64_t>& dataset_sample_index,
|
||||
const py::array_t<double>& weights,
|
||||
const int32_t num_datasets,
|
||||
const int64_t size, const bool verbose) {
|
||||
/* Given multiple datasets and a weighting array, build samples
|
||||
such that it follows those wieghts.*/
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "> building indices for blendable datasets ..." << std::endl;
|
||||
}
|
||||
|
||||
// Get the pointer access without the checks.
|
||||
auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
|
||||
auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
|
||||
auto weights_ptr = weights.unchecked<1>();
|
||||
|
||||
// Initialize buffer for number of samples used for each dataset.
|
||||
int64_t current_samples[num_datasets];
|
||||
for(int64_t i = 0; i < num_datasets; ++i) {
|
||||
current_samples[i] = 0;
|
||||
}
|
||||
|
||||
// For each sample:
|
||||
for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {
|
||||
|
||||
// Determine where the max error in sampling is happening.
|
||||
auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
|
||||
int64_t max_error_index = 0;
|
||||
double max_error = weights_ptr[0] * sample_idx_double -
|
||||
static_cast<double>(current_samples[0]);
|
||||
for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
|
||||
double error = weights_ptr[dataset_idx] * sample_idx_double -
|
||||
static_cast<double>(current_samples[dataset_idx]);
|
||||
if (error > max_error) {
|
||||
max_error = error;
|
||||
max_error_index = dataset_idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Populate the indices.
|
||||
dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
|
||||
dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];
|
||||
|
||||
// Update the total samples.
|
||||
current_samples[max_error_index] += 1;
|
||||
|
||||
}
|
||||
|
||||
// print info
|
||||
if (verbose) {
|
||||
std::cout << " > sample ratios:" << std::endl;
|
||||
for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
|
||||
auto ratio = static_cast<double>(current_samples[dataset_idx]) /
|
||||
static_cast<double>(size);
|
||||
std::cout << " dataset " << dataset_idx << ", input: " <<
|
||||
weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
|
||||
const py::array_t<int32_t>& doc_idx_,
|
||||
const int32_t seq_length,
|
||||
const int32_t num_epochs,
|
||||
const int64_t tokens_per_epoch) {
|
||||
/* Sample index (sample_idx) is used for gpt2 like dataset for which
|
||||
the documents are flattened and the samples are built based on this
|
||||
1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
|
||||
where [..., 0] contains the index into `doc_idx` and [..., 1] is the
|
||||
starting offset in that document.*/
|
||||
|
||||
// Consistency checks.
|
||||
assert(seq_length > 1);
|
||||
assert(num_epochs > 0);
|
||||
assert(tokens_per_epoch > 1);
|
||||
|
||||
// Remove bound checks.
|
||||
auto sizes = sizes_.unchecked<1>();
|
||||
auto doc_idx = doc_idx_.unchecked<1>();
|
||||
|
||||
// Mapping and it's length (1D).
|
||||
int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
|
||||
int32_t* sample_idx = new int32_t[2*(num_samples+1)];
|
||||
|
||||
cout << " using:" << endl << std::flush;
|
||||
cout << " number of documents: " <<
|
||||
doc_idx_.shape(0) / num_epochs << endl << std::flush;
|
||||
cout << " number of epochs: " << num_epochs <<
|
||||
endl << std::flush;
|
||||
cout << " sequence length: " << seq_length <<
|
||||
endl << std::flush;
|
||||
cout << " total number of samples: " << num_samples <<
|
||||
endl << std::flush;
|
||||
|
||||
// Index into sample_idx.
|
||||
int64_t sample_index = 0;
|
||||
// Index into doc_idx.
|
||||
int64_t doc_idx_index = 0;
|
||||
// Begining offset for each document.
|
||||
int32_t doc_offset = 0;
|
||||
// Start with first document and no offset.
|
||||
sample_idx[2 * sample_index] = doc_idx_index;
|
||||
sample_idx[2 * sample_index + 1] = doc_offset;
|
||||
++sample_index;
|
||||
|
||||
while (sample_index <= num_samples) {
|
||||
// Start with a fresh sequence.
|
||||
int32_t remaining_seq_length = seq_length + 1;
|
||||
while (remaining_seq_length != 0) {
|
||||
// Get the document length.
|
||||
auto doc_id = doc_idx[doc_idx_index];
|
||||
auto doc_length = sizes[doc_id] - doc_offset;
|
||||
// And add it to the current sequence.
|
||||
remaining_seq_length -= doc_length;
|
||||
// If we have more than a full sequence, adjust offset and set
|
||||
// remaining length to zero so we return from the while loop.
|
||||
// Note that -1 here is for the same reason we have -1 in
|
||||
// `_num_epochs` calculations.
|
||||
if (remaining_seq_length <= 0) {
|
||||
doc_offset += (remaining_seq_length + doc_length - 1);
|
||||
remaining_seq_length = 0;
|
||||
} else {
|
||||
// Otherwise, start from the begining of the next document.
|
||||
++doc_idx_index;
|
||||
doc_offset = 0;
|
||||
}
|
||||
}
|
||||
// Record the sequence.
|
||||
sample_idx[2 * sample_index] = doc_idx_index;
|
||||
sample_idx[2 * sample_index + 1] = doc_offset;
|
||||
++sample_index;
|
||||
}
|
||||
|
||||
// Method to deallocate memory.
|
||||
py::capsule free_when_done(sample_idx, [](void *mem_) {
|
||||
int32_t *mem = reinterpret_cast<int32_t*>(mem_);
|
||||
delete[] mem;
|
||||
});
|
||||
|
||||
// Return the numpy array.
|
||||
const auto byte_size = sizeof(int32_t);
|
||||
return py::array(std::vector<int64_t>{num_samples+1, 2}, // shape
|
||||
{2*byte_size, byte_size}, // C-style contiguous strides
|
||||
sample_idx, // the data pointer
|
||||
free_when_done); // numpy array references
|
||||
|
||||
}
|
||||
|
||||
|
||||
inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
|
||||
const int32_t max_length,
|
||||
std::mt19937& rand32_gen) {
|
||||
/* Training sample length. */
|
||||
if (short_seq_ratio == 0) {
|
||||
return max_length;
|
||||
}
|
||||
const auto random_number = rand32_gen();
|
||||
if ((random_number % short_seq_ratio) == 0) {
|
||||
return 2 + random_number % (max_length - 1);
|
||||
}
|
||||
return max_length;
|
||||
}
|
||||
|
||||
|
||||
template<typename DocIdx>
|
||||
py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
|
||||
const py::array_t<int32_t>& sizes_,
|
||||
const int32_t num_epochs,
|
||||
const uint64_t max_num_samples,
|
||||
const int32_t max_seq_length,
|
||||
const double short_seq_prob,
|
||||
const int32_t seed,
|
||||
const bool verbose,
|
||||
const int32_t min_num_sent) {
|
||||
/* Build a mapping of (start-index, end-index, sequence-length) where
|
||||
start and end index are the indices of the sentences in the sample
|
||||
and sequence-length is the target sequence length.
|
||||
*/
|
||||
|
||||
// Consistency checks.
|
||||
assert(num_epochs > 0);
|
||||
assert(max_seq_length > 1);
|
||||
assert(short_seq_prob >= 0.0);
|
||||
assert(short_seq_prob <= 1.0);
|
||||
assert(seed > 0);
|
||||
|
||||
// Remove bound checks.
|
||||
auto docs = docs_.unchecked<1>();
|
||||
auto sizes = sizes_.unchecked<1>();
|
||||
|
||||
// For efficiency, convert probability to ratio. Note: rand() generates int.
|
||||
int32_t short_seq_ratio = 0;
|
||||
if (short_seq_prob > 0) {
|
||||
short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
const auto sent_start_index = docs[0];
|
||||
const auto sent_end_index = docs[docs_.shape(0) - 1];
|
||||
const auto num_sentences = sent_end_index - sent_start_index;
|
||||
cout << " using:" << endl << std::flush;
|
||||
cout << " number of documents: " << docs_.shape(0) - 1 <<
|
||||
endl << std::flush;
|
||||
cout << " sentences range: [" << sent_start_index <<
|
||||
", " << sent_end_index << ")" << endl << std::flush;
|
||||
cout << " total number of sentences: " << num_sentences <<
|
||||
endl << std::flush;
|
||||
cout << " number of epochs: " << num_epochs <<
|
||||
endl << std::flush;
|
||||
cout << " maximum number of samples: " << max_num_samples <<
|
||||
endl << std::flush;
|
||||
cout << " maximum sequence length: " << max_seq_length <<
|
||||
endl << std::flush;
|
||||
cout << " short sequence probability: " << short_seq_prob <<
|
||||
endl << std::flush;
|
||||
cout << " short sequence ration (1/prob): " << short_seq_ratio <<
|
||||
endl << std::flush;
|
||||
cout << " seed: " << seed << endl <<
|
||||
std::flush;
|
||||
}
|
||||
|
||||
// Mapping and it's length (1D).
|
||||
int64_t num_samples = -1;
|
||||
DocIdx* maps = NULL;
|
||||
|
||||
// Perform two iterations, in the first iteration get the size
|
||||
// and allocate memory and in the second iteration populate the map.
|
||||
bool second = false;
|
||||
for (int32_t iteration=0; iteration<2; ++iteration) {
|
||||
|
||||
// Set the seed so both iterations produce the same results.
|
||||
std::mt19937 rand32_gen(seed);
|
||||
|
||||
// Set the flag on second iteration.
|
||||
second = (iteration == 1);
|
||||
|
||||
// Counters:
|
||||
uint64_t empty_docs = 0;
|
||||
uint64_t one_sent_docs = 0;
|
||||
uint64_t long_sent_docs = 0;
|
||||
|
||||
// Current map index.
|
||||
uint64_t map_index = 0;
|
||||
|
||||
// For each epoch:
|
||||
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
|
||||
if (map_index >= max_num_samples) {
|
||||
if (verbose && (!second)) {
|
||||
cout << " reached " << max_num_samples << " samples after "
|
||||
<< epoch << " epochs ..." << endl << std::flush;
|
||||
}
|
||||
break;
|
||||
}
|
||||
// For each document:
|
||||
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
|
||||
|
||||
// Document sentences are in [sent_index_first, sent_index_last)
|
||||
const auto sent_index_first = docs[doc];
|
||||
const auto sent_index_last = docs[doc + 1];
|
||||
|
||||
// At the begining of the document previous index is the
|
||||
// start index.
|
||||
auto prev_start_index = sent_index_first;
|
||||
|
||||
// Remaining documents.
|
||||
auto num_remain_sent = sent_index_last - sent_index_first;
|
||||
|
||||
// Some bookkeeping
|
||||
if ((epoch == 0) && (!second)) {
|
||||
if (num_remain_sent == 0) {
|
||||
++empty_docs;
|
||||
}
|
||||
if (num_remain_sent == 1) {
|
||||
++one_sent_docs;
|
||||
}
|
||||
}
|
||||
|
||||
// Detect documents with long sentences.
|
||||
bool contains_long_sentence = false;
|
||||
if (num_remain_sent > 1) {
|
||||
for (auto sent_index=sent_index_first;
|
||||
sent_index < sent_index_last; ++sent_index) {
|
||||
if (sizes[sent_index] > LONG_SENTENCE_LEN){
|
||||
if ((epoch == 0) && (!second)) {
|
||||
++long_sent_docs;
|
||||
}
|
||||
contains_long_sentence = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we have more than two sentences.
|
||||
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
|
||||
|
||||
// Set values.
|
||||
auto seq_len = int32_t{0};
|
||||
auto num_sent = int32_t{0};
|
||||
auto target_seq_len = get_target_sample_len(short_seq_ratio,
|
||||
max_seq_length,
|
||||
rand32_gen);
|
||||
|
||||
// Loop through sentences.
|
||||
for (auto sent_index=sent_index_first;
|
||||
sent_index < sent_index_last; ++sent_index) {
|
||||
|
||||
// Add the size and number of sentences.
|
||||
seq_len += sizes[sent_index];
|
||||
++num_sent;
|
||||
--num_remain_sent;
|
||||
|
||||
// If we have reached the target length.
|
||||
// and if not only one sentence is left in the document.
|
||||
// and if we have at least two sentneces.
|
||||
// and if we have reached end of the document.
|
||||
if (((seq_len >= target_seq_len) &&
|
||||
(num_remain_sent > 1) &&
|
||||
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
|
||||
|
||||
// Check for overflow.
|
||||
if ((3 * map_index + 2) >
|
||||
std::numeric_limits<int64_t>::max()) {
|
||||
cout << "number of samples exceeded maximum "
|
||||
<< "allowed by type int64: "
|
||||
<< std::numeric_limits<int64_t>::max()
|
||||
<< endl;
|
||||
throw std::overflow_error("Number of samples");
|
||||
}
|
||||
|
||||
// Populate the map.
|
||||
if (second) {
|
||||
const auto map_index_0 = 3 * map_index;
|
||||
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
|
||||
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
|
||||
maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
|
||||
}
|
||||
|
||||
// Update indices / counters.
|
||||
++map_index;
|
||||
prev_start_index = sent_index + 1;
|
||||
target_seq_len = get_target_sample_len(short_seq_ratio,
|
||||
max_seq_length,
|
||||
rand32_gen);
|
||||
seq_len = 0;
|
||||
num_sent = 0;
|
||||
}
|
||||
|
||||
} // for (auto sent_index=sent_index_first; ...
|
||||
} // if (num_remain_sent > 1) {
|
||||
} // for (int doc=0; doc < num_docs; ++doc) {
|
||||
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
||||
|
||||
if (!second) {
|
||||
if (verbose) {
|
||||
cout << " number of empty documents: " << empty_docs <<
|
||||
endl << std::flush;
|
||||
cout << " number of documents with one sentence: " <<
|
||||
one_sent_docs << endl << std::flush;
|
||||
cout << " number of documents with long sentences: " <<
|
||||
long_sent_docs << endl << std::flush;
|
||||
cout << " will create mapping for " << map_index <<
|
||||
" samples" << endl << std::flush;
|
||||
}
|
||||
assert(maps == NULL);
|
||||
assert(num_samples < 0);
|
||||
maps = new DocIdx[3*map_index];
|
||||
num_samples = static_cast<int64_t>(map_index);
|
||||
}
|
||||
|
||||
} // for (int iteration=0; iteration < 2; ++iteration) {
|
||||
|
||||
// Shuffle.
|
||||
// We need a 64 bit random number generator as we might have more
|
||||
// than 2 billion samples.
|
||||
std::mt19937_64 rand64_gen(seed + 1);
|
||||
for (auto i=(num_samples - 1); i > 0; --i) {
|
||||
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
|
||||
const auto i0 = 3 * i;
|
||||
const auto j0 = 3 * j;
|
||||
// Swap values.
|
||||
swap(maps[i0], maps[j0]);
|
||||
swap(maps[i0 + 1], maps[j0 + 1]);
|
||||
swap(maps[i0 + 2], maps[j0 + 2]);
|
||||
}
|
||||
|
||||
// Method to deallocate memory.
|
||||
py::capsule free_when_done(maps, [](void *mem_) {
|
||||
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
|
||||
delete[] mem;
|
||||
});
|
||||
|
||||
// Return the numpy array.
|
||||
const auto byte_size = sizeof(DocIdx);
|
||||
return py::array(std::vector<int64_t>{num_samples, 3}, // shape
|
||||
{3*byte_size, byte_size}, // C-style contiguous strides
|
||||
maps, // the data pointer
|
||||
free_when_done); // numpy array references
|
||||
|
||||
}
|
||||
|
||||
|
||||
py::array build_mapping(const py::array_t<int64_t>& docs_,
|
||||
const py::array_t<int>& sizes_,
|
||||
const int num_epochs,
|
||||
const uint64_t max_num_samples,
|
||||
const int max_seq_length,
|
||||
const double short_seq_prob,
|
||||
const int seed,
|
||||
const bool verbose,
|
||||
const int32_t min_num_sent) {
|
||||
|
||||
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
|
||||
if (verbose) {
|
||||
cout << " using uint64 for data mapping..." << endl << std::flush;
|
||||
}
|
||||
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
|
||||
max_num_samples, max_seq_length,
|
||||
short_seq_prob, seed, verbose,
|
||||
min_num_sent);
|
||||
} else {
|
||||
if (verbose) {
|
||||
cout << " using uint32 for data mapping..." << endl << std::flush;
|
||||
}
|
||||
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
|
||||
max_num_samples, max_seq_length,
|
||||
short_seq_prob, seed, verbose,
|
||||
min_num_sent);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename DocIdx>
|
||||
py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
|
||||
const py::array_t<int32_t>& sizes_,
|
||||
const py::array_t<int32_t>& titles_sizes_,
|
||||
const int32_t num_epochs,
|
||||
const uint64_t max_num_samples,
|
||||
const int32_t max_seq_length,
|
||||
const int32_t seed,
|
||||
const bool verbose,
|
||||
const bool use_one_sent_blocks) {
|
||||
/* Build a mapping of (start-index, end-index, sequence-length) where
|
||||
start and end index are the indices of the sentences in the sample
|
||||
and sequence-length is the target sequence length.
|
||||
*/
|
||||
|
||||
// Consistency checks.
|
||||
assert(num_epochs > 0);
|
||||
assert(max_seq_length > 1);
|
||||
assert(seed > 0);
|
||||
|
||||
// Remove bound checks.
|
||||
auto docs = docs_.unchecked<1>();
|
||||
auto sizes = sizes_.unchecked<1>();
|
||||
auto titles_sizes = titles_sizes_.unchecked<1>();
|
||||
|
||||
if (verbose) {
|
||||
const auto sent_start_index = docs[0];
|
||||
const auto sent_end_index = docs[docs_.shape(0) - 1];
|
||||
const auto num_sentences = sent_end_index - sent_start_index;
|
||||
cout << " using:" << endl << std::flush;
|
||||
cout << " number of documents: " << docs_.shape(0) - 1 <<
|
||||
endl << std::flush;
|
||||
cout << " sentences range: [" << sent_start_index <<
|
||||
", " << sent_end_index << ")" << endl << std::flush;
|
||||
cout << " total number of sentences: " << num_sentences <<
|
||||
endl << std::flush;
|
||||
cout << " number of epochs: " << num_epochs <<
|
||||
endl << std::flush;
|
||||
cout << " maximum number of samples: " << max_num_samples <<
|
||||
endl << std::flush;
|
||||
cout << " maximum sequence length: " << max_seq_length <<
|
||||
endl << std::flush;
|
||||
cout << " seed: " << seed << endl <<
|
||||
std::flush;
|
||||
}
|
||||
|
||||
// Mapping and its length (1D).
|
||||
int64_t num_samples = -1;
|
||||
DocIdx* maps = NULL;
|
||||
|
||||
// Acceptable number of sentences per block.
|
||||
int min_num_sent = 2;
|
||||
if (use_one_sent_blocks) {
|
||||
min_num_sent = 1;
|
||||
}
|
||||
|
||||
// Perform two iterations, in the first iteration get the size
|
||||
// and allocate memory and in the second iteration populate the map.
|
||||
bool second = false;
|
||||
for (int32_t iteration=0; iteration<2; ++iteration) {
|
||||
|
||||
// Set the flag on second iteration.
|
||||
second = (iteration == 1);
|
||||
|
||||
// Current map index.
|
||||
uint64_t map_index = 0;
|
||||
|
||||
uint64_t empty_docs = 0;
|
||||
uint64_t one_sent_docs = 0;
|
||||
uint64_t long_sent_docs = 0;
|
||||
// For each epoch:
|
||||
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
|
||||
// assign every block a unique id
|
||||
int32_t block_id = 0;
|
||||
|
||||
if (map_index >= max_num_samples) {
|
||||
if (verbose && (!second)) {
|
||||
cout << " reached " << max_num_samples << " samples after "
|
||||
<< epoch << " epochs ..." << endl << std::flush;
|
||||
}
|
||||
break;
|
||||
}
|
||||
// For each document:
|
||||
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
|
||||
|
||||
// Document sentences are in [sent_index_first, sent_index_last)
|
||||
const auto sent_index_first = docs[doc];
|
||||
const auto sent_index_last = docs[doc + 1];
|
||||
const auto target_seq_len = max_seq_length - titles_sizes[doc];
|
||||
|
||||
// At the begining of the document previous index is the
|
||||
// start index.
|
||||
auto prev_start_index = sent_index_first;
|
||||
|
||||
// Remaining documents.
|
||||
auto num_remain_sent = sent_index_last - sent_index_first;
|
||||
|
||||
// Some bookkeeping
|
||||
if ((epoch == 0) && (!second)) {
|
||||
if (num_remain_sent == 0) {
|
||||
++empty_docs;
|
||||
}
|
||||
if (num_remain_sent == 1) {
|
||||
++one_sent_docs;
|
||||
}
|
||||
}
|
||||
// Detect documents with long sentences.
|
||||
bool contains_long_sentence = false;
|
||||
if (num_remain_sent >= min_num_sent) {
|
||||
for (auto sent_index=sent_index_first;
|
||||
sent_index < sent_index_last; ++sent_index) {
|
||||
if (sizes[sent_index] > LONG_SENTENCE_LEN){
|
||||
if ((epoch == 0) && (!second)) {
|
||||
++long_sent_docs;
|
||||
}
|
||||
contains_long_sentence = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// If we have enough sentences and no long sentences.
|
||||
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
|
||||
|
||||
// Set values.
|
||||
auto seq_len = int32_t{0};
|
||||
auto num_sent = int32_t{0};
|
||||
|
||||
// Loop through sentences.
|
||||
for (auto sent_index=sent_index_first;
|
||||
sent_index < sent_index_last; ++sent_index) {
|
||||
|
||||
// Add the size and number of sentences.
|
||||
seq_len += sizes[sent_index];
|
||||
++num_sent;
|
||||
--num_remain_sent;
|
||||
|
||||
// If we have reached the target length.
|
||||
// and there are an acceptable number of sentences left
|
||||
// and if we have at least the minimum number of sentences.
|
||||
// or if we have reached end of the document.
|
||||
if (((seq_len >= target_seq_len) &&
|
||||
(num_remain_sent >= min_num_sent) &&
|
||||
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
|
||||
|
||||
// Populate the map.
|
||||
if (second) {
|
||||
const auto map_index_0 = 4 * map_index;
|
||||
// Each sample has 4 items: the starting sentence index, ending sentence index,
|
||||
// the index of the document from which the block comes (used for fetching titles)
|
||||
// and the unique id of the block (used for creating block indexes)
|
||||
|
||||
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
|
||||
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
|
||||
maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
|
||||
maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
|
||||
}
|
||||
|
||||
// Update indices / counters.
|
||||
++map_index;
|
||||
++block_id;
|
||||
prev_start_index = sent_index + 1;
|
||||
seq_len = 0;
|
||||
num_sent = 0;
|
||||
}
|
||||
} // for (auto sent_index=sent_index_first; ...
|
||||
} // if (num_remain_sent > 1) {
|
||||
} // for (int doc=0; doc < num_docs; ++doc) {
|
||||
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
||||
|
||||
if (!second) {
|
||||
if (verbose) {
|
||||
cout << " number of empty documents: " << empty_docs <<
|
||||
endl << std::flush;
|
||||
cout << " number of documents with one sentence: " <<
|
||||
one_sent_docs << endl << std::flush;
|
||||
cout << " number of documents with long sentences: " <<
|
||||
long_sent_docs << endl << std::flush;
|
||||
cout << " will create mapping for " << map_index <<
|
||||
" samples" << endl << std::flush;
|
||||
}
|
||||
assert(maps == NULL);
|
||||
assert(num_samples < 0);
|
||||
maps = new DocIdx[4*map_index];
|
||||
num_samples = static_cast<int64_t>(map_index);
|
||||
}
|
||||
|
||||
} // for (int iteration=0; iteration < 2; ++iteration) {
|
||||
|
||||
// Shuffle.
|
||||
// We need a 64 bit random number generator as we might have more
|
||||
// than 2 billion samples.
|
||||
std::mt19937_64 rand64_gen(seed + 1);
|
||||
for (auto i=(num_samples - 1); i > 0; --i) {
|
||||
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
|
||||
const auto i0 = 4 * i;
|
||||
const auto j0 = 4 * j;
|
||||
// Swap values.
|
||||
swap(maps[i0], maps[j0]);
|
||||
swap(maps[i0 + 1], maps[j0 + 1]);
|
||||
swap(maps[i0 + 2], maps[j0 + 2]);
|
||||
swap(maps[i0 + 3], maps[j0 + 3]);
|
||||
}
|
||||
|
||||
// Method to deallocate memory.
|
||||
py::capsule free_when_done(maps, [](void *mem_) {
|
||||
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
|
||||
delete[] mem;
|
||||
});
|
||||
|
||||
// Return the numpy array.
|
||||
const auto byte_size = sizeof(DocIdx);
|
||||
return py::array(std::vector<int64_t>{num_samples, 4}, // shape
|
||||
{4*byte_size, byte_size}, // C-style contiguous strides
|
||||
maps, // the data pointer
|
||||
free_when_done); // numpy array references
|
||||
|
||||
}
|
||||
|
||||
py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
|
||||
const py::array_t<int>& sizes_,
|
||||
const py::array_t<int>& titles_sizes_,
|
||||
const int num_epochs,
|
||||
const uint64_t max_num_samples,
|
||||
const int max_seq_length,
|
||||
const int seed,
|
||||
const bool verbose,
|
||||
const bool use_one_sent_blocks) {
|
||||
|
||||
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
|
||||
if (verbose) {
|
||||
cout << " using uint64 for data mapping..." << endl << std::flush;
|
||||
}
|
||||
return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
|
||||
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
|
||||
} else {
|
||||
if (verbose) {
|
||||
cout << " using uint32 for data mapping..." << endl << std::flush;
|
||||
}
|
||||
return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
|
||||
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(helpers, m) {
|
||||
m.def("build_mapping", &build_mapping);
|
||||
m.def("build_blocks_mapping", &build_blocks_mapping);
|
||||
m.def("build_sample_idx", &build_sample_idx);
|
||||
m.def("build_blending_indices", &build_blending_indices);
|
||||
}
|
||||
156
examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py
Normal file
156
examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import itertools
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from megatron import get_tokenizer
|
||||
from megatron import get_args
|
||||
from megatron.data.dataset_utils import get_indexed_dataset_
|
||||
from megatron.data.realm_dataset_utils import get_block_samples_mapping
|
||||
|
||||
def make_attention_mask(source_block, target_block):
|
||||
"""
|
||||
Returns a 2-dimensional (2-D) attention mask
|
||||
:param source_block: 1-D array
|
||||
:param target_block: 1-D array
|
||||
"""
|
||||
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
|
||||
mask = mask.astype(np.int64)
|
||||
# (source_length, target_length)
|
||||
return mask
|
||||
|
||||
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
|
||||
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
|
||||
rather than for training, since it is only built with a single epoch sample mapping.
|
||||
"""
|
||||
args = get_args()
|
||||
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
|
||||
titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
|
||||
|
||||
kwargs = dict(
|
||||
name='full',
|
||||
block_dataset=block_dataset,
|
||||
title_dataset=titles_dataset,
|
||||
data_prefix=args.data_path,
|
||||
num_epochs=1,
|
||||
max_num_samples=None,
|
||||
max_seq_length=args.seq_length,
|
||||
seed=1,
|
||||
query_in_block_prob=query_in_block_prob,
|
||||
use_titles=use_titles,
|
||||
use_one_sent_docs=args.use_one_sent_docs
|
||||
)
|
||||
dataset = ICTDataset(**kwargs)
|
||||
return dataset
|
||||
|
||||
|
||||
class ICTDataset(Dataset):
|
||||
"""Dataset containing sentences and their blocks for an inverse cloze task."""
|
||||
def __init__(self, name, block_dataset, title_dataset, data_prefix,
|
||||
num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
|
||||
seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
|
||||
self.name = name
|
||||
self.seed = seed
|
||||
self.max_seq_length = max_seq_length
|
||||
self.query_in_block_prob = query_in_block_prob
|
||||
self.block_dataset = block_dataset
|
||||
self.title_dataset = title_dataset
|
||||
self.rng = random.Random(self.seed)
|
||||
self.use_titles = use_titles
|
||||
self.use_one_sent_docs = use_one_sent_docs
|
||||
|
||||
self.samples_mapping = get_block_samples_mapping(
|
||||
block_dataset, title_dataset, data_prefix, num_epochs,
|
||||
max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
|
||||
self.tokenizer = get_tokenizer()
|
||||
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
|
||||
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
|
||||
self.cls_id = self.tokenizer.cls
|
||||
self.sep_id = self.tokenizer.sep
|
||||
self.mask_id = self.tokenizer.mask
|
||||
self.pad_id = self.tokenizer.pad
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples_mapping)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
|
||||
sample_data = self.samples_mapping[idx]
|
||||
start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple()
|
||||
|
||||
if self.use_titles:
|
||||
title = self.title_dataset[int(doc_idx)]
|
||||
title_pad_offset = 3 + len(title)
|
||||
else:
|
||||
title = None
|
||||
title_pad_offset = 2
|
||||
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
|
||||
assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1
|
||||
|
||||
# randint() is inclusive for Python rng
|
||||
rand_sent_idx = self.rng.randint(0, len(block) - 1)
|
||||
|
||||
# keep the query in the context query_in_block_prob fraction of the time.
|
||||
if self.rng.random() < self.query_in_block_prob:
|
||||
query = block[rand_sent_idx].copy()
|
||||
else:
|
||||
query = block.pop(rand_sent_idx)
|
||||
|
||||
# still need to truncate because blocks are concluded when
|
||||
# the sentence lengths have exceeded max_seq_length.
|
||||
query = query[:self.max_seq_length - 2]
|
||||
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
|
||||
|
||||
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
|
||||
context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)
|
||||
|
||||
query_mask = make_attention_mask(query_tokens, query_tokens)
|
||||
context_mask = make_attention_mask(context_tokens, context_tokens)
|
||||
|
||||
block_data = sample_data.as_array()
|
||||
|
||||
sample = {
|
||||
'query_tokens': query_tokens,
|
||||
'query_mask': query_mask,
|
||||
'query_pad_mask': query_pad_mask,
|
||||
'context_tokens': context_tokens,
|
||||
'context_mask': context_mask,
|
||||
'context_pad_mask': context_pad_mask,
|
||||
'block_data': block_data,
|
||||
}
|
||||
|
||||
return sample
|
||||
|
||||
def get_block(self, start_idx, end_idx, doc_idx):
|
||||
"""Get the IDs for an evidence block plus the title of the corresponding document"""
|
||||
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
|
||||
title = self.title_dataset[int(doc_idx)]
|
||||
|
||||
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
|
||||
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
|
||||
|
||||
return block_tokens, block_pad_mask
|
||||
|
||||
def get_null_block(self):
|
||||
"""Get empty block and title - used in REALM pretraining"""
|
||||
block, title = [], []
|
||||
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
|
||||
|
||||
return block_tokens, block_pad_mask
|
||||
|
||||
def concat_and_pad_tokens(self, tokens, title=None):
|
||||
"""Concat with special tokens and pad sequence to self.max_seq_length"""
|
||||
tokens = list(tokens)
|
||||
if title is None:
|
||||
tokens = [self.cls_id] + tokens + [self.sep_id]
|
||||
else:
|
||||
title = list(title)
|
||||
tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
|
||||
assert len(tokens) <= self.max_seq_length
|
||||
|
||||
num_pad = self.max_seq_length - len(tokens)
|
||||
pad_mask = [1] * len(tokens) + [0] * num_pad
|
||||
tokens += [self.pad_id] * num_pad
|
||||
|
||||
return np.array(tokens), np.array(pad_mask)
|
||||
@@ -0,0 +1,569 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# copied from fairseq/fairseq/data/indexed_dataset.py
|
||||
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
|
||||
# other slight modifications to remove fairseq dependencies
|
||||
# Added document index to index file and made it accessible.
|
||||
# An empty sentence no longer separates documents.
|
||||
|
||||
from functools import lru_cache
|
||||
import os
|
||||
import shutil
|
||||
import struct
|
||||
from itertools import accumulate
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def __best_fitting_dtype(vocab_size=None):
|
||||
if vocab_size is not None and vocab_size < 65500:
|
||||
return np.uint16
|
||||
else:
|
||||
return np.int32
|
||||
|
||||
|
||||
def get_available_dataset_impl():
|
||||
return ['lazy', 'cached', 'mmap']
|
||||
|
||||
|
||||
def infer_dataset_impl(path):
|
||||
if IndexedDataset.exists(path):
|
||||
with open(index_file_path(path), 'rb') as f:
|
||||
magic = f.read(8)
|
||||
if magic == IndexedDataset._HDR_MAGIC:
|
||||
return 'cached'
|
||||
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
|
||||
return 'mmap'
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
print(f"Dataset does not exist: {path}")
|
||||
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
|
||||
return None
|
||||
|
||||
|
||||
def make_builder(out_file, impl, vocab_size=None):
|
||||
if impl == 'mmap':
|
||||
return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
|
||||
else:
|
||||
return IndexedDatasetBuilder(out_file)
|
||||
|
||||
|
||||
def make_dataset(path, impl, skip_warmup=False):
|
||||
if not IndexedDataset.exists(path):
|
||||
print(f"Dataset does not exist: {path}")
|
||||
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
|
||||
return None
|
||||
if impl == 'infer':
|
||||
impl = infer_dataset_impl(path)
|
||||
if impl == 'lazy' and IndexedDataset.exists(path):
|
||||
return IndexedDataset(path)
|
||||
elif impl == 'cached' and IndexedDataset.exists(path):
|
||||
return IndexedCachedDataset(path)
|
||||
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
|
||||
return MMapIndexedDataset(path, skip_warmup)
|
||||
print(f"Unknown dataset implementation: {impl}")
|
||||
return None
|
||||
|
||||
|
||||
def dataset_exists(path, impl):
|
||||
if impl == 'mmap':
|
||||
return MMapIndexedDataset.exists(path)
|
||||
else:
|
||||
return IndexedDataset.exists(path)
|
||||
|
||||
|
||||
def read_longs(f, n):
|
||||
a = np.empty(n, dtype=np.int64)
|
||||
f.readinto(a)
|
||||
return a
|
||||
|
||||
|
||||
def write_longs(f, a):
|
||||
f.write(np.array(a, dtype=np.int64))
|
||||
|
||||
|
||||
dtypes = {
|
||||
1: np.uint8,
|
||||
2: np.int8,
|
||||
3: np.int16,
|
||||
4: np.int32,
|
||||
5: np.int64,
|
||||
6: np.float,
|
||||
7: np.double,
|
||||
8: np.uint16
|
||||
}
|
||||
|
||||
|
||||
def code(dtype):
|
||||
for k in dtypes.keys():
|
||||
if dtypes[k] == dtype:
|
||||
return k
|
||||
raise ValueError(dtype)
|
||||
|
||||
|
||||
def index_file_path(prefix_path):
|
||||
return prefix_path + '.idx'
|
||||
|
||||
|
||||
def data_file_path(prefix_path):
|
||||
return prefix_path + '.bin'
|
||||
|
||||
|
||||
def create_doc_idx(sizes):
|
||||
doc_idx = [0]
|
||||
for i, s in enumerate(sizes):
|
||||
if s == 0:
|
||||
doc_idx.append(i + 1)
|
||||
return doc_idx
|
||||
|
||||
|
||||
class IndexedDataset(torch.utils.data.Dataset):
|
||||
"""Loader for IndexedDataset"""
|
||||
_HDR_MAGIC = b'TNTIDX\x00\x00'
|
||||
|
||||
def __init__(self, path):
|
||||
super().__init__()
|
||||
self.path = path
|
||||
self.data_file = None
|
||||
self.read_index(path)
|
||||
|
||||
def read_index(self, path):
|
||||
with open(index_file_path(path), 'rb') as f:
|
||||
magic = f.read(8)
|
||||
assert magic == self._HDR_MAGIC, (
|
||||
'Index file doesn\'t match expected format. '
|
||||
'Make sure that --dataset-impl is configured properly.'
|
||||
)
|
||||
version = f.read(8)
|
||||
assert struct.unpack('<Q', version) == (1,)
|
||||
code, self.element_size = struct.unpack('<QQ', f.read(16))
|
||||
self.dtype = dtypes[code]
|
||||
self._len, self.s = struct.unpack('<QQ', f.read(16))
|
||||
self.doc_count = struct.unpack('<Q', f.read(8))
|
||||
self.dim_offsets = read_longs(f, self._len + 1)
|
||||
self.data_offsets = read_longs(f, self._len + 1)
|
||||
self.sizes = read_longs(f, self.s)
|
||||
self.doc_idx = read_longs(f, self.doc_count)
|
||||
|
||||
def read_data(self, path):
|
||||
self.data_file = open(data_file_path(path), 'rb', buffering=0)
|
||||
|
||||
def check_index(self, i):
|
||||
if i < 0 or i >= self._len:
|
||||
raise IndexError('index out of range')
|
||||
|
||||
def __del__(self):
|
||||
if self.data_file:
|
||||
self.data_file.close()
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if not self.data_file:
|
||||
self.read_data(self.path)
|
||||
if isinstance(idx, int):
|
||||
i = idx
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
return a
|
||||
elif isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||
sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
|
||||
size = sum(sizes)
|
||||
a = np.empty(size, dtype=self.dtype)
|
||||
self.data_file.seek(self.data_offsets[start] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
offsets = list(accumulate(sizes))
|
||||
sents = np.split(a, offsets[:-1])
|
||||
return sents
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
def size(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return (
|
||||
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False # avoid prefetching to save memory
|
||||
|
||||
|
||||
class IndexedCachedDataset(IndexedDataset):
|
||||
|
||||
def __init__(self, path):
|
||||
super().__init__(path)
|
||||
self.cache = None
|
||||
self.cache_index = {}
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return True
|
||||
|
||||
def prefetch(self, indices):
|
||||
if all(i in self.cache_index for i in indices):
|
||||
return
|
||||
if not self.data_file:
|
||||
self.read_data(self.path)
|
||||
indices = sorted(set(indices))
|
||||
total_size = 0
|
||||
for i in indices:
|
||||
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
|
||||
self.cache = np.empty(total_size, dtype=self.dtype)
|
||||
ptx = 0
|
||||
self.cache_index.clear()
|
||||
for i in indices:
|
||||
self.cache_index[i] = ptx
|
||||
size = self.data_offsets[i + 1] - self.data_offsets[i]
|
||||
a = self.cache[ptx: ptx + size]
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
ptx += size
|
||||
if self.data_file:
|
||||
# close and delete data file after prefetch so we can pickle
|
||||
self.data_file.close()
|
||||
self.data_file = None
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
i = idx
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
ptx = self.cache_index[i]
|
||||
np.copyto(a, self.cache[ptx: ptx + a.size])
|
||||
return a
|
||||
elif isinstance(idx, slice):
|
||||
# Hack just to make this work, can optimizer later if necessary
|
||||
sents = []
|
||||
for i in range(*idx.indices(len(self))):
|
||||
sents.append(self[i])
|
||||
return sents
|
||||
|
||||
|
||||
class IndexedDatasetBuilder(object):
|
||||
element_sizes = {
|
||||
np.uint8: 1,
|
||||
np.int8: 1,
|
||||
np.int16: 2,
|
||||
np.int32: 4,
|
||||
np.int64: 8,
|
||||
np.float: 4,
|
||||
np.double: 8
|
||||
}
|
||||
|
||||
def __init__(self, out_file, dtype=np.int32):
|
||||
self.out_file = open(out_file, 'wb')
|
||||
self.dtype = dtype
|
||||
self.data_offsets = [0]
|
||||
self.dim_offsets = [0]
|
||||
self.sizes = []
|
||||
self.element_size = self.element_sizes[self.dtype]
|
||||
self.doc_idx = [0]
|
||||
|
||||
def add_item(self, tensor):
|
||||
bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
|
||||
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
|
||||
for s in tensor.size():
|
||||
self.sizes.append(s)
|
||||
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
|
||||
|
||||
def end_document(self):
|
||||
self.doc_idx.append(len(self.sizes))
|
||||
|
||||
def merge_file_(self, another_file):
|
||||
index = IndexedDataset(another_file)
|
||||
assert index.dtype == self.dtype
|
||||
|
||||
begin = self.data_offsets[-1]
|
||||
for offset in index.data_offsets[1:]:
|
||||
self.data_offsets.append(begin + offset)
|
||||
self.sizes.extend(index.sizes)
|
||||
begin = self.dim_offsets[-1]
|
||||
for dim_offset in index.dim_offsets[1:]:
|
||||
self.dim_offsets.append(begin + dim_offset)
|
||||
|
||||
with open(data_file_path(another_file), 'rb') as f:
|
||||
while True:
|
||||
data = f.read(1024)
|
||||
if data:
|
||||
self.out_file.write(data)
|
||||
else:
|
||||
break
|
||||
|
||||
def finalize(self, index_file):
|
||||
self.out_file.close()
|
||||
index = open(index_file, 'wb')
|
||||
index.write(b'TNTIDX\x00\x00')
|
||||
index.write(struct.pack('<Q', 1))
|
||||
index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
|
||||
index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
|
||||
index.write(struct.pack('<Q', len(self.doc_idx)))
|
||||
write_longs(index, self.dim_offsets)
|
||||
write_longs(index, self.data_offsets)
|
||||
write_longs(index, self.sizes)
|
||||
write_longs(index, self.doc_idx)
|
||||
index.close()
|
||||
|
||||
|
||||
def _warmup_mmap_file(path):
|
||||
with open(path, 'rb') as stream:
|
||||
while stream.read(100 * 1024 * 1024):
|
||||
pass
|
||||
|
||||
|
||||
class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
class Index(object):
|
||||
_HDR_MAGIC = b'MMIDIDX\x00\x00'
|
||||
|
||||
@classmethod
|
||||
def writer(cls, path, dtype):
|
||||
class _Writer(object):
|
||||
def __enter__(self):
|
||||
self._file = open(path, 'wb')
|
||||
|
||||
self._file.write(cls._HDR_MAGIC)
|
||||
self._file.write(struct.pack('<Q', 1))
|
||||
self._file.write(struct.pack('<B', code(dtype)))
|
||||
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _get_pointers(sizes):
|
||||
dtype_size = dtype().itemsize
|
||||
address = 0
|
||||
pointers = []
|
||||
|
||||
for size in sizes:
|
||||
pointers.append(address)
|
||||
address += size * dtype_size
|
||||
|
||||
return pointers
|
||||
|
||||
def write(self, sizes, doc_idx):
|
||||
pointers = self._get_pointers(sizes)
|
||||
|
||||
self._file.write(struct.pack('<Q', len(sizes)))
|
||||
self._file.write(struct.pack('<Q', len(doc_idx)))
|
||||
|
||||
sizes = np.array(sizes, dtype=np.int32)
|
||||
self._file.write(sizes.tobytes(order='C'))
|
||||
del sizes
|
||||
|
||||
pointers = np.array(pointers, dtype=np.int64)
|
||||
self._file.write(pointers.tobytes(order='C'))
|
||||
del pointers
|
||||
|
||||
doc_idx = np.array(doc_idx, dtype=np.int64)
|
||||
self._file.write(doc_idx.tobytes(order='C'))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._file.close()
|
||||
|
||||
return _Writer()
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
with open(path, 'rb') as stream:
|
||||
magic_test = stream.read(9)
|
||||
assert self._HDR_MAGIC == magic_test, (
|
||||
'Index file doesn\'t match expected format. '
|
||||
'Make sure that --dataset-impl is configured properly.'
|
||||
)
|
||||
version = struct.unpack('<Q', stream.read(8))
|
||||
assert (1,) == version
|
||||
|
||||
dtype_code, = struct.unpack('<B', stream.read(1))
|
||||
self._dtype = dtypes[dtype_code]
|
||||
self._dtype_size = self._dtype().itemsize
|
||||
|
||||
self._len = struct.unpack('<Q', stream.read(8))[0]
|
||||
self._doc_count = struct.unpack('<Q', stream.read(8))[0]
|
||||
offset = stream.tell()
|
||||
|
||||
if not skip_warmup:
|
||||
print(" warming up index mmap file...")
|
||||
_warmup_mmap_file(path)
|
||||
|
||||
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
print(" reading sizes...")
|
||||
self._sizes = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int32,
|
||||
count=self._len,
|
||||
offset=offset)
|
||||
print(" reading pointers...")
|
||||
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
|
||||
offset=offset + self._sizes.nbytes)
|
||||
print(" reading document index...")
|
||||
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
|
||||
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._doc_idx
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
return self._pointers[i], self._sizes[i]
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
super().__init__()
|
||||
|
||||
self._path = None
|
||||
self._index = None
|
||||
self._bin_buffer = None
|
||||
|
||||
self._do_init(path, skip_warmup)
|
||||
|
||||
def __getstate__(self):
|
||||
return self._path
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._do_init(state)
|
||||
|
||||
def _do_init(self, path, skip_warmup):
|
||||
self._path = path
|
||||
self._index = self.Index(index_file_path(self._path), skip_warmup)
|
||||
|
||||
if not skip_warmup:
|
||||
print(" warming up data mmap file...")
|
||||
_warmup_mmap_file(data_file_path(self._path))
|
||||
print(" creating numpy buffer of mmap...")
|
||||
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
|
||||
print(" creating memory view of numpy buffer...")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
del self._index
|
||||
|
||||
def __len__(self):
|
||||
return len(self._index)
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
ptr, size = self._index[idx]
|
||||
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
|
||||
count=size, offset=ptr)
|
||||
return np_array
|
||||
elif isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||
ptr = self._index._pointers[start]
|
||||
sizes = self._index._sizes[idx]
|
||||
offsets = list(accumulate(sizes))
|
||||
total_size = sum(sizes)
|
||||
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
|
||||
count=total_size, offset=ptr)
|
||||
sents = np.split(np_array, offsets[:-1])
|
||||
return sents
|
||||
|
||||
def get(self, idx, offset=0, length=None):
|
||||
""" Retrieves a single item from the dataset with the option to only
|
||||
return a portion of the item.
|
||||
|
||||
get(idx) is the same as [idx] but get() does not support slicing.
|
||||
"""
|
||||
ptr, size = self._index[idx]
|
||||
if length is None:
|
||||
length = size - offset
|
||||
ptr += offset * np.dtype(self._index.dtype).itemsize
|
||||
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
|
||||
count=length, offset=ptr)
|
||||
return np_array
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._index.sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._index.doc_idx
|
||||
|
||||
def get_doc_idx(self):
|
||||
return self._index._doc_idx
|
||||
|
||||
def set_doc_idx(self, doc_idx_):
|
||||
self._index._doc_idx = doc_idx_
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return (
|
||||
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
|
||||
)
|
||||
|
||||
|
||||
class MMapIndexedDatasetBuilder(object):
|
||||
def __init__(self, out_file, dtype=np.int64):
|
||||
self._data_file = open(out_file, 'wb')
|
||||
self._dtype = dtype
|
||||
self._sizes = []
|
||||
self._doc_idx = [0]
|
||||
|
||||
def add_item(self, tensor):
|
||||
np_array = np.array(tensor.numpy(), dtype=self._dtype)
|
||||
self._data_file.write(np_array.tobytes(order='C'))
|
||||
self._sizes.append(np_array.size)
|
||||
|
||||
def end_document(self):
|
||||
self._doc_idx.append(len(self._sizes))
|
||||
|
||||
def merge_file_(self, another_file):
|
||||
# Concatenate index
|
||||
index = MMapIndexedDataset.Index(index_file_path(another_file))
|
||||
assert index.dtype == self._dtype
|
||||
|
||||
for size in index.sizes:
|
||||
self._sizes.append(size)
|
||||
|
||||
# Concatenate data
|
||||
with open(data_file_path(another_file), 'rb') as f:
|
||||
shutil.copyfileobj(f, self._data_file)
|
||||
|
||||
def finalize(self, index_file):
|
||||
self._data_file.close()
|
||||
|
||||
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
|
||||
index.write(self._sizes, self._doc_idx)
|
||||
@@ -0,0 +1,125 @@
|
||||
# This file isn't really a formal automated test, it's just a place to
|
||||
# put some code used during development and manual testing of
|
||||
# indexed_dataset.
|
||||
|
||||
from megatron.data import indexed_dataset
|
||||
from megatron.tokenizer import build_tokenizer
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
sys.path.append(os.path.join(script_dir, "../../../"))
|
||||
|
||||
|
||||
def test_indexed_dataset(args):
|
||||
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
|
||||
tokenizer = build_tokenizer(args)
|
||||
print(len(ds.doc_idx))
|
||||
print(len(ds))
|
||||
print(ds.doc_idx[-1])
|
||||
if ds.supports_prefetch:
|
||||
# just prefetch the whole thing in test (so assume it is small)
|
||||
ds.prefetch(range(len(ds)))
|
||||
if args.count > len(ds.doc_idx) - 1:
|
||||
args.count = len(ds.doc_idx) - 1
|
||||
|
||||
for i in range(args.count):
|
||||
start = ds.doc_idx[i]
|
||||
end = ds.doc_idx[i + 1]
|
||||
ids = ds[start:end]
|
||||
print(f"Document {i}:")
|
||||
print("--------------")
|
||||
for s in ids:
|
||||
assert len(s) > 0
|
||||
l = s.data.tolist()
|
||||
text = tokenizer.detokenize(l)
|
||||
print(text)
|
||||
print("---")
|
||||
|
||||
|
||||
def test_indexed_dataset_get(args):
|
||||
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
|
||||
tokenizer = build_tokenizer(args)
|
||||
size = ds.sizes[0]
|
||||
print(f"size: {size}")
|
||||
full = ds.get(0)
|
||||
print(full)
|
||||
# print(tokenizer.detokenize(full.data.tolist()))
|
||||
print("---")
|
||||
end = ds.get(0, offset=size - 10)
|
||||
print(end)
|
||||
# print(tokenizer.detokenize(end.data.tolist()))
|
||||
|
||||
start = ds.get(0, length=10)
|
||||
print(start)
|
||||
# print(tokenizer.detokenize(start.data.tolist()))
|
||||
|
||||
part = ds.get(0, offset=2, length=8)
|
||||
print(part)
|
||||
# print(tokenizer.detokenize(part.data.tolist()))
|
||||
|
||||
# def test_albert_dataset(args):
|
||||
# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
|
||||
# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
|
||||
# # ds = AlbertDataset(idataset, tokenizer)
|
||||
# ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl,
|
||||
# args.epochs, args.max_num_samples,
|
||||
# args.masked_lm_prob, args.seq_length,
|
||||
# args.short_seq_prob, args.seed)
|
||||
# truncated = 0
|
||||
# total = 0
|
||||
# for i, s in enumerate(ds):
|
||||
# ids = s['text']
|
||||
# tokens = ds.tokenizer.convert_ids_to_tokens(ids)
|
||||
# print(tokens)
|
||||
# if i >= args.count-1:
|
||||
# exit()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data', type=str, help='prefix to data files')
|
||||
parser.add_argument('--dataset-impl', type=str, default='infer',
|
||||
choices=['lazy', 'cached', 'mmap', 'infer'])
|
||||
parser.add_argument('--count', type=int, default=10,
|
||||
help='Number of samples/documents to print')
|
||||
|
||||
group = parser.add_argument_group(title='tokenizer')
|
||||
group.add_argument('--tokenizer-type', type=str, required=True,
|
||||
choices=['BertWordPieceLowerCase',
|
||||
'GPT2BPETokenizer'],
|
||||
help='What type of tokenizer to use.')
|
||||
group.add_argument('--vocab-file', type=str, default=None,
|
||||
help='Path to the vocab file')
|
||||
group.add_argument('--merge-file', type=str, default=None,
|
||||
help='Path to the BPE merge file (if necessary).')
|
||||
|
||||
parser.add_argument('--epochs', type=int, default=5,
|
||||
help='Number of epochs to plan for')
|
||||
parser.add_argument('--max-num-samples', type=int, default=None,
|
||||
help='Maximum number of samples to plan for')
|
||||
parser.add_argument('--masked-lm-prob', type=float, default=0.15,
|
||||
help='probability of masking tokens')
|
||||
parser.add_argument('--seq-length', type=int, default=512,
|
||||
help='maximum sequence length')
|
||||
parser.add_argument('--short-seq-prob', type=float, default=0.1,
|
||||
help='probability of creating a short sequence')
|
||||
parser.add_argument('--seed', type=int, default=1234,
|
||||
help='random seed')
|
||||
args = parser.parse_args()
|
||||
args.rank = 0
|
||||
args.make_vocab_size_divisible_by = 128
|
||||
args.tensor_model_parallel_size = 1
|
||||
|
||||
if args.dataset_impl == "infer":
|
||||
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
|
||||
|
||||
# test_albert_dataset(args)
|
||||
test_indexed_dataset_get(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
IMPL=cached
|
||||
python ../preprocess_data.py \
|
||||
--input test_samples.json \
|
||||
--vocab vocab.txt \
|
||||
--dataset-impl ${IMPL} \
|
||||
--output-prefix test_samples_${IMPL} \
|
||||
--workers 1 \
|
||||
--log-interval 2
|
||||
@@ -0,0 +1,38 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from .tokenizer import build_tokenizer
|
||||
|
||||
|
||||
_TOKENIZER = None
|
||||
_PADDED_VOCAB_SIZE = -1
|
||||
|
||||
|
||||
def initialize_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0):
|
||||
tokenizer, padded_vocab_size = build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids)
|
||||
global _TOKENIZER, _PADDED_VOCAB_SIZE
|
||||
_TOKENIZER = tokenizer
|
||||
_PADDED_VOCAB_SIZE = padded_vocab_size
|
||||
|
||||
|
||||
def get_tokenizer():
|
||||
global _TOKENIZER
|
||||
return _TOKENIZER
|
||||
|
||||
|
||||
def get_padded_vocab_size():
|
||||
global _PADDED_VOCAB_SIZE
|
||||
return _PADDED_VOCAB_SIZE
|
||||
@@ -0,0 +1,431 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tokenization classes."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import re
|
||||
import unicodedata
|
||||
import six
|
||||
|
||||
|
||||
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
|
||||
"""Checks whether the casing config is consistent with the checkpoint name."""
|
||||
|
||||
# The casing has to be passed in by the user and there is no explicit check
|
||||
# as to whether it matches the checkpoint. The casing information probably
|
||||
# should have been stored in the bert_config.json file, but it's not, so
|
||||
# we have to heuristically detect it to validate.
|
||||
|
||||
if not init_checkpoint:
|
||||
return
|
||||
|
||||
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
|
||||
if m is None:
|
||||
return
|
||||
|
||||
model_name = m.group(1)
|
||||
|
||||
lower_models = [
|
||||
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
|
||||
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
|
||||
]
|
||||
|
||||
cased_models = [
|
||||
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
|
||||
"multi_cased_L-12_H-768_A-12"
|
||||
]
|
||||
|
||||
is_bad_config = False
|
||||
if model_name in lower_models and not do_lower_case:
|
||||
is_bad_config = True
|
||||
actual_flag = "False"
|
||||
case_name = "lowercased"
|
||||
opposite_flag = "True"
|
||||
|
||||
if model_name in cased_models and do_lower_case:
|
||||
is_bad_config = True
|
||||
actual_flag = "True"
|
||||
case_name = "cased"
|
||||
opposite_flag = "False"
|
||||
|
||||
if is_bad_config:
|
||||
raise ValueError(
|
||||
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
|
||||
"However, `%s` seems to be a %s model, so you "
|
||||
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
|
||||
"how the model was pre-training. If this error is wrong, please "
|
||||
"just comment out this check." % (actual_flag, init_checkpoint,
|
||||
model_name, case_name, opposite_flag))
|
||||
|
||||
|
||||
def convert_to_unicode(text):
|
||||
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text.decode("utf-8", "ignore")
|
||||
elif isinstance(text, unicode):
|
||||
return text
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def printable_text(text):
|
||||
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
||||
|
||||
# These functions want `str` for both Python2 and Python3, but in one case
|
||||
# it's a Unicode string and in the other it's a byte string.
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, unicode):
|
||||
return text.encode("utf-8")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
index = 0
|
||||
with open(vocab_file, "r") as reader:
|
||||
while True:
|
||||
token = convert_to_unicode(reader.readline())
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab[token] = index
|
||||
index += 1
|
||||
return vocab
|
||||
|
||||
|
||||
def convert_by_vocab(vocab, items):
|
||||
"""Converts a sequence of [tokens|ids] using the vocab."""
|
||||
output = []
|
||||
for item in items:
|
||||
output.append(vocab[item])
|
||||
return output
|
||||
|
||||
|
||||
def convert_tokens_to_ids(vocab, tokens):
|
||||
return convert_by_vocab(vocab, tokens)
|
||||
|
||||
|
||||
def convert_ids_to_tokens(inv_vocab, ids):
|
||||
return convert_by_vocab(inv_vocab, ids)
|
||||
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
class FullTokenizer(object):
|
||||
"""Runs end-to-end tokenization."""
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True):
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
for token in self.basic_tokenizer.tokenize(text):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return convert_by_vocab(self.vocab, tokens)
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
return convert_by_vocab(self.inv_vocab, ids)
|
||||
|
||||
@staticmethod
|
||||
def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
|
||||
def clean_up_tokenization(out_string):
|
||||
""" Clean up a list of simple English tokenization artifacts
|
||||
like spaces before punctuations and abbreviated forms.
|
||||
"""
|
||||
out_string = (
|
||||
out_string.replace(" .", ".")
|
||||
.replace(" ?", "?")
|
||||
.replace(" !", "!")
|
||||
.replace(" ,", ",")
|
||||
.replace(" ' ", "'")
|
||||
.replace(" n't", "n't")
|
||||
.replace(" 'm", "'m")
|
||||
.replace(" 's", "'s")
|
||||
.replace(" 've", "'ve")
|
||||
.replace(" 're", "'re")
|
||||
)
|
||||
return out_string
|
||||
|
||||
text = ' '.join(tokens).replace(' ##', '').strip()
|
||||
if clean_up_tokenization_spaces:
|
||||
clean_text = clean_up_tokenization(text)
|
||||
return clean_text
|
||||
else:
|
||||
return text
|
||||
|
||||
def vocab_size(self):
|
||||
return len(self.vocab)
|
||||
|
||||
|
||||
class BasicTokenizer(object):
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
||||
def __init__(self, do_lower_case=True):
|
||||
"""Constructs a BasicTokenizer.
|
||||
|
||||
Args:
|
||||
do_lower_case: Whether to lower case the input.
|
||||
"""
|
||||
self.do_lower_case = do_lower_case
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text."""
|
||||
text = convert_to_unicode(text)
|
||||
text = self._clean_text(text)
|
||||
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
# matter since the English models were not trained on any Chinese data
|
||||
# and generally don't have any Chinese data in them (there are Chinese
|
||||
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||
# words in the English Wikipedia.).
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if self._is_chinese_char(cp):
|
||||
output.append(" ")
|
||||
output.append(char)
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
class WordpieceTokenizer(object):
|
||||
"""Runs WordPiece tokenization."""
|
||||
|
||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text into its word pieces.
|
||||
|
||||
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||
using the given vocabulary.
|
||||
|
||||
For example:
|
||||
input = "unaffable"
|
||||
output = ["un", "##aff", "##able"]
|
||||
|
||||
Args:
|
||||
text: A single token or whitespace separated tokens. This should have
|
||||
already been passed through `BasicTokenizer.
|
||||
|
||||
Returns:
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
|
||||
text = convert_to_unicode(text)
|
||||
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically control characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat in ("Cc", "Cf"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
||||
256
examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py
Normal file
256
examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py
Normal file
@@ -0,0 +1,256 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Megatron tokenizers."""
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
|
||||
from .bert_tokenization import FullTokenizer as FullBertTokenizer
|
||||
|
||||
|
||||
def build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0):
|
||||
"""Initialize tokenizer."""
|
||||
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
|
||||
print('> building {} tokenizer ...'.format(tokenizer_type),
|
||||
flush=True)
|
||||
|
||||
# Select and instantiate the tokenizer.
|
||||
if tokenizer_type == 'BertWordPieceLowerCase':
|
||||
tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file,
|
||||
lower_case=True,
|
||||
vocab_extra_ids=vocab_extra_ids)
|
||||
elif tokenizer_type == 'BertWordPieceCase':
|
||||
tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file,
|
||||
lower_case=False,
|
||||
vocab_extra_ids=vocab_extra_ids)
|
||||
else:
|
||||
raise NotImplementedError('{} tokenizer is not '
|
||||
'implemented.'.format(tokenizer_type))
|
||||
|
||||
# Add vocab size.
|
||||
padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size)
|
||||
|
||||
return tokenizer, padded_vocab_size
|
||||
|
||||
|
||||
def _vocab_size_with_padding(orig_vocab_size, make_vocab_size_divisible_by=128):
|
||||
"""Pad vocab size so it is divisible by model parallel size and
|
||||
still having GPU friendly size."""
|
||||
|
||||
after = orig_vocab_size
|
||||
|
||||
if gpc.is_initialized(ParallelMode.TENSOR):
|
||||
multiple = make_vocab_size_divisible_by * gpc.get_world_size(ParallelMode.TENSOR)
|
||||
else:
|
||||
multiple = make_vocab_size_divisible_by
|
||||
while (after % multiple) != 0:
|
||||
after += 1
|
||||
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
|
||||
print(' > padded vocab (size: {}) with {} dummy tokens '
|
||||
'(new size: {})'.format(
|
||||
orig_vocab_size, after - orig_vocab_size, after), flush=True)
|
||||
return after
|
||||
|
||||
|
||||
class AbstractTokenizer(ABC):
|
||||
"""Abstract class for tokenizer."""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def vocab_size(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def vocab(self):
|
||||
"""Dictionary from vocab text token to id token."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def inv_vocab(self):
|
||||
"""Dictionary from vocab id token to text token."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def tokenize(self, text):
|
||||
pass
|
||||
|
||||
def detokenize(self, token_ids):
|
||||
raise NotImplementedError('detokenizer is not implemented for {} '
|
||||
'tokenizer'.format(self.name))
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
raise NotImplementedError('CLS is not provided for {} '
|
||||
'tokenizer'.format(self.name))
|
||||
|
||||
@property
|
||||
def sep(self):
|
||||
raise NotImplementedError('SEP is not provided for {} '
|
||||
'tokenizer'.format(self.name))
|
||||
|
||||
@property
|
||||
def pad(self):
|
||||
raise NotImplementedError('PAD is not provided for {} '
|
||||
'tokenizer'.format(self.name))
|
||||
|
||||
@property
|
||||
def eod(self):
|
||||
raise NotImplementedError('EOD is not provided for {} '
|
||||
'tokenizer'.format(self.name))
|
||||
|
||||
@property
|
||||
def mask(self):
|
||||
raise NotImplementedError('MASK is not provided for {} '
|
||||
'tokenizer'.format(self.name))
|
||||
|
||||
|
||||
class _BertWordPieceTokenizer(AbstractTokenizer):
|
||||
"""Original BERT wordpiece tokenizer."""
|
||||
|
||||
def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0):
|
||||
if lower_case:
|
||||
name = 'BERT Lower Case'
|
||||
else:
|
||||
name = 'BERT Upper Case'
|
||||
super().__init__(name)
|
||||
self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case)
|
||||
self.cls_id = self.tokenizer.vocab['[CLS]']
|
||||
self.sep_id = self.tokenizer.vocab['[SEP]']
|
||||
self.pad_id = self.tokenizer.vocab['[PAD]']
|
||||
self.mask_id = self.tokenizer.vocab['[MASK]']
|
||||
self._additional_special_tokens = []
|
||||
|
||||
# (dsachan) Add BOS and EOS tokens
|
||||
SPECIAL_TOKENS = {'eos_token': '[EOS]',
|
||||
'bos_token': '[BOS]'}
|
||||
self._bos_token = '[BOS]'
|
||||
self.add_token(self._bos_token)
|
||||
self._bos_token_id = self.vocab.get(self._bos_token)
|
||||
|
||||
self._eos_token = '[EOS]'
|
||||
self.add_token(self._eos_token)
|
||||
self._eos_token_id = self.vocab.get(self._eos_token)
|
||||
|
||||
# (dsachan) Add additional special tokens
|
||||
# These can be used as sentinel tokens in T5 model inputs
|
||||
additional_special_tokens = []
|
||||
additional_special_tokens.extend(
|
||||
["<extra_id_{}>".format(i) for i in range(vocab_extra_ids)])
|
||||
self.add_additional_special_tokens(additional_special_tokens)
|
||||
|
||||
def add_token(self, token):
|
||||
if token not in self.vocab:
|
||||
self.inv_vocab[self.vocab_size] = token
|
||||
# self.vocab_size comes from len(vocab)
|
||||
# and it will increase as we add elements
|
||||
self.vocab[token] = self.vocab_size
|
||||
|
||||
def add_additional_special_tokens(self, tokens_list):
|
||||
setattr(self, "additional_special_tokens", tokens_list)
|
||||
for value in tokens_list:
|
||||
self.add_token(value)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.tokenizer.vocab_size()
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
return self.tokenizer.vocab
|
||||
|
||||
@property
|
||||
def inv_vocab(self):
|
||||
return self.tokenizer.inv_vocab
|
||||
|
||||
def tokenize(self, text):
|
||||
text_tokens = self.tokenizer.tokenize(text)
|
||||
return self.tokenizer.convert_tokens_to_ids(text_tokens)
|
||||
|
||||
def decode(self, ids):
|
||||
tokens = self.tokenizer.convert_ids_to_tokens(ids)
|
||||
return self.tokenizer.convert_tokens_to_string(tokens)
|
||||
|
||||
def decode_token_ids(self, token_ids):
|
||||
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
|
||||
exclude_list = ['[PAD]', '[CLS]']
|
||||
non_pads = [t for t in tokens if t not in exclude_list]
|
||||
|
||||
result = ""
|
||||
for s in non_pads:
|
||||
if s.startswith("##"):
|
||||
result += s[2:]
|
||||
else:
|
||||
result += " " + s
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
return self.cls_id
|
||||
|
||||
@property
|
||||
def sep(self):
|
||||
return self.sep_id
|
||||
|
||||
@property
|
||||
def pad(self):
|
||||
return self.pad_id
|
||||
|
||||
@property
|
||||
def mask(self):
|
||||
return self.mask_id
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
""" Beginning of sentence token id """
|
||||
return self._bos_token
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
""" End of sentence token id """
|
||||
return self._eos_token
|
||||
|
||||
@property
|
||||
def additional_special_tokens(self):
|
||||
""" All the additional special tokens you may want to use (list of strings)."""
|
||||
return self._additional_special_tokens
|
||||
|
||||
@property
|
||||
def bos_token_id(self):
|
||||
""" Id of the beginning of sentence token in the vocabulary."""
|
||||
return self._bos_token_id
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
""" Id of the end of sentence token in the vocabulary."""
|
||||
return self._eos_token_id
|
||||
|
||||
@property
|
||||
def additional_special_tokens_ids(self):
|
||||
""" Ids of all the additional special tokens in the vocabulary (list of integers)."""
|
||||
return [self.vocab.get(token) for token in self._additional_special_tokens]
|
||||
|
||||
@additional_special_tokens.setter
|
||||
def additional_special_tokens(self, value):
|
||||
self._additional_special_tokens = value
|
||||
41
examples/tutorial/sequence_parallel/loss_func/bert_loss.py
Normal file
41
examples/tutorial/sequence_parallel/loss_func/bert_loss.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.logging import get_dist_logger
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from .cross_entropy import vocab_cross_entropy
|
||||
|
||||
|
||||
class BertLoss(nn.Module):
|
||||
|
||||
def forward(self,
|
||||
lm_loss,
|
||||
sop_logits,
|
||||
loss_mask,
|
||||
sentence_order):
|
||||
lm_loss_ = lm_loss.float()
|
||||
loss_mask = loss_mask.float()
|
||||
loss_mask_sum = loss_mask.sum()
|
||||
lm_loss = torch.sum(
|
||||
lm_loss_.view(-1) * loss_mask.reshape(-1))
|
||||
|
||||
lm_loss /= loss_mask_sum
|
||||
|
||||
torch.distributed.all_reduce(
|
||||
lm_loss,
|
||||
group=gpc.get_group(ParallelMode.SEQUENCE)
|
||||
)
|
||||
|
||||
if sop_logits is not None:
|
||||
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
|
||||
sentence_order.view(-1),
|
||||
ignore_index=-1)
|
||||
sop_loss = sop_loss.float()
|
||||
loss = lm_loss + sop_loss * gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
else:
|
||||
sop_loss = None
|
||||
loss = lm_loss
|
||||
|
||||
return loss
|
||||
@@ -0,0 +1,75 @@
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
import torch
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
|
||||
class _VocabCrossEntropy(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, vocab_parallel_logits, target):
|
||||
# Maximum value along vocab dimension across all GPUs.
|
||||
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
|
||||
|
||||
# Subtract the maximum value.
|
||||
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
|
||||
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
target_mask = target < 0
|
||||
masked_target = target.clone()
|
||||
masked_target[target_mask] = 0
|
||||
|
||||
# Get predicted-logits = logits[target].
|
||||
# For Simplicity, we convert logits to a 2-D tensor with size
|
||||
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
|
||||
logits_2d = vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1))
|
||||
masked_target_1d = masked_target.view(-1)
|
||||
arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
|
||||
device=logits_2d.device)
|
||||
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
|
||||
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
|
||||
predicted_logits = predicted_logits_1d.view_as(target)
|
||||
predicted_logits[target_mask] = 0.0
|
||||
|
||||
# Sum of exponential of logits along vocab dimension across all GPUs.
|
||||
exp_logits = vocab_parallel_logits
|
||||
torch.exp(vocab_parallel_logits, out=exp_logits)
|
||||
sum_exp_logits = exp_logits.sum(dim=-1)
|
||||
|
||||
# Loss = log(sum(exp(logits))) - predicted-logit.
|
||||
loss = torch.log(sum_exp_logits) - predicted_logits
|
||||
|
||||
# Store softmax, target-mask and masked-target for backward pass.
|
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
|
||||
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
# Retreive tensors from the forward path.
|
||||
softmax, target_mask, masked_target_1d = ctx.saved_tensors
|
||||
|
||||
# All the inputs have softmax as their gradient.
|
||||
grad_input = softmax
|
||||
# For simplicity, work with the 2D gradient.
|
||||
partition_vocab_size = softmax.size()[-1]
|
||||
grad_2d = grad_input.view(-1, partition_vocab_size)
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
|
||||
device=grad_2d.device)
|
||||
grad_2d[arange_1d, masked_target_1d] -= (
|
||||
1.0 - target_mask.view(-1).float())
|
||||
|
||||
# Finally elementwise multiplication with the output gradients.
|
||||
grad_input.mul_(grad_output.unsqueeze(dim=-1))
|
||||
|
||||
return grad_input, None
|
||||
|
||||
|
||||
def vocab_cross_entropy(vocab_logits, target):
|
||||
"""helper function for the cross entropy."""
|
||||
|
||||
return _VocabCrossEntropy.apply(vocab_logits, target)
|
||||
55
examples/tutorial/sequence_parallel/loss_func/utils.py
Normal file
55
examples/tutorial/sequence_parallel/loss_func/utils.py
Normal file
@@ -0,0 +1,55 @@
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
assert numerator % denominator == 0, '{} is not divisible by {}'.format(
|
||||
numerator, denominator)
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator and return
|
||||
the division value."""
|
||||
ensure_divisibility(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
def split_tensor_along_last_dim(tensor, num_partitions,
|
||||
contiguous_split_chunks=False):
|
||||
"""Split a tensor along its last dimension.
|
||||
Arguments:
|
||||
tensor: input tensor.
|
||||
num_partitions: number of partitions to split the tensor
|
||||
contiguous_split_chunks: If True, make each chunk contiguous
|
||||
in memory.
|
||||
"""
|
||||
# Get the size and dimension.
|
||||
last_dim = tensor.dim() - 1
|
||||
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
|
||||
# Split.
|
||||
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
||||
# Note: torch.split does not create contiguous tensors by default.
|
||||
if contiguous_split_chunks:
|
||||
return tuple(chunk.contiguous() for chunk in tensor_list)
|
||||
|
||||
return tensor_list
|
||||
|
||||
|
||||
class VocabUtility:
|
||||
"""Split the vocabulary into `world_size` chunks amd return the
|
||||
first and last index of the vocabulary belonging to the `rank`
|
||||
partition: Note that indices in [fist, last)"""
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
|
||||
rank, world_size):
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f, index_l
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return VocabUtility.vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size, rank, world_size)
|
||||
@@ -0,0 +1 @@
|
||||
from .annealing_lr import AnnealingLR
|
||||
158
examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py
Normal file
158
examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Learning rate decay functions."""
|
||||
|
||||
import math
|
||||
|
||||
|
||||
class AnnealingLR(object):
|
||||
"""Anneals the learning rate."""
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
max_lr,
|
||||
min_lr,
|
||||
warmup_steps,
|
||||
decay_steps,
|
||||
decay_style,
|
||||
use_checkpoint_lr_scheduler=True,
|
||||
override_lr_scheduler=False):
|
||||
|
||||
# Class values.
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.max_lr = float(max_lr)
|
||||
self.min_lr = min_lr
|
||||
assert self.min_lr >= 0.0
|
||||
assert self.max_lr >= self.min_lr
|
||||
|
||||
self.warmup_steps = warmup_steps
|
||||
self.num_steps = 0
|
||||
self.decay_steps = decay_steps
|
||||
assert self.decay_steps > 0
|
||||
assert self.warmup_steps < self.decay_steps
|
||||
|
||||
self.decay_style = decay_style
|
||||
|
||||
self.override_lr_scheduler = override_lr_scheduler
|
||||
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
|
||||
if self.override_lr_scheduler:
|
||||
assert not self.use_checkpoint_lr_scheduler, 'both override and '\
|
||||
'use-checkpoint are set.'
|
||||
|
||||
# Set the learning rate
|
||||
self.step(0)
|
||||
|
||||
def get_lr(self):
|
||||
"""Learning rate decay functions from:
|
||||
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
|
||||
|
||||
# Use linear warmup for the initial part.
|
||||
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
|
||||
return self.max_lr * float(self.num_steps) / \
|
||||
float(self.warmup_steps)
|
||||
|
||||
# If the learning rate is constant, just return the initial value.
|
||||
if self.decay_style == 'constant':
|
||||
return self.max_lr
|
||||
|
||||
# For any steps larger than `self.decay_steps`, use `self.min_lr`.
|
||||
if self.num_steps > self.decay_steps:
|
||||
return self.min_lr
|
||||
|
||||
# If we are done with the warmup period, use the decay style.
|
||||
num_steps_ = self.num_steps - self.warmup_steps
|
||||
decay_steps_ = self.decay_steps - self.warmup_steps
|
||||
decay_ratio = float(num_steps_) / float(decay_steps_)
|
||||
assert decay_ratio >= 0.0
|
||||
assert decay_ratio <= 1.0
|
||||
delta_lr = self.max_lr - self.min_lr
|
||||
|
||||
if self.decay_style == 'linear':
|
||||
coeff = (1.0 - decay_ratio)
|
||||
elif self.decay_style == 'cosine':
|
||||
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
|
||||
else:
|
||||
raise Exception('{} decay style is not supported.'.format(
|
||||
self.decay_style))
|
||||
|
||||
return self.min_lr + coeff * delta_lr
|
||||
|
||||
def step(self, increment=1):
|
||||
"""Set lr for all parameters groups."""
|
||||
self.num_steps += increment
|
||||
new_lr = self.get_lr()
|
||||
for group in self.optimizer.param_groups:
|
||||
group['lr'] = new_lr
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {
|
||||
'max_lr': self.max_lr,
|
||||
'warmup_steps': self.warmup_steps,
|
||||
'num_steps': self.num_steps,
|
||||
'decay_style': self.decay_style,
|
||||
'decay_steps': self.decay_steps,
|
||||
'min_lr': self.min_lr
|
||||
}
|
||||
return state_dict
|
||||
|
||||
def _check_and_set(self, cls_value, sd_value, name):
|
||||
"""Auxiliary function for checking the values in the checkpoint and
|
||||
setting them."""
|
||||
if self.override_lr_scheduler:
|
||||
return cls_value
|
||||
|
||||
if not self.use_checkpoint_lr_scheduler:
|
||||
assert cls_value == sd_value, \
|
||||
f'AnnealingLR: class input value {cls_value} and checkpoint' \
|
||||
f'value {sd_value} for {name} do not match'
|
||||
return sd_value
|
||||
|
||||
def load_state_dict(self, sd):
|
||||
|
||||
if 'start_lr' in sd:
|
||||
max_lr_ = sd['start_lr']
|
||||
else:
|
||||
max_lr_ = sd['max_lr']
|
||||
self.max_lr = self._check_and_set(self.max_lr, max_lr_,
|
||||
'learning rate')
|
||||
|
||||
self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'],
|
||||
'minimum learning rate')
|
||||
|
||||
if 'warmup_iter' in sd:
|
||||
warmup_steps_ = sd['warmup_iter']
|
||||
else:
|
||||
warmup_steps_ = sd['warmup_steps']
|
||||
self.warmup_steps = self._check_and_set(self.warmup_steps,
|
||||
warmup_steps_,
|
||||
'warmup iterations')
|
||||
|
||||
if 'end_iter' in sd:
|
||||
decay_steps_ = sd['end_iter']
|
||||
else:
|
||||
decay_steps_ = sd['decay_steps']
|
||||
self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_,
|
||||
'total number of iterations')
|
||||
self.decay_style = self._check_and_set(self.decay_style,
|
||||
sd['decay_style'],
|
||||
'decay style')
|
||||
|
||||
if 'num_iters' in sd:
|
||||
num_steps = sd['num_iters']
|
||||
else:
|
||||
num_steps = sd['num_steps']
|
||||
self.step(increment=num_steps)
|
||||
2
examples/tutorial/sequence_parallel/model/__init__.py
Normal file
2
examples/tutorial/sequence_parallel/model/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
|
||||
|
||||
282
examples/tutorial/sequence_parallel/model/bert.py
Normal file
282
examples/tutorial/sequence_parallel/model/bert.py
Normal file
@@ -0,0 +1,282 @@
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import inspect
|
||||
from .layers import Embedding, BertLayer, BertDualHead, PreProcessor, VocabEmbedding
|
||||
from .layers.init_method import init_normal, output_init_normal
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.kernel import LayerNorm
|
||||
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
|
||||
class BertForPretrain(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
max_sequence_length,
|
||||
num_attention_heads,
|
||||
num_layers,
|
||||
add_binary_head,
|
||||
is_naive_fp16,
|
||||
num_tokentypes=2,
|
||||
dropout_prob=0.1,
|
||||
mlp_ratio=4,
|
||||
init_std=0.02,
|
||||
convert_fp16_to_fp32_in_softmax=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size'
|
||||
self.sub_seq_length = max_sequence_length // self.seq_parallel_size
|
||||
self.init_std = init_std
|
||||
self.num_layers = num_layers
|
||||
|
||||
if not add_binary_head:
|
||||
num_tokentypes = 0
|
||||
|
||||
self.preprocessor = PreProcessor(self.sub_seq_length)
|
||||
self.embedding = Embedding(hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dropout_prob=dropout_prob,
|
||||
num_tokentypes=num_tokentypes)
|
||||
self.bert_layers = nn.ModuleList()
|
||||
|
||||
for i in range(num_layers):
|
||||
bert_layer = BertLayer(layer_number=i+1,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=dropout_prob,
|
||||
mlp_ratio=mlp_ratio,
|
||||
hidden_dropout=dropout_prob,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
is_naive_fp16=is_naive_fp16
|
||||
)
|
||||
self.bert_layers.append(bert_layer)
|
||||
|
||||
self.layer_norm = LayerNorm(hidden_size)
|
||||
self.head = BertDualHead(hidden_size, self.embedding.word_embedding_weight.size(0),
|
||||
add_binary_head=add_binary_head)
|
||||
self.reset_parameters()
|
||||
|
||||
def _init_normal(self, tensor):
|
||||
init_normal(tensor, sigma=self.init_std)
|
||||
|
||||
def _output_init_normal(self, tensor):
|
||||
output_init_normal(tensor, sigma=self.init_std, num_layers=self.num_layers)
|
||||
|
||||
def reset_parameters(self):
|
||||
# initialize embedding
|
||||
self._init_normal(self.embedding.word_embedding_weight)
|
||||
self._init_normal(self.embedding.position_embeddings.weight)
|
||||
if self.embedding.tokentype_embeddings:
|
||||
self._init_normal(self.embedding.tokentype_embeddings.weight)
|
||||
|
||||
# initialize bert layer
|
||||
for layer in self.bert_layers:
|
||||
# initialize self attention
|
||||
self._init_normal(layer.self_attention.query_key_value.weight)
|
||||
self._output_init_normal(layer.self_attention.dense.weight)
|
||||
self._init_normal(layer.mlp.dense_h_to_4h.weight)
|
||||
self._output_init_normal(layer.mlp.dense_4h_to_h.weight)
|
||||
|
||||
# initializer head
|
||||
self._init_normal(self.head.lm_head.dense.weight)
|
||||
if self.head.binary_head is not None:
|
||||
self._init_normal(self.head.binary_head.pooler.dense.weight)
|
||||
self._init_normal(self.head.binary_head.dense.weight)
|
||||
|
||||
def forward(self, input_ids, attention_masks, tokentype_ids, lm_labels):
|
||||
# inputs of the forward function
|
||||
# input_ids: [batch_size, sub_seq_len]
|
||||
# attention_mask: [batch_size, seq_len]
|
||||
# tokentype_ids: [batch_size, sub_seq_len]
|
||||
# outputs of preprocessor
|
||||
# pos_ids: [batch_size, sub_seq_len]
|
||||
# attention_masks: [batch_size, 1, sub_seq_len, seq_len]
|
||||
pos_ids, attention_masks = self.preprocessor(input_ids, attention_masks)
|
||||
|
||||
hidden_states = self.embedding(input_ids, pos_ids, tokentype_ids)
|
||||
|
||||
# hidden_states shape change:
|
||||
# [batch_size, sub_seq_len, hidden_size] -> [sub_seq_len, batch_size, hidden_size]
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
|
||||
for idx, layer in enumerate(self.bert_layers):
|
||||
hidden_states = layer(hidden_states, attention_masks)
|
||||
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
output = self.layer_norm(hidden_states)
|
||||
|
||||
# hidden_states: [sub_seq_len, batch_size, hidden_size]
|
||||
# word_embedding: [vocab_size, hidden_size]
|
||||
return self.head(output, self.embedding.word_embedding_weight, lm_labels)
|
||||
|
||||
|
||||
class PipelineBertForPretrain(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
max_sequence_length,
|
||||
num_attention_heads,
|
||||
num_layers,
|
||||
add_binary_head,
|
||||
is_naive_fp16,
|
||||
num_tokentypes=2,
|
||||
dropout_prob=0.1,
|
||||
mlp_ratio=4,
|
||||
init_std=0.02,
|
||||
convert_fp16_to_fp32_in_softmax=False,
|
||||
first_stage=True,
|
||||
last_stage=True,
|
||||
start_idx=None,
|
||||
end_idx=None):
|
||||
super().__init__()
|
||||
self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size'
|
||||
self.sub_seq_length = max_sequence_length // self.seq_parallel_size
|
||||
self.init_std = init_std
|
||||
self.num_layers = num_layers
|
||||
|
||||
if not add_binary_head:
|
||||
num_tokentypes = 0
|
||||
|
||||
self.first_stage = first_stage
|
||||
self.last_stage = last_stage
|
||||
|
||||
self.preprocessor = PreProcessor(self.sub_seq_length)
|
||||
|
||||
if self.first_stage:
|
||||
self.embedding = Embedding(hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dropout_prob=dropout_prob,
|
||||
num_tokentypes=num_tokentypes)
|
||||
|
||||
# transformer layers
|
||||
self.bert_layers = nn.ModuleList()
|
||||
|
||||
if start_idx is None and end_idx is None:
|
||||
start_idx = 0
|
||||
end_idx = num_layers
|
||||
|
||||
for i in range(start_idx, end_idx):
|
||||
bert_layer = BertLayer(layer_number=i+1,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=dropout_prob,
|
||||
mlp_ratio=mlp_ratio,
|
||||
hidden_dropout=dropout_prob,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
is_naive_fp16=is_naive_fp16
|
||||
)
|
||||
self.bert_layers.append(bert_layer)
|
||||
|
||||
if self.last_stage:
|
||||
self.word_embeddings = VocabEmbedding(vocab_size, hidden_size)
|
||||
self.layer_norm = LayerNorm(hidden_size)
|
||||
self.head = BertDualHead(hidden_size, vocab_size,
|
||||
add_binary_head=add_binary_head)
|
||||
self.reset_parameters()
|
||||
|
||||
def _init_normal(self, tensor):
|
||||
init_normal(tensor, sigma=self.init_std)
|
||||
|
||||
def _output_init_normal(self, tensor):
|
||||
output_init_normal(tensor, sigma=self.init_std, num_layers=self.num_layers)
|
||||
|
||||
def reset_parameters(self):
|
||||
# initialize embedding
|
||||
if self.first_stage:
|
||||
self._init_normal(self.embedding.word_embedding_weight)
|
||||
self._init_normal(self.embedding.position_embeddings.weight)
|
||||
if self.embedding.tokentype_embeddings:
|
||||
self._init_normal(self.embedding.tokentype_embeddings.weight)
|
||||
|
||||
# initialize bert layer
|
||||
for layer in self.bert_layers:
|
||||
# initialize self attention
|
||||
self._init_normal(layer.self_attention.query_key_value.weight)
|
||||
self._output_init_normal(layer.self_attention.dense.weight)
|
||||
self._init_normal(layer.mlp.dense_h_to_4h.weight)
|
||||
self._output_init_normal(layer.mlp.dense_4h_to_h.weight)
|
||||
|
||||
# initializer head
|
||||
if self.last_stage:
|
||||
self._init_normal(self.head.lm_head.dense.weight)
|
||||
if self.head.binary_head is not None:
|
||||
self._init_normal(self.head.binary_head.pooler.dense.weight)
|
||||
self._init_normal(self.head.binary_head.dense.weight)
|
||||
|
||||
def forward(self, input_ids, attention_masks, tokentype_ids, lm_labels):
|
||||
# inputs of the forward function
|
||||
# input_ids: [batch_size, sub_seq_len]
|
||||
# attention_mask: [batch_size, seq_len]
|
||||
# tokentype_ids: [batch_size, sub_seq_len]
|
||||
# outputs of preprocessor
|
||||
# pos_ids: [batch_size, sub_seq_len]
|
||||
# attention_masks: [batch_size, 1, sub_seq_len, seq_len]
|
||||
if self.first_stage:
|
||||
pos_ids, attention_masks = self.preprocessor(input_ids, attention_masks)
|
||||
else:
|
||||
_, attention_masks = self.preprocessor(None, attention_masks)
|
||||
|
||||
if self.first_stage:
|
||||
hidden_states = self.embedding(input_ids, pos_ids, tokentype_ids)
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
else:
|
||||
hidden_states = input_ids
|
||||
|
||||
# hidden_states shape change:
|
||||
# [batch_size, sub_seq_len, hidden_size] -> [sub_seq_len, batch_size, hidden_size]
|
||||
for idx, layer in enumerate(self.bert_layers):
|
||||
hidden_states = layer(hidden_states, attention_masks)
|
||||
|
||||
if self.last_stage:
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
output = self.layer_norm(hidden_states)
|
||||
output = self.head(output, self.word_embeddings.weight, lm_labels)
|
||||
else:
|
||||
output = hidden_states
|
||||
|
||||
# hidden_states: [sub_seq_len, batch_size, hidden_size]
|
||||
# word_embedding: [vocab_size, hidden_size]
|
||||
return output
|
||||
|
||||
|
||||
def _filter_kwargs(func, kwargs):
|
||||
sig = inspect.signature(func)
|
||||
return {k: v for k, v in kwargs.items() if k in sig.parameters}
|
||||
|
||||
|
||||
def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||||
logger = get_dist_logger()
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
rank = gpc.get_global_rank()
|
||||
wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
|
||||
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
|
||||
models = []
|
||||
for start, end in parts:
|
||||
kwargs['num_layers'] = num_layers
|
||||
kwargs['start_idx'] = start
|
||||
kwargs['end_idx'] = end
|
||||
kwargs['first_stage'] = start == 0
|
||||
kwargs['last_stage'] = end == num_layers
|
||||
logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
|
||||
chunk = PipelineBertForPretrain(**_filter_kwargs(PipelineBertForPretrain.__init__, kwargs)).to(device)
|
||||
if start == 0:
|
||||
wrapper.register_module(chunk.embedding.word_embeddings)
|
||||
elif end == num_layers:
|
||||
wrapper.register_module(chunk.word_embeddings)
|
||||
models.append(chunk)
|
||||
if len(models) == 1:
|
||||
model = models[0]
|
||||
else:
|
||||
model = nn.ModuleList(models)
|
||||
return model
|
||||
@@ -0,0 +1,4 @@
|
||||
from .embedding import VocabEmbedding, Embedding
|
||||
from .bert_layer import BertLayer
|
||||
from .head import BertDualHead
|
||||
from .preprocess import PreProcessor
|
||||
118
examples/tutorial/sequence_parallel/model/layers/bert_layer.py
Normal file
118
examples/tutorial/sequence_parallel/model/layers/bert_layer.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.nn.layer.parallel_sequence import TransformerSelfAttentionRing
|
||||
from colossalai.kernel.jit import bias_dropout_add_fused_train, bias_dropout_add_fused_inference
|
||||
from colossalai.kernel.cuda_native import LayerNorm
|
||||
from .mlp import TransformerMLP
|
||||
from .dropout import get_bias_dropout_add
|
||||
|
||||
|
||||
def attention_mask_func(attention_scores, attention_mask):
|
||||
attention_scores.masked_fill_(attention_mask, -10000.0)
|
||||
return attention_scores
|
||||
|
||||
|
||||
class BertLayer(nn.Module):
|
||||
"""A single transformer layer.
|
||||
Transformer layer takes input with size [b, s, h] and returns an
|
||||
output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
layer_number,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout,
|
||||
mlp_ratio,
|
||||
hidden_dropout,
|
||||
is_naive_fp16,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
fp32_residual_connection=False,
|
||||
bias_dropout_fusion: bool = True,
|
||||
convert_fp16_to_fp32_in_softmax: bool = False):
|
||||
super().__init__()
|
||||
self.layer_number = layer_number
|
||||
|
||||
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
||||
self.fp32_residual_connection = fp32_residual_connection
|
||||
|
||||
# Layernorm on the input data.
|
||||
self.input_layernorm = LayerNorm(hidden_size)
|
||||
|
||||
# Self attention.
|
||||
self.self_attention = TransformerSelfAttentionRing(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=attention_dropout,
|
||||
attention_mask_func=attention_mask_func,
|
||||
layer_number=layer_number,
|
||||
apply_query_key_layer_scaling=True,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
fp16=is_naive_fp16
|
||||
)
|
||||
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.bias_dropout_fusion = bias_dropout_fusion
|
||||
|
||||
# Layernorm on the attention output
|
||||
self.post_attention_layernorm = LayerNorm(hidden_size)
|
||||
|
||||
self.mlp = TransformerMLP(hidden_size=hidden_size, mlp_ratio=mlp_ratio)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
# hidden_states: [batch_size, sub_seq_len, hidden_size]
|
||||
# attention_mask: [batch_size, 1, sub_seq_len, seq_len]
|
||||
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self attention.
|
||||
attention_output, attention_bias = self.self_attention(layernorm_output, attention_mask)
|
||||
|
||||
# Residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
# jit scripting for a nn.module (with dropout) is not
|
||||
# trigerring the fusion kernel. For now, we use two
|
||||
# different nn.functional routines to account for varying
|
||||
# dropout semantics during training and inference phases.
|
||||
if self.bias_dropout_fusion:
|
||||
if self.training:
|
||||
bias_dropout_add_func = bias_dropout_add_fused_train
|
||||
else:
|
||||
bias_dropout_add_func = bias_dropout_add_fused_inference
|
||||
else:
|
||||
bias_dropout_add_func = get_bias_dropout_add(self.training)
|
||||
|
||||
# re-enable torch grad to enable fused optimization.
|
||||
with torch.enable_grad():
|
||||
layernorm_input = bias_dropout_add_func(
|
||||
attention_output,
|
||||
attention_bias.expand_as(residual),
|
||||
residual,
|
||||
self.hidden_dropout)
|
||||
|
||||
# Layer norm post the self attention.
|
||||
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
||||
|
||||
# MLP.
|
||||
mlp_output, mlp_bias = self.mlp(layernorm_output)
|
||||
|
||||
# Second residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = layernorm_input
|
||||
|
||||
# re-enable torch grad to enable fused optimization.
|
||||
with torch.enable_grad():
|
||||
output = bias_dropout_add_func(
|
||||
mlp_output,
|
||||
mlp_bias.expand_as(residual),
|
||||
residual,
|
||||
self.hidden_dropout)
|
||||
|
||||
return output
|
||||
13
examples/tutorial/sequence_parallel/model/layers/dropout.py
Normal file
13
examples/tutorial/sequence_parallel/model/layers/dropout.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import torch
|
||||
|
||||
def bias_dropout_add(x, bias, residual, prob, training):
|
||||
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
|
||||
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
|
||||
out = residual + out
|
||||
return out
|
||||
|
||||
|
||||
def get_bias_dropout_add(training):
|
||||
def _bias_dropout_add(x, bias, residual, prob):
|
||||
return bias_dropout_add(x, bias, residual, prob, training)
|
||||
return _bias_dropout_add
|
||||
@@ -0,0 +1,96 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
|
||||
|
||||
class VocabEmbedding(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_embeddings, embedding_dim):
|
||||
super(VocabEmbedding, self).__init__()
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = None
|
||||
self.max_norm = None
|
||||
self.norm_type = 2.
|
||||
self.scale_grad_by_freq = False
|
||||
self.sparse = False
|
||||
self._weight = None
|
||||
|
||||
# Allocate weights and initialize.
|
||||
self.weight = nn.Parameter(torch.empty(
|
||||
self.num_embeddings, self.embedding_dim))
|
||||
init.xavier_uniform_(self.weight)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
output = F.embedding(hidden_state, self.weight,
|
||||
self.padding_idx, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq,
|
||||
self.sparse)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f'VocabEmbedding(num_embeddings={self.num_embeddings}, ' \
|
||||
f'embedding_dim={self.embedding_dim})'
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
"""Language model embeddings.
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
vocab_size: vocabulary size
|
||||
max_sequence_length: maximum size of sequence. This
|
||||
is used for positional embedding
|
||||
embedding_dropout_prob: dropout probability for embeddings
|
||||
init_method: weight initialization method
|
||||
num_tokentypes: size of the token-type embeddings. 0 value
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
max_sequence_length,
|
||||
embedding_dropout_prob,
|
||||
num_tokentypes):
|
||||
super(Embedding, self).__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_tokentypes = num_tokentypes
|
||||
|
||||
self.word_embeddings = VocabEmbedding(vocab_size, self.hidden_size)
|
||||
|
||||
# Position embedding (serial).
|
||||
self.position_embeddings = torch.nn.Embedding(
|
||||
max_sequence_length, self.hidden_size)
|
||||
|
||||
# Token type embedding.
|
||||
# Add this as an optional field that can be added through
|
||||
# method call so we can load a pretrain model without
|
||||
# token types and add them as needed.
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
|
||||
self.hidden_size)
|
||||
else:
|
||||
self.tokentype_embeddings = None
|
||||
|
||||
# Embeddings dropout
|
||||
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
|
||||
|
||||
@property
|
||||
def word_embedding_weight(self):
|
||||
return self.word_embeddings.weight
|
||||
|
||||
def forward(self, input_ids, position_ids, tokentype_ids=None):
|
||||
# Embeddings.
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings = words_embeddings + position_embeddings
|
||||
if tokentype_ids is not None and self.tokentype_embeddings is not None:
|
||||
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
|
||||
|
||||
# Dropout.
|
||||
embeddings = self.embedding_dropout(embeddings)
|
||||
|
||||
return embeddings
|
||||
78
examples/tutorial/sequence_parallel/model/layers/head.py
Normal file
78
examples/tutorial/sequence_parallel/model/layers/head.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .pooler import Pooler
|
||||
from .linear import Linear
|
||||
from .embedding import VocabEmbedding
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.kernel import LayerNorm
|
||||
from loss_func.cross_entropy import vocab_cross_entropy
|
||||
|
||||
|
||||
class BertLMHead(nn.Module):
|
||||
"""Masked LM head for Bert
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
init_method: init method for weight initialization
|
||||
layernorm_epsilon: tolerance for layer norm divisions
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
):
|
||||
|
||||
super(BertLMHead, self).__init__()
|
||||
self.bias = torch.nn.Parameter(torch.zeros(vocab_size))
|
||||
|
||||
self.dense = Linear(hidden_size, hidden_size)
|
||||
self.layernorm = LayerNorm(hidden_size)
|
||||
self.gelu = torch.nn.functional.gelu
|
||||
|
||||
def forward(self, hidden_states, word_embeddings_weight, lm_labels):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.gelu(hidden_states)
|
||||
hidden_states = self.layernorm(hidden_states)
|
||||
|
||||
output = F.linear(hidden_states, word_embeddings_weight, self.bias)
|
||||
lm_loss = vocab_cross_entropy(output, lm_labels)
|
||||
|
||||
return lm_loss
|
||||
|
||||
|
||||
class BertBinaryHead(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size):
|
||||
super().__init__()
|
||||
self.pooler = Pooler(hidden_size)
|
||||
self.dense = Linear(hidden_size, 2)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if gpc.get_local_rank(ParallelMode.SEQUENCE) == 0:
|
||||
output = self.pooler(hidden_states)
|
||||
output = self.dense(output)
|
||||
else:
|
||||
output = None
|
||||
return output
|
||||
|
||||
|
||||
class BertDualHead(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, vocab_size, add_binary_head):
|
||||
super().__init__()
|
||||
self.lm_head = BertLMHead(vocab_size, hidden_size)
|
||||
self.add_binary_head = add_binary_head
|
||||
if add_binary_head:
|
||||
self.binary_head = BertBinaryHead(hidden_size)
|
||||
else:
|
||||
self.binary_head = None
|
||||
|
||||
def forward(self, hidden_states, word_embeddings_weight, lm_labels):
|
||||
if self.add_binary_head:
|
||||
binary_output = self.binary_head(hidden_states)
|
||||
else:
|
||||
binary_output = None
|
||||
lm_loss = self.lm_head(hidden_states, word_embeddings_weight, lm_labels)
|
||||
return lm_loss, binary_output
|
||||
@@ -0,0 +1,12 @@
|
||||
import torch
|
||||
import math
|
||||
|
||||
def init_normal(tensor, sigma):
|
||||
"""Init method based on N(0, sigma)."""
|
||||
torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
|
||||
|
||||
|
||||
def output_init_normal(tensor, sigma, num_layers):
|
||||
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
|
||||
std = sigma / math.sqrt(2.0 * num_layers)
|
||||
torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
||||
63
examples/tutorial/sequence_parallel/model/layers/linear.py
Normal file
63
examples/tutorial/sequence_parallel/model/layers/linear.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Parameter
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
"""Linear layer with column parallelism.
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its second dimension as A = [A_1, ..., A_p].
|
||||
Arguments:
|
||||
input_size: first dimension of matrix A.
|
||||
output_size: second dimension of matrix A.
|
||||
bias: If true, add bias
|
||||
init_method: method to initialize weights. Note that bias is always set
|
||||
to zero.
|
||||
stride: For the strided linear layers.
|
||||
keep_master_weight_for_test: This was added for testing and should be
|
||||
set to False. It returns the master weights
|
||||
used for initialization.
|
||||
skip_bias_add: This was added to enable performance optimations where bias
|
||||
can be fused with other elementwise operations. we skip
|
||||
adding bias but instead return it.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
bias=True,
|
||||
skip_bias_add=False):
|
||||
super(Linear, self).__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
self.weight = Parameter(torch.empty(self.output_size,
|
||||
self.input_size,
|
||||
))
|
||||
init.normal_(self.weight)
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.output_size))
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def forward(self, input_):
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output = F.linear(input_, self.weight, bias)
|
||||
|
||||
if self.skip_bias_add:
|
||||
return output, self.bias
|
||||
else:
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \
|
||||
f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})'
|
||||
50
examples/tutorial/sequence_parallel/model/layers/mlp.py
Normal file
50
examples/tutorial/sequence_parallel/model/layers/mlp.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .linear import Linear
|
||||
from colossalai.kernel.jit import bias_gelu_impl
|
||||
|
||||
|
||||
class TransformerMLP(nn.Module):
|
||||
"""MLP.
|
||||
MLP will take the input with h hidden state, project it to 4*h
|
||||
hidden dimension, perform nonlinear transformation, and project the
|
||||
state back into h hidden dimension. At the end, dropout is also
|
||||
applied.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, mlp_ratio, fuse_gelu=True):
|
||||
super(TransformerMLP, self).__init__()
|
||||
|
||||
# Project to 4h.
|
||||
self.dense_h_to_4h = Linear(
|
||||
hidden_size,
|
||||
int(hidden_size*mlp_ratio),
|
||||
skip_bias_add=True)
|
||||
|
||||
self.bias_gelu_fusion = fuse_gelu
|
||||
self.activation_func = F.gelu
|
||||
|
||||
# Project back to h.
|
||||
self.dense_4h_to_h = Linear(
|
||||
int(hidden_size*mlp_ratio),
|
||||
hidden_size,
|
||||
skip_bias_add=True)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# hidden states should be in the shape of [s, b, h]
|
||||
# it will be projects into [s, b, 4h]
|
||||
# and projected back to [s, b, h]
|
||||
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
|
||||
|
||||
if self.bias_gelu_fusion:
|
||||
intermediate_parallel = \
|
||||
bias_gelu_impl(intermediate_parallel, bias_parallel)
|
||||
else:
|
||||
intermediate_parallel = \
|
||||
self.activation_func(intermediate_parallel + bias_parallel)
|
||||
|
||||
# [s, b, h]
|
||||
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
|
||||
return output, output_bias
|
||||
28
examples/tutorial/sequence_parallel/model/layers/pooler.py
Normal file
28
examples/tutorial/sequence_parallel/model/layers/pooler.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .linear import Linear
|
||||
|
||||
|
||||
class Pooler(nn.Module):
|
||||
"""Pooler layer.
|
||||
|
||||
Pool hidden states of a specific token (for example start of the
|
||||
sequence) and add a linear transformation followed by a tanh.
|
||||
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
init_method: weight initialization method for the linear layer.
|
||||
bias is set to zero.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size):
|
||||
super(Pooler, self).__init__()
|
||||
self.dense = Linear(hidden_size, hidden_size)
|
||||
|
||||
def forward(self, hidden_states, sequence_index=0):
|
||||
# hidden_states: [b, s, h]
|
||||
# sequence_index: index of the token to pool.
|
||||
pooled = hidden_states[:, sequence_index, :]
|
||||
pooled = self.dense(pooled)
|
||||
pooled = torch.tanh(pooled)
|
||||
return pooled
|
||||
@@ -0,0 +1,58 @@
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
class PreProcessor(nn.Module):
|
||||
|
||||
def __init__(self, sub_seq_length):
|
||||
super().__init__()
|
||||
self.sub_seq_length = sub_seq_length
|
||||
|
||||
def bert_position_ids(self, token_ids):
|
||||
# Create position ids
|
||||
seq_length = token_ids.size(1)
|
||||
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
||||
position_ids = torch.arange(seq_length*local_rank,
|
||||
seq_length * (local_rank+1),
|
||||
dtype=torch.long,
|
||||
device=token_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
|
||||
|
||||
return position_ids
|
||||
|
||||
def bert_extended_attention_mask(self, attention_mask):
|
||||
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
||||
start_index = local_rank * self.sub_seq_length
|
||||
end_index = (local_rank + 1) * self.sub_seq_length
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# [b, 1, s]
|
||||
attention_mask_b1s = attention_mask.unsqueeze(1)
|
||||
# [b, s, 1]
|
||||
attention_mask_bs1 = attention_mask.unsqueeze(2)
|
||||
# [b, s/D, s]
|
||||
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
|
||||
|
||||
attention_mask_bss = attention_mask_bss[:, start_index:end_index, :]
|
||||
|
||||
# [b, 1, s/D, s]
|
||||
extended_attention_mask = attention_mask_bss.unsqueeze(1)
|
||||
|
||||
# Convert attention mask to binary:
|
||||
extended_attention_mask = (extended_attention_mask < 0.5)
|
||||
|
||||
return extended_attention_mask
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None):
|
||||
if attention_mask is not None:
|
||||
extended_attention_mask = self.bert_extended_attention_mask(attention_mask)
|
||||
else:
|
||||
extended_attention_mask = None
|
||||
|
||||
if input_ids is not None:
|
||||
position_ids = self.bert_position_ids(input_ids)
|
||||
else:
|
||||
position_ids = None
|
||||
return position_ids, extended_attention_mask
|
||||
210
examples/tutorial/sequence_parallel/train.py
Normal file
210
examples/tutorial/sequence_parallel/train.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import colossalai
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from data import build_train_valid_test_data_iterators
|
||||
from data.tokenizer import initialize_tokenizer, get_padded_vocab_size
|
||||
from data.bert_helper import get_batch_for_sequence_parallel, SequenceParallelDataIterator
|
||||
from colossalai.amp import AMP_TYPE
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import MultiTimer, is_using_pp
|
||||
from model.bert import BertForPretrain
|
||||
from lr_scheduler import AnnealingLR
|
||||
from loss_func.bert_loss import BertLoss
|
||||
import torch
|
||||
from colossalai.engine.schedule import PipelineSchedule
|
||||
from colossalai.amp import AMP_TYPE
|
||||
from colossalai.nn.optimizer import FusedAdam
|
||||
from colossalai.kernel import LayerNorm
|
||||
from model.bert import build_pipeline_bert
|
||||
|
||||
|
||||
def process_batch_data(batch_data):
|
||||
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = batch_data
|
||||
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
data = dict(input_ids=tokens, attention_masks=padding_mask, tokentype_ids=types, lm_labels=lm_labels)
|
||||
else:
|
||||
data = dict(attention_masks=padding_mask, tokentype_ids=types, lm_labels=lm_labels)
|
||||
label = dict(loss_mask=loss_mask, sentence_order=sentence_order)
|
||||
return data, label
|
||||
|
||||
|
||||
def main():
|
||||
# initialize
|
||||
colossalai.launch_from_torch(config='./config.py', seed=1234, backend='nccl')
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# build dataloader
|
||||
initialize_tokenizer(gpc.config.VOCAB_FILE_PATH, tokenizer_type='BertWordPieceLowerCase')
|
||||
VOCAB_SIZE = get_padded_vocab_size()
|
||||
trainloader, validloader, testloader = build_train_valid_test_data_iterators(
|
||||
train_iters=gpc.config.TRAIN_ITERS,
|
||||
global_batch_size=gpc.config.GLOBAL_BATCH_SIZE,
|
||||
eval_interval=gpc.config.EVAL_INTERVAL,
|
||||
eval_iters=gpc.config.EVAL_ITERS,
|
||||
data_prefix=[gpc.config.DATA_PATH],
|
||||
data_impl='mmap',
|
||||
splits_string='949,50,1',
|
||||
max_seq_length=gpc.config.SEQ_LENGTH,
|
||||
masked_lm_prob=0.15,
|
||||
short_seq_prob=0.1,
|
||||
seed=1234,
|
||||
skip_warmup=True,
|
||||
binary_head=False,
|
||||
)
|
||||
|
||||
logger.info("Dataloaders are built", ranks=[0])
|
||||
|
||||
# build model
|
||||
if hasattr(gpc.config, 'fp16') and gpc.config.fp16.get('mode') == AMP_TYPE.NAIVE:
|
||||
is_naive_fp16 = True
|
||||
else:
|
||||
is_naive_fp16 = False
|
||||
|
||||
use_pipeline = is_using_pp()
|
||||
kwargs = dict(vocab_size=VOCAB_SIZE,
|
||||
hidden_size=gpc.config.HIDDEN_SIZE,
|
||||
max_sequence_length=gpc.config.SEQ_LENGTH,
|
||||
num_attention_heads=gpc.config.NUM_ATTENTION_HEADS,
|
||||
convert_fp16_to_fp32_in_softmax=True,
|
||||
is_naive_fp16=is_naive_fp16,
|
||||
add_binary_head=gpc.config.ADD_BINARY_HEAD)
|
||||
|
||||
if use_pipeline:
|
||||
model = build_pipeline_bert(num_layers=gpc.config.DEPTH, num_chunks=1, **kwargs)
|
||||
else:
|
||||
model = BertForPretrain(num_layers=gpc.config.DEPTH, **kwargs)
|
||||
|
||||
model = model.half()
|
||||
model.reset_parameters()
|
||||
logger.info(f"Model is built with softmax in fp32 = {is_naive_fp16}", ranks=[0])
|
||||
|
||||
total_numel = 0
|
||||
for p in model.parameters():
|
||||
total_numel += p.numel()
|
||||
logger.info(f"This model has {total_numel} parameters")
|
||||
|
||||
# build criterion
|
||||
criterion = BertLoss()
|
||||
logger.info("Criterion is built", ranks=[0])
|
||||
|
||||
# layernorm and bias has no weight decay
|
||||
weight_decay_params = {'params': []}
|
||||
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
|
||||
for module_ in model.modules():
|
||||
if isinstance(module_, LayerNorm):
|
||||
no_weight_decay_params['params'].extend([p for p in list(module_._parameters.values()) if p is not None])
|
||||
else:
|
||||
weight_decay_params['params'].extend(
|
||||
[p for n, p in list(module_._parameters.items()) if p is not None and n != 'bias'])
|
||||
no_weight_decay_params['params'].extend(
|
||||
[p for n, p in list(module_._parameters.items()) if p is not None and n == 'bias'])
|
||||
|
||||
logger.info(
|
||||
f"without weight decay param: {len(no_weight_decay_params['params'])}, with weight decay param: {len(weight_decay_params['params'])}"
|
||||
)
|
||||
# optimizer
|
||||
optimizer = FusedAdam((weight_decay_params, no_weight_decay_params),
|
||||
lr=gpc.config.LR,
|
||||
weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
logger.info("Optimizer is built", ranks=[0])
|
||||
|
||||
# lr scheduler
|
||||
# follow Megatron-LM setting
|
||||
warmup_steps = int(gpc.config.DECAY_ITERS * gpc.config.WARMUP_FRACTION)
|
||||
lr_scheduler = AnnealingLR(optimizer=optimizer,
|
||||
max_lr=gpc.config.LR,
|
||||
min_lr=gpc.config.MIN_LR,
|
||||
warmup_steps=warmup_steps,
|
||||
decay_steps=gpc.config.DECAY_ITERS,
|
||||
decay_style='linear')
|
||||
logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps")
|
||||
|
||||
# # init
|
||||
engine, *dummy = colossalai.initialize(
|
||||
model,
|
||||
optimizer,
|
||||
criterion,
|
||||
)
|
||||
|
||||
# build timer
|
||||
timer = MultiTimer()
|
||||
skip_iters = 0
|
||||
|
||||
# build loss tracker
|
||||
accumulated_train_loss = torch.zeros(1, dtype=torch.float32).cuda()
|
||||
accumulated_eval_loss = torch.zeros(1, dtype=torch.float32).cuda()
|
||||
|
||||
# build data iters for pipeline parallel
|
||||
if use_pipeline:
|
||||
train_data_iter = SequenceParallelDataIterator(trainloader)
|
||||
valid_data_iter = SequenceParallelDataIterator(validloader)
|
||||
|
||||
for step in range(1, gpc.config.TRAIN_ITERS + 1):
|
||||
timer.start('train-iterations')
|
||||
engine.train()
|
||||
if use_pipeline:
|
||||
engine.zero_grad()
|
||||
_, _, train_loss = engine.execute_schedule(train_data_iter, return_output_label=False)
|
||||
engine.step()
|
||||
else:
|
||||
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel(
|
||||
trainloader)
|
||||
engine.zero_grad()
|
||||
lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels)
|
||||
train_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order)
|
||||
engine.backward(train_loss)
|
||||
engine.step()
|
||||
timer.stop('train-iterations', keep_in_history=True)
|
||||
|
||||
if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
accumulated_train_loss += train_loss
|
||||
|
||||
lr_scheduler.step()
|
||||
|
||||
if step % gpc.config.EVAL_INTERVAL == 0:
|
||||
engine.eval()
|
||||
|
||||
for j in range(gpc.config.EVAL_ITERS):
|
||||
with torch.no_grad():
|
||||
if use_pipeline:
|
||||
_, _, eval_loss = engine.execute_schedule(valid_data_iter,
|
||||
forward_only=True,
|
||||
return_output_label=False)
|
||||
else:
|
||||
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel(
|
||||
validloader)
|
||||
lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels)
|
||||
eval_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order)
|
||||
|
||||
if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
accumulated_eval_loss += eval_loss
|
||||
|
||||
if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
accumulated_eval_loss /= gpc.config.EVAL_ITERS
|
||||
accumulated_train_loss /= gpc.config.EVAL_INTERVAL
|
||||
|
||||
timer_string = []
|
||||
for n, t in timer:
|
||||
timer_string.append(f"{n}: {t.get_history_mean()*1000:.5f}")
|
||||
timer_string = ' | '.join(timer_string)
|
||||
lr = list(engine.optimizer.param_groups)[0]['lr']
|
||||
loss_scale = engine.optimizer.optim.loss_scale.item()
|
||||
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
ranks = [gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1]]
|
||||
else:
|
||||
ranks = [0]
|
||||
logger.info(f'Step {step} / {gpc.config.TRAIN_ITERS} | Train Loss: {accumulated_train_loss.item():.5g} ' +
|
||||
f'| Eval Loss: {accumulated_eval_loss.item():.5g} ' + f'| Loss Scale: {loss_scale}' +
|
||||
f"| Learning rate: {lr} | " + timer_string,
|
||||
ranks=ranks)
|
||||
|
||||
for n, t in timer:
|
||||
t.reset()
|
||||
accumulated_eval_loss.zero_()
|
||||
accumulated_train_loss.zero_()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user