mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 13:59:08 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -14,17 +14,19 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s
|
||||
sent_list = []
|
||||
try:
|
||||
if flag == "zh":
|
||||
document = re.sub('(?P<quotation_mark>([。?!…](?![”’"\'])))', r'\g<quotation_mark>\n', document)
|
||||
document = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\'])', r'\g<quotation_mark>\n', document)
|
||||
document = re.sub("(?P<quotation_mark>([。?!…](?![”’\"'])))", r"\g<quotation_mark>\n", document)
|
||||
document = re.sub("(?P<quotation_mark>([。?!]|…{1,2})[”’\"'])", r"\g<quotation_mark>\n", document)
|
||||
elif flag == "en":
|
||||
document = re.sub('(?P<quotation_mark>([.?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
|
||||
document = re.sub('(?P<quotation_mark>([?!.]["\']))', r'\g<quotation_mark>\n',
|
||||
document) # Special quotation marks
|
||||
document = re.sub("(?P<quotation_mark>([.?!](?![”’\"'])))", r"\g<quotation_mark>\n", document)
|
||||
document = re.sub(
|
||||
"(?P<quotation_mark>([?!.][\"']))", r"\g<quotation_mark>\n", document
|
||||
) # Special quotation marks
|
||||
else:
|
||||
document = re.sub('(?P<quotation_mark>([。?!….?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
|
||||
document = re.sub("(?P<quotation_mark>([。?!….?!](?![”’\"'])))", r"\g<quotation_mark>\n", document)
|
||||
|
||||
document = re.sub('(?P<quotation_mark>(([。?!.!?]|…{1,2})[”’"\']))', r'\g<quotation_mark>\n',
|
||||
document) # Special quotation marks
|
||||
document = re.sub(
|
||||
"(?P<quotation_mark>(([。?!.!?]|…{1,2})[”’\"']))", r"\g<quotation_mark>\n", document
|
||||
) # Special quotation marks
|
||||
|
||||
sent_list_ori = document.splitlines()
|
||||
for sent in sent_list_ori:
|
||||
@@ -46,36 +48,35 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s
|
||||
|
||||
|
||||
def get_sent(output_path, input_path, fin_list=[], host=-1, seq_len=512) -> None:
|
||||
|
||||
workers = 32
|
||||
|
||||
if input_path[-1] == '/':
|
||||
if input_path[-1] == "/":
|
||||
input_path = input_path[:-1]
|
||||
|
||||
cur_path = os.path.join(output_path, str(host) + '.txt')
|
||||
cur_path = os.path.join(output_path, str(host) + ".txt")
|
||||
new_split_sentence = functools.partial(split_sentence, limit=seq_len - 2)
|
||||
with open(cur_path, 'w', encoding='utf-8') as f:
|
||||
with open(cur_path, "w", encoding="utf-8") as f:
|
||||
for fi, fin_path in enumerate(fin_list):
|
||||
if not os.path.exists(os.path.join(input_path, fin_path[0])):
|
||||
continue
|
||||
if '.json' not in fin_path[0]:
|
||||
if ".json" not in fin_path[0]:
|
||||
continue
|
||||
|
||||
print("Processing ", fin_path[0], " ", fi)
|
||||
|
||||
with open(os.path.join(input_path, fin_path[0]), 'r') as fin:
|
||||
f_data = [l['content'] for l in json.load(fin)]
|
||||
with open(os.path.join(input_path, fin_path[0]), "r") as fin:
|
||||
f_data = [l["content"] for l in json.load(fin)]
|
||||
|
||||
pool = multiprocessing.Pool(workers)
|
||||
all_sent = pool.imap_unordered(new_split_sentence, f_data, 32)
|
||||
pool.close()
|
||||
print('finished..')
|
||||
print("finished..")
|
||||
|
||||
cnt = 0
|
||||
for d in tqdm(all_sent):
|
||||
for i in d:
|
||||
f.write(i.strip() + '\n')
|
||||
f.write(']]' + '\n')
|
||||
f.write(i.strip() + "\n")
|
||||
f.write("]]" + "\n")
|
||||
cnt += 1
|
||||
# if cnt >= 2:
|
||||
# exit()
|
||||
@@ -86,7 +87,7 @@ def getFileSize(filepath, shard):
|
||||
for i in os.listdir(filepath):
|
||||
all_data.append(os.path.join(filepath, i))
|
||||
all_size = sum([os.path.getsize(os.path.join(filepath, f)) for f in all_data])
|
||||
ans = [[f.split('/')[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data]
|
||||
ans = [[f.split("/")[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data]
|
||||
ans = sorted(ans, key=lambda x: x[1], reverse=True)
|
||||
per_size = all_size / shard
|
||||
real_shard = []
|
||||
@@ -106,24 +107,24 @@ def getFileSize(filepath, shard):
|
||||
return real_shard
|
||||
|
||||
|
||||
def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'):
|
||||
def get_start_end(real_shard, base=0, server_num=10, server_name="GPU"):
|
||||
import socket
|
||||
|
||||
host = int(socket.gethostname().split(server_name)[-1])
|
||||
|
||||
fin_list = real_shard[server_num * base + host - 1]
|
||||
print(fin_list)
|
||||
print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}')
|
||||
print(f"I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}")
|
||||
return fin_list, host
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--server_num', type=int, default=10, help='number of servers')
|
||||
parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
|
||||
parser.add_argument('--shard', type=int, default=100, help='number of shards, e.g., 10, 50, or 100')
|
||||
parser.add_argument('--input_path', type=str, required=True, help='input path of original corpus')
|
||||
parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence')
|
||||
parser.add_argument("--server_num", type=int, default=10, help="number of servers")
|
||||
parser.add_argument("--seq_len", type=int, default=512, help="sequence length")
|
||||
parser.add_argument("--shard", type=int, default=100, help="number of shards, e.g., 10, 50, or 100")
|
||||
parser.add_argument("--input_path", type=str, required=True, help="input path of original corpus")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="output path of shard which has split sentence")
|
||||
args = parser.parse_args()
|
||||
|
||||
server_num = args.server_num
|
||||
@@ -137,7 +138,7 @@ if __name__ == '__main__':
|
||||
start = time.time()
|
||||
for index, shard in enumerate(real_shard):
|
||||
get_sent(output_path, input_path, fin_list=shard, host=index, seq_len=seq_len)
|
||||
print(f'cost {str(time.time() - start)}')
|
||||
print(f"cost {str(time.time() - start)}")
|
||||
|
||||
# if you have multiple server, you can use code below or modify code to openmpi
|
||||
|
||||
|
Reference in New Issue
Block a user