[example] reorganize for community examples (#3557)

This commit is contained in:
binmakeswell
2023-04-14 16:27:48 +08:00
committed by GitHub
parent 1a809eddaa
commit f1b3d60cae
31 changed files with 785 additions and 844 deletions

View File

@@ -0,0 +1,9 @@
CXXFLAGS += -O3 -Wall -shared -std=c++14 -fPIC -fdiagnostics-color
CPPFLAGS += $(shell python3 -m pybind11 --includes)
LIBNAME = mask
LIBEXT = $(shell python3-config --extension-suffix)
default: $(LIBNAME)$(LIBEXT)
%$(LIBEXT): %.cpp
$(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@

View File

@@ -0,0 +1,105 @@
# Data PreProcessing for chinese Whole Word Masked
<span id='all_catelogue'/>
## Catalogue:
* <a href='#introduction'>1. Introduction</a>
* <a href='#Quick Start Guide'>2. Quick Start Guide:</a>
* <a href='#Split Sentence'>2.1. Split Sentence</a>
* <a href='#Tokenizer & Whole Word Masked'>2.2.Tokenizer & Whole Word Masked</a>
<span id='introduction'/>
## 1. Introduction: <a href='#all_catelogue'>[Back to Top]</a>
This folder is used to preprocess chinese corpus with Whole Word Masked. You can obtain corpus from [WuDao](https://resource.wudaoai.cn/home?ind&name=WuDaoCorpora%202.0&id=1394901288847716352). Moreover, data preprocessing is flexible, and you can modify the code based on your needs, hardware or parallel framework(Open MPI, Spark, Dask).
<span id='Quick Start Guide'/>
## 2. Quick Start Guide: <a href='#all_catelogue'>[Back to Top]</a>
<span id='Split Sentence'/>
### 2.1. Split Sentence & Split data into multiple shard:
Firstly, each file has multiple documents, and each document contains multiple sentences. Split sentence through punctuation, such as `。!`. **Secondly, split data into multiple shard based on server hardware (cpu, cpu memory, hard disk) and corpus size.** Each shard contains a part of corpus, and the model needs to train all the shards as one epoch.
In this example, split 200G Corpus into 100 shard, and each shard is about 2G. The size of the shard is memory-dependent, taking into account the number of servers, the memory used by the tokenizer, and the memory used by the multi-process training to read the shard (n data parallel requires n\*shard_size memory). **To sum up, data preprocessing and model pretraining requires fighting with hardware, not just GPU.**
```python
python sentence_split.py --input_path /orginal_corpus --output_path /shard --shard 100
# This step takes a short time
```
* `--input_path`: all original corpus, e.g., /orginal_corpus/0.json /orginal_corpus/1.json ...
* `--output_path`: all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ...
* `--shard`: Number of shard, e.g., 10, 50, or 100
<summary><b>Input json:</b></summary>
```
[
{
"id": 0,
"title": "打篮球",
"content": "我今天去打篮球。不回来吃饭。"
}
{
"id": 1,
"title": "旅游",
"content": "我后天去旅游。下周请假。"
}
]
```
<summary><b>Output txt:</b></summary>
```
我今天去打篮球。
不回来吃饭。
]]
我后天去旅游。
下周请假。
```
<span id='Tokenizer & Whole Word Masked'/>
### 2.2. Tokenizer & Whole Word Masked:
```python
python tokenize_mask.py --input_path /shard --output_path /h5 --tokenizer_path /roberta --backend python
# This step is time consuming and is mainly spent on mask
```
**[optional but recommended]**: the C++ backend with `pybind11` can provide faster speed
```shell
make
```
* `--input_path`: location of all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ...
* `--output_path`: location of all h5 with token_id, input_mask, segment_ids and masked_lm_positions, e.g., /h5/0.h5, /h5/1.h5 ...
* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json. Download config.json, special_tokens_map.json, vocab.txt and tokenzier.json from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main)
* `--backend`: python or c++, **specifies c++ can obtain faster preprocess speed**
* `--dupe_factor`: specifies how many times the preprocessor repeats to create the input from the same article/document
* `--worker`: number of process
<summary><b>Input txt:</b></summary>
```
我今天去打篮球。
不回来吃饭。
]]
我后天去旅游。
下周请假。
```
<summary><b>Output h5+numpy:</b></summary>
```
'input_ids': [[id0,id1,id2,id3,id4,id5,id6,0,0..],
...]
'input_mask': [[1,1,1,1,1,1,0,0..],
...]
'segment_ids': [[0,0,0,0,0,...],
...]
'masked_lm_positions': [[label1,-1,-1,label2,-1...],
...]
```

View File

@@ -0,0 +1,260 @@
import collections
import logging
import os
import random
import time
from enum import IntEnum
from random import choice
import jieba
import torch
jieba.setLogLevel(logging.CRITICAL)
import re
import mask
import numpy as np
PAD = 0
MaskedLMInstance = collections.namedtuple("MaskedLMInstance", ["index", "label"])
def map_to_numpy(data):
return np.asarray(data)
class PreTrainingDataset():
def __init__(self,
tokenizer,
max_seq_length,
backend='python',
max_predictions_per_seq: int = 80,
do_whole_word_mask: bool = True):
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.masked_lm_prob = 0.15
self.backend = backend
self.do_whole_word_mask = do_whole_word_mask
self.max_predictions_per_seq = max_predictions_per_seq
self.vocab_words = list(tokenizer.vocab.keys())
self.rec = re.compile('[\u4E00-\u9FA5]')
self.whole_rec = re.compile('##[\u4E00-\u9FA5]')
self.mlm_p = 0.15
self.mlm_mask_p = 0.8
self.mlm_tamper_p = 0.05
self.mlm_maintain_p = 0.1
def tokenize(self, doc):
temp = []
for d in doc:
temp.append(self.tokenizer.tokenize(d))
return temp
def create_training_instance(self, instance):
is_next = 1
raw_text_list = self.get_new_segment(instance)
tokens_a = raw_text_list
assert len(tokens_a) == len(instance)
# tokens_a, tokens_b, is_next = instance.get_values()
# print(f'is_next label:{is_next}')
# Create mapper
tokens = []
original_tokens = []
segment_ids = []
tokens.append("[CLS]")
original_tokens.append('[CLS]')
segment_ids.append(0)
for index, token in enumerate(tokens_a):
tokens.append(token)
original_tokens.append(instance[index])
segment_ids.append(0)
tokens.append("[SEP]")
original_tokens.append('[SEP]')
segment_ids.append(0)
# for token in tokens_b:
# tokens.append(token)
# segment_ids.append(1)
# tokens.append("[SEP]")
# segment_ids.append(1)
# Get Masked LM predictions
if self.backend == 'c++':
output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions(
tokens, original_tokens, self.vocab_words, self.tokenizer.vocab, self.max_predictions_per_seq,
self.masked_lm_prob)
elif self.backend == 'python':
output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens)
# Convert to Ids
input_ids = self.tokenizer.convert_tokens_to_ids(output_tokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < self.max_seq_length:
input_ids.append(PAD)
segment_ids.append(PAD)
input_mask.append(PAD)
masked_lm_output.append(-1)
return ([
map_to_numpy(input_ids),
map_to_numpy(input_mask),
map_to_numpy(segment_ids),
map_to_numpy(masked_lm_output),
map_to_numpy([is_next])
])
def create_masked_lm_predictions(self, tokens):
cand_indexes = []
for i, token in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
# cand_indexes.append(i)
random.shuffle(cand_indexes)
output_tokens = list(tokens)
num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob))))
masked_lms = []
covered_indexes = set()
for index in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
if index in covered_indexes:
continue
covered_indexes.add(index)
masked_token = None
# 80% mask
if random.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% Keep Original
if random.random() < 0.5:
masked_token = tokens[index]
# 10% replace w/ random word
else:
masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)]
output_tokens[index] = masked_token
masked_lms.append(MaskedLMInstance(index=index, label=tokens[index]))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_output = [-1] * len(output_tokens)
for p in masked_lms:
masked_lm_output[p.index] = self.tokenizer.vocab[p.label]
return (output_tokens, masked_lm_output)
def get_new_segment(self, segment):
"""
Input a sentence, return a processed sentence: In order to support the Chinese whole word mask, the words that are separated will be marked with a special mark ("#"), so that the subsequent processing module can know which words belong to the same word.
:param segment: a sentence
"""
seq_cws = jieba.lcut(''.join(segment))
seq_cws_dict = {x: 1 for x in seq_cws}
new_segment = []
i = 0
while i < len(segment):
if len(self.rec.findall(segment[i])) == 0:
new_segment.append(segment[i])
i += 1
continue
has_add = False
for length in range(3, 0, -1):
if i + length > len(segment):
continue
if ''.join(segment[i:i + length]) in seq_cws_dict:
new_segment.append(segment[i])
for l in range(1, length):
new_segment.append('##' + segment[i + l])
i += length
has_add = True
break
if not has_add:
new_segment.append(segment[i])
i += 1
return new_segment
def create_whole_masked_lm_predictions(self, tokens):
"""Creates the predictions for the masked LM objective."""
cand_indexes = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
random.shuffle(cand_indexes)
output_tokens = [t[2:] if len(self.whole_rec.findall(t)) > 0 else t for t in tokens] # 去掉"##"
num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob))))
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if random.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% of the time, keep original
if random.random() < 0.5:
masked_token = tokens[index][2:] if len(self.whole_rec.findall(
tokens[index])) > 0 else tokens[index] # 去掉"##"
# 10% of the time, replace with random word
else:
masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)]
output_tokens[index] = masked_token
masked_lms.append(
MaskedLMInstance(
index=index,
label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index]))
assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_output = [-1] * len(output_tokens)
for p in masked_lms:
masked_lm_output[p.index] = self.tokenizer.vocab[p.label]
return (output_tokens, masked_lm_output)

View File

@@ -0,0 +1,190 @@
#include <math.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <algorithm>
#include <chrono>
#include <iostream>
#include <limits>
#include <random>
#include <stdexcept>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace py = pybind11;
const int32_t LONG_SENTENCE_LEN = 512;
struct MaskedLMInstance {
int index;
std::string label;
MaskedLMInstance(int index, std::string label) {
this->index = index;
this->label = label;
}
};
auto get_new_segment(
std::vector<std::string> segment, std::vector<std::string> segment_jieba,
const std::vector<bool> chinese_vocab) { // const
// std::unordered_set<std::string>
// &chinese_vocab
std::unordered_set<std::string> seq_cws_dict;
for (auto word : segment_jieba) {
seq_cws_dict.insert(word);
}
int i = 0;
std::vector<std::string> new_segment;
int segment_size = segment.size();
while (i < segment_size) {
if (!chinese_vocab[i]) { // chinese_vocab.find(segment[i]) ==
// chinese_vocab.end()
new_segment.emplace_back(segment[i]);
i += 1;
continue;
}
bool has_add = false;
for (int length = 3; length >= 1; length--) {
if (i + length > segment_size) {
continue;
}
std::string chinese_word = "";
for (int j = i; j < i + length; j++) {
chinese_word += segment[j];
}
if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) {
new_segment.emplace_back(segment[i]);
for (int j = i + 1; j < i + length; j++) {
new_segment.emplace_back("##" + segment[j]);
}
i += length;
has_add = true;
break;
}
}
if (!has_add) {
new_segment.emplace_back(segment[i]);
i += 1;
}
}
return new_segment;
}
bool startsWith(const std::string &s, const std::string &sub) {
return s.find(sub) == 0 ? true : false;
}
auto create_whole_masked_lm_predictions(
std::vector<std::string> &tokens,
const std::vector<std::string> &original_tokens,
const std::vector<std::string> &vocab_words,
std::map<std::string, int> &vocab, const int max_predictions_per_seq,
const double masked_lm_prob) {
// for (auto item : vocab) {
// std::cout << "key=" << std::string(py::str(item.first)) << ", "
// << "value=" << std::string(py::str(item.second)) <<
// std::endl;
// }
std::vector<std::vector<int> > cand_indexes;
std::vector<int> cand_temp;
int tokens_size = tokens.size();
std::string prefix = "##";
bool do_whole_masked = true;
for (int i = 0; i < tokens_size; i++) {
if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") {
continue;
}
if (do_whole_masked && (cand_indexes.size() > 0) &&
(tokens[i].rfind(prefix, 0) == 0)) {
cand_temp.emplace_back(i);
} else {
if (cand_temp.size() > 0) {
cand_indexes.emplace_back(cand_temp);
}
cand_temp.clear();
cand_temp.emplace_back(i);
}
}
auto seed = std::chrono::system_clock::now().time_since_epoch().count();
std::shuffle(cand_indexes.begin(), cand_indexes.end(),
std::default_random_engine(seed));
// for (auto i : cand_indexes) {
// for (auto j : i) {
// std::cout << tokens[j] << " ";
// }
// std::cout << std::endl;
// }
// for (auto i : output_tokens) {
// std::cout << i;
// }
// std::cout << std::endl;
int num_to_predict = std::min(max_predictions_per_seq,
std::max(1, int(tokens_size * masked_lm_prob)));
// std::cout << num_to_predict << std::endl;
std::set<int> covered_indexes;
std::vector<int> masked_lm_output(tokens_size, -1);
int vocab_words_len = vocab_words.size();
std::default_random_engine e(seed);
std::uniform_real_distribution<double> u1(0.0, 1.0);
std::uniform_int_distribution<unsigned> u2(0, vocab_words_len - 1);
int mask_cnt = 0;
std::vector<std::string> output_tokens;
output_tokens = original_tokens;
for (auto index_set : cand_indexes) {
if (mask_cnt > num_to_predict) {
break;
}
int index_set_size = index_set.size();
if (mask_cnt + index_set_size > num_to_predict) {
continue;
}
bool is_any_index_covered = false;
for (auto index : index_set) {
if (covered_indexes.find(index) != covered_indexes.end()) {
is_any_index_covered = true;
break;
}
}
if (is_any_index_covered) {
continue;
}
for (auto index : index_set) {
covered_indexes.insert(index);
std::string masked_token;
if (u1(e) < 0.8) {
masked_token = "[MASK]";
} else {
if (u1(e) < 0.5) {
masked_token = output_tokens[index];
} else {
int random_index = u2(e);
masked_token = vocab_words[random_index];
}
}
// masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index]));
masked_lm_output[index] = vocab[output_tokens[index]];
output_tokens[index] = masked_token;
mask_cnt++;
}
}
// for (auto p : masked_lms) {
// masked_lm_output[p.index] = vocab[p.label];
// }
return std::make_tuple(output_tokens, masked_lm_output);
}
PYBIND11_MODULE(mask, m) {
m.def("create_whole_masked_lm_predictions",
&create_whole_masked_lm_predictions);
m.def("get_new_segment", &get_new_segment);
}

View File

@@ -0,0 +1,152 @@
import argparse
import functools
import json
import multiprocessing
import os
import re
import time
from typing import List
from tqdm import tqdm
def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]:
sent_list = []
try:
if flag == "zh":
document = re.sub('(?P<quotation_mark>([。?!…](?![”’"\'])))', r'\g<quotation_mark>\n', document)
document = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\'])', r'\g<quotation_mark>\n', document)
elif flag == "en":
document = re.sub('(?P<quotation_mark>([.?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
document = re.sub('(?P<quotation_mark>([?!.]["\']))', r'\g<quotation_mark>\n',
document) # Special quotation marks
else:
document = re.sub('(?P<quotation_mark>([。?!….?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
document = re.sub('(?P<quotation_mark>(([。?!.!?]|…{1,2})[”’"\']))', r'\g<quotation_mark>\n',
document) # Special quotation marks
sent_list_ori = document.splitlines()
for sent in sent_list_ori:
sent = sent.strip()
if not sent:
continue
elif len(sent) <= 2:
continue
else:
while len(sent) > limit:
temp = sent[0:limit]
sent_list.append(temp)
sent = sent[limit:]
sent_list.append(sent)
except:
sent_list.clear()
sent_list.append(document)
return sent_list
def get_sent(output_path, input_path, fin_list=[], host=-1, seq_len=512) -> None:
workers = 32
if input_path[-1] == '/':
input_path = input_path[:-1]
cur_path = os.path.join(output_path, str(host) + '.txt')
new_split_sentence = functools.partial(split_sentence, limit=seq_len - 2)
with open(cur_path, 'w', encoding='utf-8') as f:
for fi, fin_path in enumerate(fin_list):
if not os.path.exists(os.path.join(input_path, fin_path[0])):
continue
if '.json' not in fin_path[0]:
continue
print("Processing ", fin_path[0], " ", fi)
with open(os.path.join(input_path, fin_path[0]), 'r') as fin:
f_data = [l['content'] for l in json.load(fin)]
pool = multiprocessing.Pool(workers)
all_sent = pool.imap_unordered(new_split_sentence, f_data, 32)
pool.close()
print('finished..')
cnt = 0
for d in tqdm(all_sent):
for i in d:
f.write(i.strip() + '\n')
f.write(']]' + '\n')
cnt += 1
# if cnt >= 2:
# exit()
def getFileSize(filepath, shard):
all_data = []
for i in os.listdir(filepath):
all_data.append(os.path.join(filepath, i))
all_size = sum([os.path.getsize(os.path.join(filepath, f)) for f in all_data])
ans = [[f.split('/')[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data]
ans = sorted(ans, key=lambda x: x[1], reverse=True)
per_size = all_size / shard
real_shard = []
temp = []
accu_size = 0
for i in ans:
accu_size += i[1]
temp.append(i)
if accu_size > per_size:
real_shard.append(temp)
accu_size = 0
temp = []
if len(temp) > 0:
real_shard.append(temp)
return real_shard
def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'):
import socket
host = int(socket.gethostname().split(server_name)[-1])
fin_list = real_shard[server_num * base + host - 1]
print(fin_list)
print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}')
return fin_list, host
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--server_num', type=int, default=10, help='number of servers')
parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
parser.add_argument('--shard', type=int, default=100, help='number of shards, e.g., 10, 50, or 100')
parser.add_argument('--input_path', type=str, required=True, help='input path of original corpus')
parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence')
args = parser.parse_args()
server_num = args.server_num
seq_len = args.seq_len
shard = args.shard
input_path = args.input_path
output_path = args.output_path
real_shard = getFileSize(input_path, shard)
start = time.time()
for index, shard in enumerate(real_shard):
get_sent(output_path, input_path, fin_list=shard, host=index, seq_len=seq_len)
print(f'cost {str(time.time() - start)}')
# if you have multiple server, you can use code below or modify code to openmpi
# for i in range(len(real_shard) // server_num + 1):
# fin_list, host = get_start_end(real_shard, i)
# start = time.time()
# get_sent(output_path,
# input_path,
# fin_list=fin_list, host= 10 * i + host - 1)
# print(f'cost {str(time.time() - start)}')

View File

@@ -0,0 +1,267 @@
import argparse
import multiprocessing
import os
import socket
import time
from random import shuffle
import h5py
import numpy as np
import psutil
from get_mask import PreTrainingDataset
from tqdm import tqdm
from transformers import AutoTokenizer
def get_raw_instance(document, max_sequence_length=512):
"""
Get the initial training instances, split the whole segment into multiple parts according to the max_sequence_length, and return as multiple processed instances.
:param document: document
:param max_sequence_length:
:return: a list. each element is a sequence of text
"""
# document = self.documents[index]
max_sequence_length_allowed = max_sequence_length - 2
# document = [seq for seq in document if len(seq)<max_sequence_length_allowed]
sizes = [len(seq) for seq in document]
result_list = []
curr_seq = []
sz_idx = 0
while sz_idx < len(sizes):
if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0:
curr_seq += document[sz_idx]
sz_idx += 1
elif sizes[sz_idx] >= max_sequence_length_allowed:
if len(curr_seq) > 0:
result_list.append(curr_seq)
curr_seq = []
result_list.append(document[sz_idx][:max_sequence_length_allowed])
sz_idx += 1
else:
result_list.append(curr_seq)
curr_seq = []
if len(curr_seq) > max_sequence_length_allowed / 2: # /2
result_list.append(curr_seq)
# num_instance=int(len(big_list)/max_sequence_length_allowed)+1
# print("num_instance:",num_instance)
# result_list=[]
# for j in range(num_instance):
# index=j*max_sequence_length_allowed
# end_index=index+max_sequence_length_allowed if j!=num_instance-1 else -1
# result_list.append(big_list[index:end_index])
return result_list
def split_numpy_chunk(path, tokenizer, pretrain_data, host):
documents = []
instances = []
s = time.time()
with open(path, encoding='utf-8') as fd:
document = []
for i, line in enumerate(tqdm(fd)):
line = line.strip()
# document = line
# if len(document.split("<sep>")) <= 3:
# continue
if len(line) > 0 and line[:2] == "]]": # This is end of document
documents.append(document)
document = []
elif len(line) >= 2:
document.append(line)
if len(document) > 0:
documents.append(document)
print('read_file ', time.time() - s)
# documents = [x for x in documents if x]
# print(len(documents))
# print(len(documents[0]))
# print(documents[0][0:10])
import multiprocessing
from typing import List
ans = []
for docs in tqdm(documents):
ans.append(pretrain_data.tokenize(docs))
print(time.time() - s)
del documents
instances = []
for a in tqdm(ans):
raw_ins = get_raw_instance(a)
instances.extend(raw_ins)
del ans
print('len instance', len(instances))
sen_num = len(instances)
seq_len = 512
input_ids = np.zeros([sen_num, seq_len], dtype=np.int32)
input_mask = np.zeros([sen_num, seq_len], dtype=np.int32)
segment_ids = np.zeros([sen_num, seq_len], dtype=np.int32)
masked_lm_output = np.zeros([sen_num, seq_len], dtype=np.int32)
for index, ins in tqdm(enumerate(instances)):
mask_dict = pretrain_data.create_training_instance(ins)
input_ids[index] = mask_dict[0]
input_mask[index] = mask_dict[1]
segment_ids[index] = mask_dict[2]
masked_lm_output[index] = mask_dict[3]
with h5py.File(f'/output/{host}.h5', 'w') as hf:
hf.create_dataset("input_ids", data=input_ids)
hf.create_dataset("input_mask", data=input_ids)
hf.create_dataset("segment_ids", data=segment_ids)
hf.create_dataset("masked_lm_positions", data=masked_lm_output)
del instances
def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_factor, seq_len, file_name):
if os.path.exists(os.path.join(output_path, f'{file_name}.h5')):
print(f'{file_name}.h5 exists')
return
documents = []
instances = []
s = time.time()
with open(input_path, 'r', encoding='utf-8') as fd:
document = []
for i, line in enumerate(tqdm(fd)):
line = line.strip()
if len(line) > 0 and line[:2] == "]]": # This is end of document
documents.append(document)
document = []
elif len(line) >= 2:
document.append(line)
if len(document) > 0:
documents.append(document)
print(f'read_file cost {time.time() - s}, length is {len(documents)}')
ans = []
s = time.time()
pool = multiprocessing.Pool(worker)
encoded_doc = pool.imap_unordered(pretrain_data.tokenize, documents, 100)
for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour='cyan'):
ans.append(res)
pool.close()
print((time.time() - s) / 60)
del documents
instances = []
for a in tqdm(ans, colour='MAGENTA'):
raw_ins = get_raw_instance(a, max_sequence_length=seq_len)
instances.extend(raw_ins)
del ans
print('len instance', len(instances))
new_instances = []
for _ in range(dupe_factor):
for ins in instances:
new_instances.append(ins)
shuffle(new_instances)
instances = new_instances
print('after dupe_factor, len instance', len(instances))
sentence_num = len(instances)
input_ids = np.zeros([sentence_num, seq_len], dtype=np.int32)
input_mask = np.zeros([sentence_num, seq_len], dtype=np.int32)
segment_ids = np.zeros([sentence_num, seq_len], dtype=np.int32)
masked_lm_output = np.zeros([sentence_num, seq_len], dtype=np.int32)
s = time.time()
pool = multiprocessing.Pool(worker)
encoded_docs = pool.imap_unordered(pretrain_data.create_training_instance, instances, 32)
for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour='blue'):
input_ids[index] = mask_dict[0]
input_mask[index] = mask_dict[1]
segment_ids[index] = mask_dict[2]
masked_lm_output[index] = mask_dict[3]
pool.close()
print((time.time() - s) / 60)
with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf:
hf.create_dataset("input_ids", data=input_ids)
hf.create_dataset("input_mask", data=input_mask)
hf.create_dataset("segment_ids", data=segment_ids)
hf.create_dataset("masked_lm_positions", data=masked_lm_output)
del instances
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer')
parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
parser.add_argument('--max_predictions_per_seq',
type=int,
default=80,
help='number of shards, e.g., 10, 50, or 100')
parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence')
parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id')
parser.add_argument('--backend',
type=str,
default='python',
help='backend of mask token, python, c++, numpy respectively')
parser.add_argument(
'--dupe_factor',
type=int,
default=1,
help='specifies how many times the preprocessor repeats to create the input from the same article/document')
parser.add_argument('--worker', type=int, default=32, help='number of process')
parser.add_argument('--server_num', type=int, default=10, help='number of servers')
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
pretrain_data = PreTrainingDataset(tokenizer,
args.seq_len,
args.backend,
max_predictions_per_seq=args.max_predictions_per_seq)
data_len = len(os.listdir(args.input_path))
for i in range(data_len):
input_path = os.path.join(args.input_path, f'{i}.txt')
if os.path.exists(input_path):
start = time.time()
print(f'process {input_path}')
split_numpy_chunk_pool(input_path, args.output_path, pretrain_data, args.worker, args.dupe_factor,
args.seq_len, i)
end_ = time.time()
print(u'memory%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024))
print(f'has cost {(end_ - start) / 60}')
print('-' * 100)
print('')
# if you have multiple server, you can use code below or modify code to openmpi
# host = int(socket.gethostname().split('GPU')[-1])
# for i in range(data_len // args.server_num + 1):
# h = args.server_num * i + host - 1
# input_path = os.path.join(args.input_path, f'{h}.txt')
# if os.path.exists(input_path):
# start = time.time()
# print(f'I am server {host}, process {input_path}')
# split_numpy_chunk_pool(input_path,
# args.output_path,
# pretrain_data,
# args.worker,
# args.dupe_factor,
# args.seq_len,
# h)
# end_ = time.time()
# print(u'memory%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
# print(f'has cost {(end_ - start) / 60}')
# print('-' * 100)
# print('')