mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-04 00:25:54 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -2,12 +2,11 @@ import itertools
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from megatron import get_tokenizer
|
||||
from megatron import get_args
|
||||
from megatron import get_args, get_tokenizer
|
||||
from megatron.data.dataset_utils import get_indexed_dataset_
|
||||
from megatron.data.realm_dataset_utils import get_block_samples_mapping
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def make_attention_mask(source_block, target_block):
|
||||
"""
|
||||
@@ -20,16 +19,17 @@ def make_attention_mask(source_block, target_block):
|
||||
# (source_length, target_length)
|
||||
return mask
|
||||
|
||||
|
||||
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
|
||||
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
|
||||
rather than for training, since it is only built with a single epoch sample mapping.
|
||||
"""
|
||||
args = get_args()
|
||||
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
|
||||
titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
|
||||
block_dataset = get_indexed_dataset_(args.data_path, "mmap", True)
|
||||
titles_dataset = get_indexed_dataset_(args.titles_data_path, "mmap", True)
|
||||
|
||||
kwargs = dict(
|
||||
name='full',
|
||||
name="full",
|
||||
block_dataset=block_dataset,
|
||||
title_dataset=titles_dataset,
|
||||
data_prefix=args.data_path,
|
||||
@@ -39,7 +39,7 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1):
|
||||
seed=1,
|
||||
query_in_block_prob=query_in_block_prob,
|
||||
use_titles=use_titles,
|
||||
use_one_sent_docs=args.use_one_sent_docs
|
||||
use_one_sent_docs=args.use_one_sent_docs,
|
||||
)
|
||||
dataset = ICTDataset(**kwargs)
|
||||
return dataset
|
||||
@@ -47,9 +47,22 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1):
|
||||
|
||||
class ICTDataset(Dataset):
|
||||
"""Dataset containing sentences and their blocks for an inverse cloze task."""
|
||||
def __init__(self, name, block_dataset, title_dataset, data_prefix,
|
||||
num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
|
||||
seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
block_dataset,
|
||||
title_dataset,
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
max_seq_length,
|
||||
query_in_block_prob,
|
||||
seed,
|
||||
use_titles=True,
|
||||
use_one_sent_docs=False,
|
||||
binary_head=False,
|
||||
):
|
||||
self.name = name
|
||||
self.seed = seed
|
||||
self.max_seq_length = max_seq_length
|
||||
@@ -61,8 +74,16 @@ class ICTDataset(Dataset):
|
||||
self.use_one_sent_docs = use_one_sent_docs
|
||||
|
||||
self.samples_mapping = get_block_samples_mapping(
|
||||
block_dataset, title_dataset, data_prefix, num_epochs,
|
||||
max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
|
||||
block_dataset,
|
||||
title_dataset,
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
max_seq_length,
|
||||
seed,
|
||||
name,
|
||||
use_one_sent_docs,
|
||||
)
|
||||
self.tokenizer = get_tokenizer()
|
||||
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
|
||||
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
|
||||
@@ -99,8 +120,8 @@ class ICTDataset(Dataset):
|
||||
|
||||
# still need to truncate because blocks are concluded when
|
||||
# the sentence lengths have exceeded max_seq_length.
|
||||
query = query[:self.max_seq_length - 2]
|
||||
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
|
||||
query = query[: self.max_seq_length - 2]
|
||||
block = list(itertools.chain(*block))[: self.max_seq_length - title_pad_offset]
|
||||
|
||||
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
|
||||
context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)
|
||||
@@ -111,13 +132,13 @@ class ICTDataset(Dataset):
|
||||
block_data = sample_data.as_array()
|
||||
|
||||
sample = {
|
||||
'query_tokens': query_tokens,
|
||||
'query_mask': query_mask,
|
||||
'query_pad_mask': query_pad_mask,
|
||||
'context_tokens': context_tokens,
|
||||
'context_mask': context_mask,
|
||||
'context_pad_mask': context_pad_mask,
|
||||
'block_data': block_data,
|
||||
"query_tokens": query_tokens,
|
||||
"query_mask": query_mask,
|
||||
"query_pad_mask": query_pad_mask,
|
||||
"context_tokens": context_tokens,
|
||||
"context_mask": context_mask,
|
||||
"context_pad_mask": context_pad_mask,
|
||||
"block_data": block_data,
|
||||
}
|
||||
|
||||
return sample
|
||||
@@ -127,7 +148,7 @@ class ICTDataset(Dataset):
|
||||
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
|
||||
title = self.title_dataset[int(doc_idx)]
|
||||
|
||||
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
|
||||
block = list(itertools.chain(*block))[: self.max_seq_length - (3 + len(title))]
|
||||
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
|
||||
|
||||
return block_tokens, block_pad_mask
|
||||
|
Reference in New Issue
Block a user