mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,13 +1,8 @@
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from enum import IntEnum
|
||||
from random import choice
|
||||
|
||||
import jieba
|
||||
import torch
|
||||
|
||||
jieba.setLogLevel(logging.CRITICAL)
|
||||
import re
|
||||
@@ -23,14 +18,15 @@ def map_to_numpy(data):
|
||||
return np.asarray(data)
|
||||
|
||||
|
||||
class PreTrainingDataset():
|
||||
|
||||
def __init__(self,
|
||||
tokenizer,
|
||||
max_seq_length,
|
||||
backend='python',
|
||||
max_predictions_per_seq: int = 80,
|
||||
do_whole_word_mask: bool = True):
|
||||
class PreTrainingDataset:
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
max_seq_length,
|
||||
backend="python",
|
||||
max_predictions_per_seq: int = 80,
|
||||
do_whole_word_mask: bool = True,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_seq_length = max_seq_length
|
||||
self.masked_lm_prob = 0.15
|
||||
@@ -38,8 +34,8 @@ class PreTrainingDataset():
|
||||
self.do_whole_word_mask = do_whole_word_mask
|
||||
self.max_predictions_per_seq = max_predictions_per_seq
|
||||
self.vocab_words = list(tokenizer.vocab.keys())
|
||||
self.rec = re.compile('[\u4E00-\u9FA5]')
|
||||
self.whole_rec = re.compile('##[\u4E00-\u9FA5]')
|
||||
self.rec = re.compile("[\u4E00-\u9FA5]")
|
||||
self.whole_rec = re.compile("##[\u4E00-\u9FA5]")
|
||||
|
||||
self.mlm_p = 0.15
|
||||
self.mlm_mask_p = 0.8
|
||||
@@ -64,7 +60,7 @@ class PreTrainingDataset():
|
||||
original_tokens = []
|
||||
segment_ids = []
|
||||
tokens.append("[CLS]")
|
||||
original_tokens.append('[CLS]')
|
||||
original_tokens.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
for index, token in enumerate(tokens_a):
|
||||
tokens.append(token)
|
||||
@@ -72,7 +68,7 @@ class PreTrainingDataset():
|
||||
segment_ids.append(0)
|
||||
|
||||
tokens.append("[SEP]")
|
||||
original_tokens.append('[SEP]')
|
||||
original_tokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
|
||||
# for token in tokens_b:
|
||||
@@ -83,11 +79,16 @@ class PreTrainingDataset():
|
||||
# segment_ids.append(1)
|
||||
|
||||
# Get Masked LM predictions
|
||||
if self.backend == 'c++':
|
||||
if self.backend == "c++":
|
||||
output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions(
|
||||
tokens, original_tokens, self.vocab_words, self.tokenizer.vocab, self.max_predictions_per_seq,
|
||||
self.masked_lm_prob)
|
||||
elif self.backend == 'python':
|
||||
tokens,
|
||||
original_tokens,
|
||||
self.vocab_words,
|
||||
self.tokenizer.vocab,
|
||||
self.max_predictions_per_seq,
|
||||
self.masked_lm_prob,
|
||||
)
|
||||
elif self.backend == "python":
|
||||
output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens)
|
||||
|
||||
# Convert to Ids
|
||||
@@ -99,20 +100,20 @@ class PreTrainingDataset():
|
||||
segment_ids.append(PAD)
|
||||
input_mask.append(PAD)
|
||||
masked_lm_output.append(-1)
|
||||
return ([
|
||||
return [
|
||||
map_to_numpy(input_ids),
|
||||
map_to_numpy(input_mask),
|
||||
map_to_numpy(segment_ids),
|
||||
map_to_numpy(masked_lm_output),
|
||||
map_to_numpy([is_next])
|
||||
])
|
||||
map_to_numpy([is_next]),
|
||||
]
|
||||
|
||||
def create_masked_lm_predictions(self, tokens):
|
||||
cand_indexes = []
|
||||
for i, token in enumerate(tokens):
|
||||
if token == "[CLS]" or token == "[SEP]":
|
||||
continue
|
||||
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")):
|
||||
if self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##"):
|
||||
cand_indexes[-1].append(i)
|
||||
else:
|
||||
cand_indexes.append([i])
|
||||
@@ -160,7 +161,7 @@ class PreTrainingDataset():
|
||||
Input a sentence, return a processed sentence: In order to support the Chinese whole word mask, the words that are separated will be marked with a special mark ("#"), so that the subsequent processing module can know which words belong to the same word.
|
||||
:param segment: a sentence
|
||||
"""
|
||||
seq_cws = jieba.lcut(''.join(segment))
|
||||
seq_cws = jieba.lcut("".join(segment))
|
||||
seq_cws_dict = {x: 1 for x in seq_cws}
|
||||
new_segment = []
|
||||
i = 0
|
||||
@@ -174,10 +175,10 @@ class PreTrainingDataset():
|
||||
for length in range(3, 0, -1):
|
||||
if i + length > len(segment):
|
||||
continue
|
||||
if ''.join(segment[i:i + length]) in seq_cws_dict:
|
||||
if "".join(segment[i : i + length]) in seq_cws_dict:
|
||||
new_segment.append(segment[i])
|
||||
for l in range(1, length):
|
||||
new_segment.append('##' + segment[i + l])
|
||||
new_segment.append("##" + segment[i + l])
|
||||
i += length
|
||||
has_add = True
|
||||
break
|
||||
@@ -190,7 +191,7 @@ class PreTrainingDataset():
|
||||
"""Creates the predictions for the masked LM objective."""
|
||||
|
||||
cand_indexes = []
|
||||
for (i, token) in enumerate(tokens):
|
||||
for i, token in enumerate(tokens):
|
||||
if token == "[CLS]" or token == "[SEP]":
|
||||
continue
|
||||
# Whole Word Masking means that if we mask all of the wordpieces
|
||||
@@ -202,14 +203,14 @@ class PreTrainingDataset():
|
||||
# Note that Whole Word Masking does *not* change the training code
|
||||
# at all -- we still predict each WordPiece independently, softmaxed
|
||||
# over the entire vocabulary.
|
||||
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")):
|
||||
if self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##"):
|
||||
cand_indexes[-1].append(i)
|
||||
else:
|
||||
cand_indexes.append([i])
|
||||
|
||||
random.shuffle(cand_indexes)
|
||||
|
||||
output_tokens = [t[2:] if len(self.whole_rec.findall(t)) > 0 else t for t in tokens] # 去掉"##"
|
||||
output_tokens = [t[2:] if len(self.whole_rec.findall(t)) > 0 else t for t in tokens] # 去掉"##"
|
||||
|
||||
num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob))))
|
||||
|
||||
@@ -239,8 +240,9 @@ class PreTrainingDataset():
|
||||
else:
|
||||
# 10% of the time, keep original
|
||||
if random.random() < 0.5:
|
||||
masked_token = tokens[index][2:] if len(self.whole_rec.findall(
|
||||
tokens[index])) > 0 else tokens[index] # 去掉"##"
|
||||
masked_token = (
|
||||
tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index]
|
||||
) # 去掉"##"
|
||||
# 10% of the time, replace with random word
|
||||
else:
|
||||
masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)]
|
||||
@@ -250,7 +252,9 @@ class PreTrainingDataset():
|
||||
masked_lms.append(
|
||||
MaskedLMInstance(
|
||||
index=index,
|
||||
label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index]))
|
||||
label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index],
|
||||
)
|
||||
)
|
||||
assert len(masked_lms) <= num_to_predict
|
||||
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
||||
masked_lm_output = [-1] * len(output_tokens)
|
||||
|
@@ -14,17 +14,19 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s
|
||||
sent_list = []
|
||||
try:
|
||||
if flag == "zh":
|
||||
document = re.sub('(?P<quotation_mark>([。?!…](?![”’"\'])))', r'\g<quotation_mark>\n', document)
|
||||
document = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\'])', r'\g<quotation_mark>\n', document)
|
||||
document = re.sub("(?P<quotation_mark>([。?!…](?![”’\"'])))", r"\g<quotation_mark>\n", document)
|
||||
document = re.sub("(?P<quotation_mark>([。?!]|…{1,2})[”’\"'])", r"\g<quotation_mark>\n", document)
|
||||
elif flag == "en":
|
||||
document = re.sub('(?P<quotation_mark>([.?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
|
||||
document = re.sub('(?P<quotation_mark>([?!.]["\']))', r'\g<quotation_mark>\n',
|
||||
document) # Special quotation marks
|
||||
document = re.sub("(?P<quotation_mark>([.?!](?![”’\"'])))", r"\g<quotation_mark>\n", document)
|
||||
document = re.sub(
|
||||
"(?P<quotation_mark>([?!.][\"']))", r"\g<quotation_mark>\n", document
|
||||
) # Special quotation marks
|
||||
else:
|
||||
document = re.sub('(?P<quotation_mark>([。?!….?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
|
||||
document = re.sub("(?P<quotation_mark>([。?!….?!](?![”’\"'])))", r"\g<quotation_mark>\n", document)
|
||||
|
||||
document = re.sub('(?P<quotation_mark>(([。?!.!?]|…{1,2})[”’"\']))', r'\g<quotation_mark>\n',
|
||||
document) # Special quotation marks
|
||||
document = re.sub(
|
||||
"(?P<quotation_mark>(([。?!.!?]|…{1,2})[”’\"']))", r"\g<quotation_mark>\n", document
|
||||
) # Special quotation marks
|
||||
|
||||
sent_list_ori = document.splitlines()
|
||||
for sent in sent_list_ori:
|
||||
@@ -46,36 +48,35 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s
|
||||
|
||||
|
||||
def get_sent(output_path, input_path, fin_list=[], host=-1, seq_len=512) -> None:
|
||||
|
||||
workers = 32
|
||||
|
||||
if input_path[-1] == '/':
|
||||
if input_path[-1] == "/":
|
||||
input_path = input_path[:-1]
|
||||
|
||||
cur_path = os.path.join(output_path, str(host) + '.txt')
|
||||
cur_path = os.path.join(output_path, str(host) + ".txt")
|
||||
new_split_sentence = functools.partial(split_sentence, limit=seq_len - 2)
|
||||
with open(cur_path, 'w', encoding='utf-8') as f:
|
||||
with open(cur_path, "w", encoding="utf-8") as f:
|
||||
for fi, fin_path in enumerate(fin_list):
|
||||
if not os.path.exists(os.path.join(input_path, fin_path[0])):
|
||||
continue
|
||||
if '.json' not in fin_path[0]:
|
||||
if ".json" not in fin_path[0]:
|
||||
continue
|
||||
|
||||
print("Processing ", fin_path[0], " ", fi)
|
||||
|
||||
with open(os.path.join(input_path, fin_path[0]), 'r') as fin:
|
||||
f_data = [l['content'] for l in json.load(fin)]
|
||||
with open(os.path.join(input_path, fin_path[0]), "r") as fin:
|
||||
f_data = [l["content"] for l in json.load(fin)]
|
||||
|
||||
pool = multiprocessing.Pool(workers)
|
||||
all_sent = pool.imap_unordered(new_split_sentence, f_data, 32)
|
||||
pool.close()
|
||||
print('finished..')
|
||||
print("finished..")
|
||||
|
||||
cnt = 0
|
||||
for d in tqdm(all_sent):
|
||||
for i in d:
|
||||
f.write(i.strip() + '\n')
|
||||
f.write(']]' + '\n')
|
||||
f.write(i.strip() + "\n")
|
||||
f.write("]]" + "\n")
|
||||
cnt += 1
|
||||
# if cnt >= 2:
|
||||
# exit()
|
||||
@@ -86,7 +87,7 @@ def getFileSize(filepath, shard):
|
||||
for i in os.listdir(filepath):
|
||||
all_data.append(os.path.join(filepath, i))
|
||||
all_size = sum([os.path.getsize(os.path.join(filepath, f)) for f in all_data])
|
||||
ans = [[f.split('/')[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data]
|
||||
ans = [[f.split("/")[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data]
|
||||
ans = sorted(ans, key=lambda x: x[1], reverse=True)
|
||||
per_size = all_size / shard
|
||||
real_shard = []
|
||||
@@ -106,24 +107,24 @@ def getFileSize(filepath, shard):
|
||||
return real_shard
|
||||
|
||||
|
||||
def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'):
|
||||
def get_start_end(real_shard, base=0, server_num=10, server_name="GPU"):
|
||||
import socket
|
||||
|
||||
host = int(socket.gethostname().split(server_name)[-1])
|
||||
|
||||
fin_list = real_shard[server_num * base + host - 1]
|
||||
print(fin_list)
|
||||
print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}')
|
||||
print(f"I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}")
|
||||
return fin_list, host
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--server_num', type=int, default=10, help='number of servers')
|
||||
parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
|
||||
parser.add_argument('--shard', type=int, default=100, help='number of shards, e.g., 10, 50, or 100')
|
||||
parser.add_argument('--input_path', type=str, required=True, help='input path of original corpus')
|
||||
parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence')
|
||||
parser.add_argument("--server_num", type=int, default=10, help="number of servers")
|
||||
parser.add_argument("--seq_len", type=int, default=512, help="sequence length")
|
||||
parser.add_argument("--shard", type=int, default=100, help="number of shards, e.g., 10, 50, or 100")
|
||||
parser.add_argument("--input_path", type=str, required=True, help="input path of original corpus")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="output path of shard which has split sentence")
|
||||
args = parser.parse_args()
|
||||
|
||||
server_num = args.server_num
|
||||
@@ -137,7 +138,7 @@ if __name__ == '__main__':
|
||||
start = time.time()
|
||||
for index, shard in enumerate(real_shard):
|
||||
get_sent(output_path, input_path, fin_list=shard, host=index, seq_len=seq_len)
|
||||
print(f'cost {str(time.time() - start)}')
|
||||
print(f"cost {str(time.time() - start)}")
|
||||
|
||||
# if you have multiple server, you can use code below or modify code to openmpi
|
||||
|
||||
|
@@ -1,7 +1,6 @@
|
||||
import argparse
|
||||
import multiprocessing
|
||||
import os
|
||||
import socket
|
||||
import time
|
||||
from random import shuffle
|
||||
|
||||
@@ -29,8 +28,7 @@ def get_raw_instance(document, max_sequence_length=512):
|
||||
curr_seq = []
|
||||
sz_idx = 0
|
||||
while sz_idx < len(sizes):
|
||||
|
||||
if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0:
|
||||
if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0:
|
||||
curr_seq += document[sz_idx]
|
||||
sz_idx += 1
|
||||
elif sizes[sz_idx] >= max_sequence_length_allowed:
|
||||
@@ -43,7 +41,7 @@ def get_raw_instance(document, max_sequence_length=512):
|
||||
result_list.append(curr_seq)
|
||||
curr_seq = []
|
||||
|
||||
if len(curr_seq) > max_sequence_length_allowed / 2: # /2
|
||||
if len(curr_seq) > max_sequence_length_allowed / 2: # /2
|
||||
result_list.append(curr_seq)
|
||||
|
||||
# num_instance=int(len(big_list)/max_sequence_length_allowed)+1
|
||||
@@ -58,33 +56,30 @@ def get_raw_instance(document, max_sequence_length=512):
|
||||
|
||||
|
||||
def split_numpy_chunk(path, tokenizer, pretrain_data, host):
|
||||
|
||||
documents = []
|
||||
instances = []
|
||||
|
||||
s = time.time()
|
||||
with open(path, encoding='utf-8') as fd:
|
||||
with open(path, encoding="utf-8") as fd:
|
||||
document = []
|
||||
for i, line in enumerate(tqdm(fd)):
|
||||
line = line.strip()
|
||||
# document = line
|
||||
# if len(document.split("<sep>")) <= 3:
|
||||
# continue
|
||||
if len(line) > 0 and line[:2] == "]]": # This is end of document
|
||||
if len(line) > 0 and line[:2] == "]]": # This is end of document
|
||||
documents.append(document)
|
||||
document = []
|
||||
elif len(line) >= 2:
|
||||
document.append(line)
|
||||
if len(document) > 0:
|
||||
documents.append(document)
|
||||
print('read_file ', time.time() - s)
|
||||
print("read_file ", time.time() - s)
|
||||
|
||||
# documents = [x for x in documents if x]
|
||||
# print(len(documents))
|
||||
# print(len(documents[0]))
|
||||
# print(documents[0][0:10])
|
||||
import multiprocessing
|
||||
from typing import List
|
||||
|
||||
ans = []
|
||||
for docs in tqdm(documents):
|
||||
@@ -98,7 +93,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host):
|
||||
instances.extend(raw_ins)
|
||||
del ans
|
||||
|
||||
print('len instance', len(instances))
|
||||
print("len instance", len(instances))
|
||||
|
||||
sen_num = len(instances)
|
||||
seq_len = 512
|
||||
@@ -114,7 +109,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host):
|
||||
segment_ids[index] = mask_dict[2]
|
||||
masked_lm_output[index] = mask_dict[3]
|
||||
|
||||
with h5py.File(f'/output/{host}.h5', 'w') as hf:
|
||||
with h5py.File(f"/output/{host}.h5", "w") as hf:
|
||||
hf.create_dataset("input_ids", data=input_ids)
|
||||
hf.create_dataset("input_mask", data=input_ids)
|
||||
hf.create_dataset("segment_ids", data=segment_ids)
|
||||
@@ -124,45 +119,44 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host):
|
||||
|
||||
|
||||
def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_factor, seq_len, file_name):
|
||||
|
||||
if os.path.exists(os.path.join(output_path, f'{file_name}.h5')):
|
||||
print(f'{file_name}.h5 exists')
|
||||
if os.path.exists(os.path.join(output_path, f"{file_name}.h5")):
|
||||
print(f"{file_name}.h5 exists")
|
||||
return
|
||||
|
||||
documents = []
|
||||
instances = []
|
||||
|
||||
s = time.time()
|
||||
with open(input_path, 'r', encoding='utf-8') as fd:
|
||||
with open(input_path, "r", encoding="utf-8") as fd:
|
||||
document = []
|
||||
for i, line in enumerate(tqdm(fd)):
|
||||
line = line.strip()
|
||||
if len(line) > 0 and line[:2] == "]]": # This is end of document
|
||||
if len(line) > 0 and line[:2] == "]]": # This is end of document
|
||||
documents.append(document)
|
||||
document = []
|
||||
elif len(line) >= 2:
|
||||
document.append(line)
|
||||
if len(document) > 0:
|
||||
documents.append(document)
|
||||
print(f'read_file cost {time.time() - s}, length is {len(documents)}')
|
||||
print(f"read_file cost {time.time() - s}, length is {len(documents)}")
|
||||
|
||||
ans = []
|
||||
s = time.time()
|
||||
pool = multiprocessing.Pool(worker)
|
||||
encoded_doc = pool.imap_unordered(pretrain_data.tokenize, documents, 100)
|
||||
for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour='cyan'):
|
||||
for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour="cyan"):
|
||||
ans.append(res)
|
||||
pool.close()
|
||||
print((time.time() - s) / 60)
|
||||
del documents
|
||||
|
||||
instances = []
|
||||
for a in tqdm(ans, colour='MAGENTA'):
|
||||
for a in tqdm(ans, colour="MAGENTA"):
|
||||
raw_ins = get_raw_instance(a, max_sequence_length=seq_len)
|
||||
instances.extend(raw_ins)
|
||||
del ans
|
||||
|
||||
print('len instance', len(instances))
|
||||
print("len instance", len(instances))
|
||||
|
||||
new_instances = []
|
||||
for _ in range(dupe_factor):
|
||||
@@ -171,7 +165,7 @@ def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_
|
||||
|
||||
shuffle(new_instances)
|
||||
instances = new_instances
|
||||
print('after dupe_factor, len instance', len(instances))
|
||||
print("after dupe_factor, len instance", len(instances))
|
||||
|
||||
sentence_num = len(instances)
|
||||
input_ids = np.zeros([sentence_num, seq_len], dtype=np.int32)
|
||||
@@ -182,7 +176,7 @@ def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_
|
||||
s = time.time()
|
||||
pool = multiprocessing.Pool(worker)
|
||||
encoded_docs = pool.imap_unordered(pretrain_data.create_training_instance, instances, 32)
|
||||
for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour='blue'):
|
||||
for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour="blue"):
|
||||
input_ids[index] = mask_dict[0]
|
||||
input_mask[index] = mask_dict[1]
|
||||
segment_ids[index] = mask_dict[2]
|
||||
@@ -190,7 +184,7 @@ def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_
|
||||
pool.close()
|
||||
print((time.time() - s) / 60)
|
||||
|
||||
with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf:
|
||||
with h5py.File(os.path.join(output_path, f"{file_name}.h5"), "w") as hf:
|
||||
hf.create_dataset("input_ids", data=input_ids)
|
||||
hf.create_dataset("input_mask", data=input_mask)
|
||||
hf.create_dataset("segment_ids", data=segment_ids)
|
||||
@@ -199,50 +193,48 @@ def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_
|
||||
del instances
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer')
|
||||
parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
|
||||
parser.add_argument('--max_predictions_per_seq',
|
||||
type=int,
|
||||
default=80,
|
||||
help='number of shards, e.g., 10, 50, or 100')
|
||||
parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence')
|
||||
parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id')
|
||||
parser.add_argument('--backend',
|
||||
type=str,
|
||||
default='python',
|
||||
help='backend of mask token, python, c++, numpy respectively')
|
||||
parser.add_argument("--tokenizer_path", type=str, required=True, default=10, help="path of tokenizer")
|
||||
parser.add_argument("--seq_len", type=int, default=512, help="sequence length")
|
||||
parser.add_argument(
|
||||
'--dupe_factor',
|
||||
"--max_predictions_per_seq", type=int, default=80, help="number of shards, e.g., 10, 50, or 100"
|
||||
)
|
||||
parser.add_argument("--input_path", type=str, required=True, help="input path of shard which has split sentence")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="output path of h5 contains token id")
|
||||
parser.add_argument(
|
||||
"--backend", type=str, default="python", help="backend of mask token, python, c++, numpy respectively"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dupe_factor",
|
||||
type=int,
|
||||
default=1,
|
||||
help='specifies how many times the preprocessor repeats to create the input from the same article/document')
|
||||
parser.add_argument('--worker', type=int, default=32, help='number of process')
|
||||
parser.add_argument('--server_num', type=int, default=10, help='number of servers')
|
||||
help="specifies how many times the preprocessor repeats to create the input from the same article/document",
|
||||
)
|
||||
parser.add_argument("--worker", type=int, default=32, help="number of process")
|
||||
parser.add_argument("--server_num", type=int, default=10, help="number of servers")
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
|
||||
pretrain_data = PreTrainingDataset(tokenizer,
|
||||
args.seq_len,
|
||||
args.backend,
|
||||
max_predictions_per_seq=args.max_predictions_per_seq)
|
||||
pretrain_data = PreTrainingDataset(
|
||||
tokenizer, args.seq_len, args.backend, max_predictions_per_seq=args.max_predictions_per_seq
|
||||
)
|
||||
|
||||
data_len = len(os.listdir(args.input_path))
|
||||
|
||||
for i in range(data_len):
|
||||
input_path = os.path.join(args.input_path, f'{i}.txt')
|
||||
input_path = os.path.join(args.input_path, f"{i}.txt")
|
||||
if os.path.exists(input_path):
|
||||
start = time.time()
|
||||
print(f'process {input_path}')
|
||||
split_numpy_chunk_pool(input_path, args.output_path, pretrain_data, args.worker, args.dupe_factor,
|
||||
args.seq_len, i)
|
||||
print(f"process {input_path}")
|
||||
split_numpy_chunk_pool(
|
||||
input_path, args.output_path, pretrain_data, args.worker, args.dupe_factor, args.seq_len, i
|
||||
)
|
||||
end_ = time.time()
|
||||
print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024))
|
||||
print(f'has cost {(end_ - start) / 60}')
|
||||
print('-' * 100)
|
||||
print('')
|
||||
print("memory:%.4f GB" % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024))
|
||||
print(f"has cost {(end_ - start) / 60}")
|
||||
print("-" * 100)
|
||||
print("")
|
||||
|
||||
# if you have multiple server, you can use code below or modify code to openmpi
|
||||
|
||||
|
Reference in New Issue
Block a user