add RoBERTa (#1980)

* update roberta

* update roberta & readme

* update roberta & readme

* update roberta & readme
This commit is contained in:
mandoxzhang 2022-11-18 14:04:49 +08:00 committed by GitHub
parent 31922110ad
commit 52bd106627
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 5814 additions and 0 deletions

View File

@ -0,0 +1,58 @@
# Introduction
This repo introduce how to pretrain a chinese roberta-large from scratch, including preprocessing, pretraining, finetune. The repo can help you quickly train a high-quality bert.
## 0. Prerequisite
- Install Colossal-AI
- Editing the port from /etc/ssh/sshd_config and /etc/ssh/ssh_config, every host expose the same ssh port of server and client. If you are a root user, you also set the **PermitRootLogin** from /etc/ssh/sshd_config to "yes"
- Ensure that each host can log in to each other without password. If you have n hosts, need to execute n<sup>2</sup> times
```
ssh-keygen
ssh-copy-id -i ~/.ssh/id_rsa.pub ip_destination
```
- In all hosts, edit /etc/hosts to record all hosts' name and ip.The example is shown below.
```bash
192.168.2.1 GPU001
192.168.2.2 GPU002
192.168.2.3 GPU003
192.168.2.4 GPU004
192.168.2.5 GPU005
192.168.2.6 GPU006
192.168.2.7 GPU007
...
```
- restart ssh
```
service ssh restart
```
## 1. Corpus Preprocessing
```bash
cd preprocessing
```
following the `README.md`, preprocess orginal corpus to h5py+numpy
## 2. Pretrain
```bash
cd pretraining
```
following the `README.md`, load the h5py generated by preprocess of step 1 to pretrain the model
## 3. Finetune
The checkpoint produced by this repo can replace `pytorch_model.bin` from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) directly. Then use transfomers from HuggingFace to finetune downstream application.
## Contributors
The repo is contributed by AI team from [Moore Threads](https://www.mthreads.com/). If you find any problems for pretraining, please file an issue or send an email to yehua.zhang@mthreads.com. At last, welcome any form of contribution!
```
@misc{
title={A simple Chinese RoBERTa Example for Whole Word Masked},
author={Yehua Zhang, Chen Zhang},
year={2022}
}
```

View File

@ -0,0 +1,4 @@
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.nn.optimizer import FusedAdam
clip_grad_norm = 1.0

View File

@ -0,0 +1,32 @@
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.nn.optimizer import FusedAdam
# fp16 = dict(
# mode=AMP_TYPE.TORCH,
# )
# seed = 2
zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(),
reduce_scatter_bucket_size_mb=25,
fp32_reduce_scatter=False,
tensor_placement_policy="cuda",
gradient_predivide_factor=1.0,
reuse_fp16_shard=False),
optimizer_config=dict(gpu_margin_mem_ratio=0.8,
initial_scale=2**5,
min_scale=1,
growth_factor=2,
backoff_factor=0.5,
growth_interval=1000,
hysteresis=2,
max_scale=2**32))
# gradient_accumulation = 4
clip_grad_norm = 1.0
optimizer = dict(
type=FusedAdam,
lr=0.00015,
weight_decay=1e-2,
)
# 64433

View File

@ -0,0 +1,9 @@
CXXFLAGS += -O3 -Wall -shared -std=c++14 -fPIC -fdiagnostics-color
CPPFLAGS += $(shell python3 -m pybind11 --includes)
LIBNAME = mask
LIBEXT = $(shell python3-config --extension-suffix)
default: $(LIBNAME)$(LIBEXT)
%$(LIBEXT): %.cpp
$(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@

View File

@ -0,0 +1,105 @@
# Data PreProcessing for chinese Whole Word Masked
<span id='all_catelogue'/>
## Catalogue:
* <a href='#introduction'>1. Introduction</a>
* <a href='#Quick Start Guide'>2. Quick Start Guide:</a>
* <a href='#Split Sentence'>2.1. Split Sentence</a>
* <a href='#Tokenizer & Whole Word Masked'>2.2.Tokenizer & Whole Word Masked</a>
<span id='introduction'/>
## 1. Introduction: <a href='#all_catelogue'>[Back to Top]</a>
This folder is used to preprocess chinese corpus with Whole Word Masked. You can obtain corpus from [WuDao](https://resource.wudaoai.cn/home?ind&name=WuDaoCorpora%202.0&id=1394901288847716352). Moreover, data preprocessing is flexible, and you can modify the code based on your needs, hardware or parallel framework(Open MPI, Spark, Dask).
<span id='Quick Start Guide'/>
## 2. Quick Start Guide: <a href='#all_catelogue'>[Back to Top]</a>
<span id='Split Sentence'/>
### 2.1. Split Sentence & Split data into multiple shard:
Firstly, each file has multiple documents, and each document contains multiple sentences. Split sentence through punctuation, such as `。!`. **Secondly, split data into multiple shard based on server hardware (cpu, cpu memory, hard disk) and corpus size.** Each shard contains a part of corpus, and the model needs to train all the shards as one epoch.
In this example, split 200G Corpus into 100 shard, and each shard is about 2G. The size of the shard is memory-dependent, taking into account the number of servers, the memory used by the tokenizer, and the memory used by the multi-process training to read the shard (n data parallel requires n\*shard_size memory). **To sum up, data preprocessing and model pretraining requires fighting with hardware, not just GPU.**
```python
python sentence_split.py --input_path /orginal_corpus --output_path /shard --shard 100
# This step takes a short time
```
* `--input_path`: all original corpus, e.g., /orginal_corpus/0.json /orginal_corpus/1.json ...
* `--output_path`: all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ...
* `--shard`: Number of shard, e.g., 10, 50, or 100
<summary><b>Input json:</b></summary>
```
[
{
"id": 0,
"title": "打篮球",
"content": "我今天去打篮球。不回来吃饭。"
}
{
"id": 1,
"title": "旅游",
"content": "我后天去旅游。下周请假。"
}
]
```
<summary><b>Output txt:</b></summary>
```
我今天去打篮球。
不回来吃饭。
]]
我后天去旅游。
下周请假。
```
<span id='Tokenizer & Whole Word Masked'/>
### 2.2. Tokenizer & Whole Word Masked:
```python
python tokenize_mask.py --input_path /shard --output_path /h5 --tokenizer_path /roberta --backend python
# This step is time consuming and is mainly spent on mask
```
**[optional but recommended]**: the C++ backend with `pybind11` can provide faster speed
```shell
make
```
* `--input_path`: location of all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ...
* `--output_path`: location of all h5 with token_id, input_mask, segment_ids and masked_lm_positions, e.g., /h5/0.h5, /h5/1.h5 ...
* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json. Download config.json, special_tokens_map.json, vocab.txt and tokenzier.json from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main)
* `--backend`: python or c++, **specifies c++ can obtain faster preprocess speed**
* `--dupe_factor`: specifies how many times the preprocessor repeats to create the input from the same article/document
* `--worker`: number of process
<summary><b>Input txt:</b></summary>
```
我今天去打篮球。
不回来吃饭。
]]
我后天去旅游。
下周请假。
```
<summary><b>Output h5+numpy:</b></summary>
```
'input_ids': [[id0,id1,id2,id3,id4,id5,id6,0,0..],
...]
'input_mask': [[1,1,1,1,1,1,0,0..],
...]
'segment_ids': [[0,0,0,0,0,...],
...]
'masked_lm_positions': [[label1,-1,-1,label2,-1...],
...]
```

View File

@ -0,0 +1,266 @@
import torch
import os
from enum import IntEnum
from random import choice
import random
import collections
import time
import logging
import jieba
jieba.setLogLevel(logging.CRITICAL)
import re
import numpy as np
import mask
PAD = 0
MaskedLMInstance = collections.namedtuple("MaskedLMInstance",
["index", "label"])
def map_to_numpy(data):
return np.asarray(data)
class PreTrainingDataset():
def __init__(self,
tokenizer,
max_seq_length,
backend='python',
max_predictions_per_seq: int = 80,
do_whole_word_mask: bool = True):
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.masked_lm_prob = 0.15
self.backend = backend
self.do_whole_word_mask = do_whole_word_mask
self.max_predictions_per_seq = max_predictions_per_seq
self.vocab_words = list(tokenizer.vocab.keys())
self.rec = re.compile('[\u4E00-\u9FA5]')
self.whole_rec = re.compile('##[\u4E00-\u9FA5]')
self.mlm_p = 0.15
self.mlm_mask_p = 0.8
self.mlm_tamper_p = 0.05
self.mlm_maintain_p = 0.1
def tokenize(self, doc):
temp = []
for d in doc:
temp.append(self.tokenizer.tokenize(d))
return temp
def create_training_instance(self, instance):
is_next = 1
raw_text_list = self.get_new_segment(instance)
tokens_a = raw_text_list
assert len(tokens_a) == len(instance)
# tokens_a, tokens_b, is_next = instance.get_values()
# print(f'is_next label:{is_next}')
# Create mapper
tokens = []
original_tokens = []
segment_ids = []
tokens.append("[CLS]")
original_tokens.append('[CLS]')
segment_ids.append(0)
for index, token in enumerate(tokens_a):
tokens.append(token)
original_tokens.append(instance[index])
segment_ids.append(0)
tokens.append("[SEP]")
original_tokens.append('[SEP]')
segment_ids.append(0)
# for token in tokens_b:
# tokens.append(token)
# segment_ids.append(1)
# tokens.append("[SEP]")
# segment_ids.append(1)
# Get Masked LM predictions
if self.backend == 'c++':
output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions(tokens, original_tokens, self.vocab_words,
self.tokenizer.vocab, self.max_predictions_per_seq, self.masked_lm_prob)
elif self.backend == 'python':
output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens)
# Convert to Ids
input_ids = self.tokenizer.convert_tokens_to_ids(output_tokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < self.max_seq_length:
input_ids.append(PAD)
segment_ids.append(PAD)
input_mask.append(PAD)
masked_lm_output.append(-1)
return ([
map_to_numpy(input_ids),
map_to_numpy(input_mask),
map_to_numpy(segment_ids),
map_to_numpy(masked_lm_output),
map_to_numpy([is_next])
])
def create_masked_lm_predictions(self, tokens):
cand_indexes = []
for i, token in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
# cand_indexes.append(i)
random.shuffle(cand_indexes)
output_tokens = list(tokens)
num_to_predict = min(
self.max_predictions_per_seq,
max(1, int(round(len(tokens) * self.masked_lm_prob))))
masked_lms = []
covered_indexes = set()
for index in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
if index in covered_indexes:
continue
covered_indexes.add(index)
masked_token = None
# 80% mask
if random.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% Keep Original
if random.random() < 0.5:
masked_token = tokens[index]
# 10% replace w/ random word
else:
masked_token = self.vocab_words[random.randint(
0,
len(self.vocab_words) - 1)]
output_tokens[index] = masked_token
masked_lms.append(
MaskedLMInstance(index=index, label=tokens[index]))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_output = [-1] * len(output_tokens)
for p in masked_lms:
masked_lm_output[p.index] = self.tokenizer.vocab[p.label]
return (output_tokens, masked_lm_output)
def get_new_segment(self, segment):
"""
输入一句话返回一句经过处理的话: 为了支持中文全称mask将被分开的词将上特殊标记("#")使得后续处理模块能够知道哪些字是属于同一个词的
:param segment: 一句话
:return: 一句处理过的话
"""
seq_cws = jieba.lcut(''.join(segment))
seq_cws_dict = {x: 1 for x in seq_cws}
new_segment = []
i = 0
while i < len(segment):
if len(self.rec.findall(segment[i])) == 0: # 不是中文的,原文加进去。
new_segment.append(segment[i])
i += 1
continue
has_add = False
for length in range(3, 0, -1):
if i + length > len(segment):
continue
if ''.join(segment[i: i+length]) in seq_cws_dict:
new_segment.append(segment[i])
for l in range(1, length):
new_segment.append('##' + segment[i+l])
i += length
has_add = True
break
if not has_add:
new_segment.append(segment[i])
i += 1
return new_segment
def create_whole_masked_lm_predictions(self, tokens):
"""Creates the predictions for the masked LM objective."""
cand_indexes = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
random.shuffle(cand_indexes)
output_tokens = [t[2:] if len(self.whole_rec.findall(t))>0 else t for t in tokens] # 去掉"##"
num_to_predict = min(self.max_predictions_per_seq,
max(1, int(round(len(tokens) * self.masked_lm_prob))))
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if random.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% of the time, keep original
if random.random() < 0.5:
masked_token = tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index] # 去掉"##"
# 10% of the time, replace with random word
else:
masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)]
output_tokens[index] = masked_token
masked_lms.append(MaskedLMInstance(index=index, label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index]))
assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_output = [-1] * len(output_tokens)
for p in masked_lms:
masked_lm_output[p.index] = self.tokenizer.vocab[p.label]
return (output_tokens, masked_lm_output)

View File

@ -0,0 +1,184 @@
#include <algorithm>
#include <iostream>
#include <limits>
#include <math.h>
#include <stdexcept>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <random>
#include <vector>
#include <string>
#include <pybind11/stl.h>
#include <chrono>
#include <tuple>
#include <unordered_set>
#include <unordered_map>
namespace py = pybind11;
const int32_t LONG_SENTENCE_LEN = 512;
struct MaskedLMInstance {
int index;
std::string label;
MaskedLMInstance(int index, std::string label) {
this->index = index;
this->label = label;
}
};
auto get_new_segment(std::vector<std::string> segment, std::vector<std::string> segment_jieba, const std::vector<bool> chinese_vocab) { // const std::unordered_set<std::string> &chinese_vocab
std::unordered_set<std::string> seq_cws_dict;
for (auto word : segment_jieba) {
seq_cws_dict.insert(word);
}
int i = 0;
std::vector<std::string> new_segment;
int segment_size = segment.size();
while (i < segment_size) {
if (!chinese_vocab[i]) { //chinese_vocab.find(segment[i]) == chinese_vocab.end()
new_segment.emplace_back(segment[i]);
i += 1;
continue;
}
bool has_add = false;
for (int length = 3; length >= 1; length--) {
if (i + length > segment_size) {
continue;
}
std::string chinese_word = "";
for (int j = i; j < i + length; j++) {
chinese_word += segment[j];
}
if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) {
new_segment.emplace_back(segment[i]);
for (int j = i + 1; j < i + length; j++) {
new_segment.emplace_back("##" + segment[j]);
}
i += length;
has_add = true;
break;
}
}
if (!has_add) {
new_segment.emplace_back(segment[i]);
i += 1;
}
}
return new_segment;
}
bool startsWith(const std::string& s, const std::string& sub) {
return s.find(sub) == 0 ? true : false;
}
auto create_whole_masked_lm_predictions(std::vector<std::string> &tokens,
const std::vector<std::string> &original_tokens,
const std::vector<std::string> &vocab_words,
std::map<std::string, int> &vocab,
const int max_predictions_per_seq,
const double masked_lm_prob) {
// for (auto item : vocab) {
// std::cout << "key=" << std::string(py::str(item.first)) << ", "
// << "value=" << std::string(py::str(item.second)) << std::endl;
// }
std::vector<std::vector<int> > cand_indexes;
std::vector<int> cand_temp;
int tokens_size = tokens.size();
std::string prefix = "##";
bool do_whole_masked = true;
for (int i = 0; i < tokens_size; i++) {
if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") {
continue;
}
if (do_whole_masked && (cand_indexes.size() > 0) && (tokens[i].rfind(prefix, 0) == 0)) {
cand_temp.emplace_back(i);
}
else {
if (cand_temp.size() > 0) {
cand_indexes.emplace_back(cand_temp);
}
cand_temp.clear();
cand_temp.emplace_back(i);
}
}
auto seed = std::chrono::system_clock::now().time_since_epoch().count();
std::shuffle(cand_indexes.begin(), cand_indexes.end(), std::default_random_engine(seed));
// for (auto i : cand_indexes) {
// for (auto j : i) {
// std::cout << tokens[j] << " ";
// }
// std::cout << std::endl;
// }
// for (auto i : output_tokens) {
// std::cout << i;
// }
// std::cout << std::endl;
int num_to_predict = std::min(max_predictions_per_seq,
std::max(1, int(tokens_size * masked_lm_prob)));
// std::cout << num_to_predict << std::endl;
std::set<int> covered_indexes;
std::vector<int> masked_lm_output(tokens_size, -1);
int vocab_words_len = vocab_words.size();
std::default_random_engine e(seed);
std::uniform_real_distribution<double> u1(0.0, 1.0);
std::uniform_int_distribution<unsigned> u2(0, vocab_words_len - 1);
int mask_cnt = 0;
std::vector<std::string> output_tokens;
output_tokens = original_tokens;
for (auto index_set : cand_indexes) {
if (mask_cnt > num_to_predict) {
break;
}
int index_set_size = index_set.size();
if (mask_cnt + index_set_size > num_to_predict) {
continue;
}
bool is_any_index_covered = false;
for (auto index : index_set) {
if (covered_indexes.find(index) != covered_indexes.end()) {
is_any_index_covered = true;
break;
}
}
if (is_any_index_covered) {
continue;
}
for (auto index : index_set) {
covered_indexes.insert(index);
std::string masked_token;
if (u1(e) < 0.8) {
masked_token = "[MASK]";
}
else {
if (u1(e) < 0.5) {
masked_token = output_tokens[index];
}
else {
int random_index = u2(e);
masked_token = vocab_words[random_index];
}
}
// masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index]));
masked_lm_output[index] = vocab[output_tokens[index]];
output_tokens[index] = masked_token;
mask_cnt++;
}
}
// for (auto p : masked_lms) {
// masked_lm_output[p.index] = vocab[p.label];
// }
return std::make_tuple(output_tokens, masked_lm_output);
}
PYBIND11_MODULE(mask, m) {
m.def("create_whole_masked_lm_predictions", &create_whole_masked_lm_predictions);
m.def("get_new_segment", &get_new_segment);
}

View File

@ -0,0 +1,163 @@
import multiprocessing
import os
import re
from tqdm import tqdm
from typing import List
import json
import time
import argparse
import functools
def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]:
"""
Args:
document:
flag: Type:str, "all" 中英文标点分句"zh" 中文标点分句"en" 英文标点分句
limit: 默认单句最大长度为510个字符
Returns: Type:list
"""
sent_list = []
try:
if flag == "zh":
document = re.sub('(?P<quotation_mark>([。?!…](?![”’"\'])))', r'\g<quotation_mark>\n', document) # 单字符断句符
document = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\'])', r'\g<quotation_mark>\n', document) # 特殊引号
elif flag == "en":
document = re.sub('(?P<quotation_mark>([.?!](?![”’"\'])))', r'\g<quotation_mark>\n', document) # 英文单字符断句符
document = re.sub('(?P<quotation_mark>([?!.]["\']))', r'\g<quotation_mark>\n', document) # 特殊引号
else:
document = re.sub('(?P<quotation_mark>([。?!….?!](?![”’"\'])))', r'\g<quotation_mark>\n', document) # 单字符断句符
document = re.sub('(?P<quotation_mark>(([。?!.!?]|…{1,2})[”’"\']))', r'\g<quotation_mark>\n',
document) # 特殊引号
sent_list_ori = document.splitlines()
for sent in sent_list_ori:
sent = sent.strip()
if not sent:
continue
elif len(sent) <= 2:
continue
else:
while len(sent) > limit:
temp = sent[0:limit]
sent_list.append(temp)
sent = sent[limit:]
sent_list.append(sent)
except:
sent_list.clear()
sent_list.append(document)
return sent_list
def get_sent(output_path,
input_path,
fin_list=[], host=-1, seq_len=512) -> None:
workers = 32
if input_path[-1] == '/':
input_path = input_path[:-1]
cur_path = os.path.join(output_path, str(host) + '.txt')
new_split_sentence = functools.partial(split_sentence, limit=seq_len-2)
with open(cur_path, 'w', encoding='utf-8') as f:
for fi, fin_path in enumerate(fin_list):
if not os.path.exists(os.path.join(input_path, fin_path[0])):
continue
if '.json' not in fin_path[0]:
continue
print("Processing ", fin_path[0], " ", fi)
with open(os.path.join(input_path, fin_path[0]), 'r') as fin:
f_data = [l['content'] for l in json.load(fin)]
pool = multiprocessing.Pool(workers)
all_sent = pool.imap_unordered(new_split_sentence, f_data, 32)
pool.close()
print('finished..')
cnt = 0
for d in tqdm(all_sent):
for i in d:
f.write(i.strip() + '\n')
f.write(']]' + '\n')
cnt += 1
# if cnt >= 2:
# exit()
def getFileSize(filepath, shard):
all_data = []
for i in os.listdir(filepath):
all_data.append(os.path.join(filepath, i))
all_size = sum([os.path.getsize(os.path.join(filepath, f)) for f in all_data])
ans = [[f.split('/')[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data]
ans = sorted(ans, key=lambda x: x[1], reverse=True)
per_size = all_size / shard
real_shard = []
temp = []
accu_size = 0
for i in ans:
accu_size += i[1]
temp.append(i)
if accu_size > per_size:
real_shard.append(temp)
accu_size = 0
temp = []
if len(temp) > 0:
real_shard.append(temp)
return real_shard
def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'):
import socket
host = int(socket.gethostname().split(server_name)[-1])
fin_list = real_shard[server_num * base + host - 1]
print(fin_list)
print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}')
return fin_list, host
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--server_num', type=int, default=10, help='number of servers')
parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
parser.add_argument('--shard', type=int, default=100, help='number of shards, e.g., 10, 50, or 100')
parser.add_argument('--input_path', type=str, required=True, help='input path of original corpus')
parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence')
args = parser.parse_args()
server_num = args.server_num
seq_len = args.seq_len
shard = args.shard
input_path = args.input_path
output_path = args.output_path
real_shard = getFileSize(input_path, shard)
start = time.time()
for index, shard in enumerate(real_shard):
get_sent(output_path,
input_path,
fin_list=shard,
host=index,
seq_len=seq_len)
print(f'cost {str(time.time() - start)}')
# if you have multiple server, you can use code below or modify code to openmpi
# for i in range(len(real_shard) // server_num + 1):
# fin_list, host = get_start_end(real_shard, i)
# start = time.time()
# get_sent(output_path,
# input_path,
# fin_list=fin_list, host= 10 * i + host - 1)
# print(f'cost {str(time.time() - start)}')

View File

@ -0,0 +1,275 @@
import time
import os
import psutil
import h5py
import socket
import argparse
import numpy as np
import multiprocessing
from tqdm import tqdm
from random import shuffle
from transformers import AutoTokenizer
from get_mask import PreTrainingDataset
def get_raw_instance(document, max_sequence_length=512):
"""
获取初步的训练实例将整段按照max_sequence_length切分成多个部分,并以多个处理好的实例的形式返回
:param document: 一整段
:param max_sequence_length:
:return: a list. each element is a sequence of text
"""
# document = self.documents[index]
max_sequence_length_allowed = max_sequence_length - 2
# document = [seq for seq in document if len(seq)<max_sequence_length_allowed]
sizes = [len(seq) for seq in document]
result_list = []
curr_seq = [] # 当前处理的序列
sz_idx = 0
while sz_idx < len(sizes):
# 当前句子加上新的句子,如果长度小于最大限制,则合并当前句子和新句子;否则即超过了最大限制,那么做为一个新的序列加到目标列表中
if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0:
curr_seq += document[sz_idx]
sz_idx += 1
elif sizes[sz_idx] >= max_sequence_length_allowed:
if len(curr_seq) > 0:
result_list.append(curr_seq)
curr_seq = []
result_list.append(document[sz_idx][ : max_sequence_length_allowed])
sz_idx += 1
else:
result_list.append(curr_seq)
curr_seq = []
# 对最后一个序列进行处理,如果太短的话,丢弃掉。
if len(curr_seq) > max_sequence_length_allowed / 2: # /2
result_list.append(curr_seq)
# # 计算总共可以得到多少份
# num_instance=int(len(big_list)/max_sequence_length_allowed)+1
# print("num_instance:",num_instance)
# # 切分成多份,添加到列表中
# result_list=[]
# for j in range(num_instance):
# index=j*max_sequence_length_allowed
# end_index=index+max_sequence_length_allowed if j!=num_instance-1 else -1
# result_list.append(big_list[index:end_index])
return result_list
def split_numpy_chunk(path, tokenizer, pretrain_data, host):
documents = []
instances = []
s = time.time()
with open(path, encoding='utf-8') as fd:
document = []
for i, line in enumerate(tqdm(fd)):
line = line.strip()
# document = line
# if len(document.split("<sep>")) <= 3:
# continue
if len(line
) > 0 and line[:2] == "]]": # This is end of document
documents.append(document)
document = []
elif len(line) >= 2:
document.append(line)
if len(document) > 0:
documents.append(document)
print('read_file ', time.time() - s)
# documents = [x for x in documents if x]
# print(len(documents))
# print(len(documents[0]))
# print(documents[0][0:10])
from typing import List
import multiprocessing
ans = []
for docs in tqdm(documents):
ans.append(pretrain_data.tokenize(docs))
print(time.time() - s)
del documents
instances = []
for a in tqdm(ans):
raw_ins = get_raw_instance(a)
instances.extend(raw_ins)
del ans
print('len instance', len(instances))
sen_num = len(instances)
seq_len = 512
input_ids = np.zeros([sen_num, seq_len], dtype=np.int32)
input_mask = np.zeros([sen_num, seq_len], dtype=np.int32)
segment_ids = np.zeros([sen_num, seq_len], dtype=np.int32)
masked_lm_output = np.zeros([sen_num, seq_len], dtype=np.int32)
for index, ins in tqdm(enumerate(instances)):
mask_dict = pretrain_data.create_training_instance(ins)
input_ids[index] = mask_dict[0]
input_mask[index] = mask_dict[1]
segment_ids[index] = mask_dict[2]
masked_lm_output[index] = mask_dict[3]
with h5py.File(f'/output/{host}.h5', 'w') as hf:
hf.create_dataset("input_ids", data=input_ids)
hf.create_dataset("input_mask", data=input_ids)
hf.create_dataset("segment_ids", data=segment_ids)
hf.create_dataset("masked_lm_positions", data=masked_lm_output)
del instances
def split_numpy_chunk_pool(input_path,
output_path,
pretrain_data,
worker,
dupe_factor,
seq_len,
file_name):
if os.path.exists(os.path.join(output_path, f'{file_name}.h5')):
print(f'{file_name}.h5 exists')
return
documents = []
instances = []
s = time.time()
with open(input_path, 'r', encoding='utf-8') as fd:
document = []
for i, line in enumerate(tqdm(fd)):
line = line.strip()
if len(line
) > 0 and line[:2] == "]]": # This is end of document
documents.append(document)
document = []
elif len(line) >= 2:
document.append(line)
if len(document) > 0:
documents.append(document)
print(f'read_file cost {time.time() - s}, length is {len(documents)}')
ans = []
s = time.time()
pool = multiprocessing.Pool(worker)
encoded_doc = pool.imap_unordered(pretrain_data.tokenize, documents, 100)
for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour='cyan'):
ans.append(res)
pool.close()
print((time.time() - s) / 60)
del documents
instances = []
for a in tqdm(ans, colour='MAGENTA'):
raw_ins = get_raw_instance(a, max_sequence_length=seq_len)
instances.extend(raw_ins)
del ans
print('len instance', len(instances))
new_instances = []
for _ in range(dupe_factor):
for ins in instances:
new_instances.append(ins)
shuffle(new_instances)
instances = new_instances
print('after dupe_factor, len instance', len(instances))
sentence_num = len(instances)
input_ids = np.zeros([sentence_num, seq_len], dtype=np.int32)
input_mask = np.zeros([sentence_num, seq_len], dtype=np.int32)
segment_ids = np.zeros([sentence_num, seq_len], dtype=np.int32)
masked_lm_output = np.zeros([sentence_num, seq_len], dtype=np.int32)
s = time.time()
pool = multiprocessing.Pool(worker)
encoded_docs = pool.imap_unordered(pretrain_data.create_training_instance, instances, 32)
for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour='blue'):
input_ids[index] = mask_dict[0]
input_mask[index] = mask_dict[1]
segment_ids[index] = mask_dict[2]
masked_lm_output[index] = mask_dict[3]
pool.close()
print((time.time() - s) / 60)
with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf:
hf.create_dataset("input_ids", data=input_ids)
hf.create_dataset("input_mask", data=input_mask)
hf.create_dataset("segment_ids", data=segment_ids)
hf.create_dataset("masked_lm_positions", data=masked_lm_output)
del instances
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer')
parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
parser.add_argument('--max_predictions_per_seq', type=int, default=80, help='number of shards, e.g., 10, 50, or 100')
parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence')
parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id')
parser.add_argument('--backend', type=str, default='python', help='backend of mask token, python, c++, numpy respectively')
parser.add_argument('--dupe_factor', type=int, default=1, help='specifies how many times the preprocessor repeats to create the input from the same article/document')
parser.add_argument('--worker', type=int, default=32, help='number of process')
parser.add_argument('--server_num', type=int, default=10, help='number of servers')
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
pretrain_data = PreTrainingDataset(tokenizer,
args.seq_len,
args.backend,
max_predictions_per_seq=args.max_predictions_per_seq)
data_len = len(os.listdir(args.input_path))
for i in range(data_len):
input_path = os.path.join(args.input_path, f'{i}.txt')
if os.path.exists(input_path):
start = time.time()
print(f'process {input_path}')
split_numpy_chunk_pool(input_path,
args.output_path,
pretrain_data,
args.worker,
args.dupe_factor,
args.seq_len,
i)
end_ = time.time()
print(u'memory%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
print(f'has cost {(end_ - start) / 60}')
print('-' * 100)
print('')
# if you have multiple server, you can use code below or modify code to openmpi
# host = int(socket.gethostname().split('GPU')[-1])
# for i in range(data_len // args.server_num + 1):
# h = args.server_num * i + host - 1
# input_path = os.path.join(args.input_path, f'{h}.txt')
# if os.path.exists(input_path):
# start = time.time()
# print(f'I am server {host}, process {input_path}')
# split_numpy_chunk_pool(input_path,
# args.output_path,
# pretrain_data,
# args.worker,
# args.dupe_factor,
# args.seq_len,
# h)
# end_ = time.time()
# print(u'memory%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
# print(f'has cost {(end_ - start) / 60}')
# print('-' * 100)
# print('')

View File

@ -0,0 +1,24 @@
# Pretraining
1. Pretraining roberta through running the script below. Detailed parameter descriptions can be found in the arguments.py. `data_path_prefix` is absolute path specifies output of preprocessing. **You have to modify the *hostfile* according to your cluster.**
```bash
bash run_pretrain.sh
```
* `--hostfile`: servers' host name from /etc/hosts
* `--include`: servers which will be used
* `--nproc_per_node`: number of process(GPU) from each server
* `--data_path_prefix`: absolute location of train data, e.g., /h5/0.h5
* `--eval_data_path_prefix`: absolute location of eval data
* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json, e.g./tokenizer/tokenizer.json
* `--bert_config`: config.json which represent model
* `--mlm`: model type of backbone, bert or deberta_v2
2. if resume training from earylier checkpoint, run the script below.
```shell
bash run_pretrain_resume.sh
```
* `--resume_train`: whether to resume training
* `--load_pretrain_model`: absolute path which contains model checkpoint
* `--load_optimizer_lr`: absolute path which contains optimizer checkpoint

View File

@ -0,0 +1,152 @@
import colossalai
from numpy import require
__all__ = ['parse_args']
def parse_args():
parser = colossalai.get_default_parser()
parser.add_argument(
'--lr',
type=float,
required=True,
help='initial learning rate')
parser.add_argument(
'--epoch',
type=int,
required=True,
help='number of epoch')
parser.add_argument(
'--data_path_prefix',
type=str,
required=True,
help="location of the train data corpus")
parser.add_argument(
'--eval_data_path_prefix',
type=str,
required=True,
help='location of the evaluation data corpus')
parser.add_argument(
'--tokenizer_path',
type=str,
required=True,
help='location of the tokenizer')
parser.add_argument(
'--max_seq_length',
type=int,
default=512,
help='sequence length')
parser.add_argument(
'--refresh_bucket_size',
type=int,
default=1,
help=
"This param makes sure that a certain task is repeated for this time steps to \
optimise on the back propogation speed with APEX's DistributedDataParallel")
parser.add_argument(
"--max_predictions_per_seq",
"--max_pred",
default=80,
type=int,
help=
"The maximum number of masked tokens in a sequence to be predicted.")
parser.add_argument(
"--gradient_accumulation_steps",
default=1,
type=int,
help="accumulation_steps")
parser.add_argument(
"--train_micro_batch_size_per_gpu",
default=2,
type=int,
required=True,
help="train batch size")
parser.add_argument(
"--eval_micro_batch_size_per_gpu",
default=2,
type=int,
required=True,
help="eval batch size")
parser.add_argument(
"--num_workers",
default=8,
type=int,
help="")
parser.add_argument(
"--async_worker",
action='store_true',
help="")
parser.add_argument(
"--bert_config",
required=True,
type=str,
help="location of config.json")
parser.add_argument(
"--wandb",
action='store_true',
help="use wandb to watch model")
parser.add_argument(
"--wandb_project_name",
default='roberta',
help="wandb project name")
parser.add_argument(
"--log_interval",
default=100,
type=int,
help="report interval")
parser.add_argument(
"--log_path",
type=str,
required=True,
help="log file which records train step")
parser.add_argument(
"--tensorboard_path",
type=str,
required=True,
help="location of tensorboard file")
parser.add_argument(
"--colossal_config",
type=str,
required=True,
help="colossal config, which contains zero config and so on")
parser.add_argument(
"--ckpt_path",
type=str,
required=True,
help="location of saving checkpoint, which contains model and optimizer")
parser.add_argument(
'--seed',
type=int,
default=42,
help="random seed for initialization")
parser.add_argument(
'--vscode_debug',
action='store_true',
help="use vscode to debug")
parser.add_argument(
'--load_pretrain_model',
default='',
type=str,
help="location of model's checkpoin")
parser.add_argument(
'--load_optimizer_lr',
default='',
type=str,
help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step")
parser.add_argument(
'--resume_train',
action='store_true',
help="whether resume training from a early checkpoint")
parser.add_argument(
'--mlm',
default='bert',
type=str,
help="model type, bert or deberta")
parser.add_argument(
'--checkpoint_activations',
action='store_true',
help="whether to use gradient checkpointing")
args = parser.parse_args()
return args

View File

@ -0,0 +1,15 @@
class BertDatasetProviderInterface:
def get_shard(self, index, shuffle=True):
raise NotImplementedError
def release_shard(self, index):
raise NotImplementedError
def prefetch_shard(self, index):
raise NotImplementedError
def get_batch(self, batch_iter):
raise NotImplementedError
def prefetch_batch(self):
raise NotImplementedError

View File

@ -0,0 +1,71 @@
import os
import math
import torch
from tqdm import tqdm
from utils.global_vars import get_timers, get_tensorboard_writer
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
def evaluate(engine, args, logger, global_step):
evaluate_dataset_provider = NvidiaBertDatasetProvider(args, evaluate=True)
start_shard = 0
engine.eval()
timers = get_timers()
eval_step = 0
eval_loss = 0
cur_loss = 0
world_size = torch.distributed.get_world_size()
with torch.no_grad():
for shard in range(start_shard, len(os.listdir(args.eval_data_path_prefix))):
timers('eval_shard_time').start()
dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard)
# evaluate_dataset_provider.prefetch_shard(shard + 1)
if torch.distributed.get_rank() == 0:
iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), colour='MAGENTA', smoothing=1)
else:
iterator_data = enumerate(dataset_iterator)
for step, batch_data in iterator_data: #tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1):
# batch_data = pretrain_dataset_provider.get_batch(batch_index)
eval_step += 1
input_ids = batch_data[0].cuda()
attention_mask = batch_data[1].cuda()
token_type_ids = batch_data[2].cuda()
mlm_label = batch_data[3].cuda()
# nsp_label = batch_data[5].cuda()
output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
loss = engine.criterion(output.logits, mlm_label)#prediction_scores
evaluate_dataset_provider.prefetch_batch()
eval_loss += loss.float().item()
cur_loss = eval_loss / eval_step
elapsed_time = timers("eval_shard_time").elapsed()
elapsed_time_per_iteration = elapsed_time / eval_step
ppl = math.exp(cur_loss)
if args.wandb and torch.distributed.get_rank() == 0:
tensorboard_log = get_tensorboard_writer()
tensorboard_log.log_eval({
'loss': cur_loss,
'ppl': ppl,
'mins_batch': elapsed_time_per_iteration
}, global_step)
eval_log_str = f'evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \
f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}'
logger.info(eval_log_str)
logger.info('-' * 100)
logger.info('')
evaluate_dataset_provider.release_shard()
engine.train()
return cur_loss

View File

@ -0,0 +1,10 @@
GPU001
GPU002
GPU003
GPU004
GPU005
GPU006
GPU007
GPU008
GPU009
GPU010

View File

@ -0,0 +1,17 @@
import torch
__all__ = ['LossForPretraining']
class LossForPretraining(torch.nn.Module):
def __init__(self, vocab_size):
super(LossForPretraining, self).__init__()
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1)
self.vocab_size = vocab_size
def forward(self, prediction_scores, masked_lm_labels, next_sentence_labels=None):
masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1))
# next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1))
total_loss = masked_lm_loss #+ next_sentence_loss
return total_loss

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,182 @@
import os
import random
import h5py
import logging
import json
import time
from concurrent.futures import ProcessPoolExecutor
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import RandomSampler
from torch.utils.data.distributed import DistributedSampler
from bert_dataset_provider import BertDatasetProviderInterface
import colossalai.utils as utils
# Workaround because python functions are not picklable
class WorkerInitObj(object):
def __init__(self, seed):
self.seed = seed
def __call__(self, id):
np.random.seed(seed=self.seed + id)
random.seed(self.seed + id)
def create_pretraining_dataset(input_file, max_predictions_per_seq,
num_workers, train_batch_size, worker_init,
data_sampler):
train_data = pretraining_dataset(
input_file=input_file, max_predictions_per_seq=max_predictions_per_seq)
train_dataloader = DataLoader(train_data,
sampler=data_sampler(train_data),
batch_size=train_batch_size,
num_workers=num_workers,
worker_init_fn=worker_init,
pin_memory=True
)
return train_dataloader, len(train_data)
class pretraining_dataset(Dataset):
def __init__(self, input_file, max_predictions_per_seq):
self.input_file = input_file
self.max_predictions_per_seq = max_predictions_per_seq
f = h5py.File(input_file, "r")
keys = [
'input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions'
]
self.inputs = [np.asarray(f[key][:]) for key in keys]
f.close()
def __len__(self):
'Denotes the total number of samples'
return len(self.inputs[0])
def __getitem__(self, index):
[
input_ids, input_mask, segment_ids, masked_lm_labels
] = [
torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else
torch.from_numpy(np.asarray(input[index].astype(np.int64)))
for indice, input in enumerate(self.inputs)
]
return [
input_ids, input_mask,
segment_ids, masked_lm_labels
]
class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
def __init__(self, args, evaluate=False):
self.num_workers = args.num_workers
self.max_seq_length = args.max_seq_length
self.max_predictions_per_seq = args.max_predictions_per_seq
self.gradient_accumulation_steps = args.gradient_accumulation_steps
if not evaluate:
self.train_micro_batch_size_per_gpu = args.train_micro_batch_size_per_gpu
else:
self.train_micro_batch_size_per_gpu = args.eval_micro_batch_size_per_gpu
self.logger = args.logger
self.global_rank = dist.get_rank()
self.world_size = dist.get_world_size()
# Initialize dataset files
if not evaluate:
self.dataset_files = [
os.path.join(args.data_path_prefix, f) for f in os.listdir(args.data_path_prefix) if
os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f
]
else:
self.dataset_files = [
os.path.join(args.eval_data_path_prefix, f) for f in os.listdir(args.eval_data_path_prefix) if
os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f
]
self.dataset_files.sort()
# random.shuffle(self.dataset_files)
self.num_files = len(self.dataset_files)
# self.data_sampler = RandomSampler
self.data_sampler = DistributedSampler
self.worker_init = WorkerInitObj(args.seed + args.local_rank)
self.dataset_future = None
self.pool = ProcessPoolExecutor(1)
self.data_file = None
self.shuffle = True
if self.global_rank == 0:
self.logger.info(
f"NvidiaBertDatasetProvider - Initialization: num_files = {self.num_files}"
)
def get_shard(self, index):
start = time.time()
if self.dataset_future is None:
self.data_file = self._get_shard_file(index)
self.train_dataloader, sample_count = create_pretraining_dataset(
input_file=self.data_file,
max_predictions_per_seq=self.max_predictions_per_seq,
num_workers=self.num_workers,
train_batch_size=self.train_micro_batch_size_per_gpu,
worker_init=self.worker_init,
data_sampler=self.data_sampler)
else:
self.train_dataloader, sample_count = self.dataset_future.result(
timeout=None)
self.logger.info(
f"Data Loading Completed for Pretraining Data from {self.data_file} with {sample_count} samples took {time.time()-start:.2f}s."
)
return self.train_dataloader, sample_count
def release_shard(self):
del self.train_dataloader
self.pool.shutdown()
def prefetch_shard(self, index):
self.data_file = self._get_shard_file(index)
self.dataset_future = self.pool.submit(
create_pretraining_dataset, self.data_file,
self.max_predictions_per_seq, self.num_workers,
self.train_micro_batch_size_per_gpu, self.worker_init,
self.data_sampler)
def get_batch(self, batch_iter):
return batch_iter
def prefetch_batch(self):
pass
def _get_shard_file(self, shard_index):
file_index = self._get_shard_file_index(shard_index, self.global_rank)
return self.dataset_files[file_index]
def _get_shard_file_index(self, shard_index, global_rank):
# if dist.is_initialized() and self.world_size > self.num_files:
# remainder = self.world_size % self.num_files
# file_index = (shard_index * self.world_size) + global_rank + (
# remainder * shard_index)
# else:
# file_index = shard_index * self.world_size + global_rank
return shard_index % self.num_files
def shuffle_dataset(self, epoch):
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(self.num_files, generator=g).tolist()
new_dataset = [self.dataset_files[i] for i in indices]
self.dataset_files = new_dataset

View File

@ -0,0 +1,112 @@
import transformers
import logging
from colossalai.nn.lr_scheduler import LinearWarmupLR
from transformers import get_linear_schedule_with_warmup
from transformers import BertForPreTraining, RobertaForMaskedLM, RobertaConfig
from transformers import GPT2Config, GPT2LMHeadModel
from transformers import AutoTokenizer, AutoModelForMaskedLM
from colossalai.nn.optimizer import FusedAdam
from torch.optim import AdamW
from colossalai.core import global_context as gpc
import torch
import os
import sys
sys.path.append(os.getcwd())
from model.deberta_v2 import DebertaV2ForMaskedLM
from model.bert import BertForMaskedLM
import torch.nn as nn
from collections import OrderedDict
__all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining']
def get_new_state_dict(state_dict, start_index=13):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[start_index:]
new_state_dict[name] = v
return new_state_dict
class LMModel(nn.Module):
def __init__(self, model, config, args):
super().__init__()
self.checkpoint = args.checkpoint_activations
self.config = config
self.model = model
if self.checkpoint:
self.model.gradient_checkpointing_enable()
def forward(self, input_ids, token_type_ids=None, attention_mask=None):
# Only return lm_logits
return self.model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
def get_model(args, logger):
if args.mlm == 'bert':
config = transformers.BertConfig.from_json_file(args.bert_config)
model = BertForMaskedLM(config)
elif args.mlm == 'deberta_v2':
config = transformers.DebertaV2Config.from_json_file(args.bert_config)
model = DebertaV2ForMaskedLM(config)
else:
raise Exception("Invalid mlm!")
if len(args.load_pretrain_model) > 0:
assert os.path.exists(args.load_pretrain_model)
# load_checkpoint(args.load_pretrain_model, model, strict=False)
m_state_dict = torch.load(args.load_pretrain_model, map_location=torch.device(f"cuda:{torch.cuda.current_device()}"))
# new_state_dict = get_new_state_dict(m_state_dict)
model.load_state_dict(m_state_dict, strict=True) # must insure that every process have identical parameters !!!!!!!
logger.info("load model success")
numel = sum([p.numel() for p in model.parameters()])
if args.checkpoint_activations:
model.gradient_checkpointing_enable()
# model = LMModel(model, config, args)
return config, model, numel
def get_optimizer(model, lr):
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']
# configure the weight decay for bert models
optimizer_grouped_parameters = [{
'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay': 0.1
}, {
'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
'weight_decay': 0.0
}]
optimizer = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=[0.9, 0.95])
return optimizer
def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1):
# warmup_steps = int(total_steps * warmup_ratio)
lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, last_epoch=last_epoch)
# lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps)
return lr_scheduler
def save_ckpt(model, optimizer, lr_scheduler, path, epoch, shard, global_step):
model_path = path + '_pytorch_model.bin'
optimizer_lr_path = path + '.op_lrs'
checkpoint = {}
checkpoint['optimizer'] = optimizer.state_dict()
checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
checkpoint['epoch'] = epoch
checkpoint['shard'] = shard
checkpoint['global_step'] = global_step
model_state = model.state_dict() #each process must run model.state_dict()
if gpc.get_global_rank() == 0:
torch.save(checkpoint, optimizer_lr_path)
torch.save(model_state, model_path)

View File

@ -0,0 +1,40 @@
#!/usr/bin/env sh
root_path=$PWD
PY_FILE_PATH="$root_path/run_pretraining.py"
tensorboard_path="$root_path/tensorboard"
log_path="$root_path/exp_log"
ckpt_path="$root_path/ckpt"
colossal_config="$root_path/../configs/colossalai_ddp.py"
mkdir -p $tensorboard_path
mkdir -p $log_path
mkdir -p $ckpt_path
export PYTHONPATH=$PWD
env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \
--include GPU002,GPU003,GPU004,GPU007 \
--nproc_per_node=8 \
$PY_FILE_PATH \
--master_addr GPU007 \
--master_port 20024 \
--lr 2.0e-4 \
--train_micro_batch_size_per_gpu 190 \
--eval_micro_batch_size_per_gpu 20 \
--epoch 15 \
--data_path_prefix /h5 \
--eval_data_path_prefix /eval_h5 \
--tokenizer_path /roberta \
--bert_config /roberta/config.json \
--tensorboard_path $tensorboard_path \
--log_path $log_path \
--ckpt_path $ckpt_path \
--colossal_config $colossal_config \
--log_interval 50 \
--mlm bert \
--wandb \
--checkpoint_activations \

View File

@ -0,0 +1,43 @@
#!/usr/bin/env sh
root_path=$PWD
PY_FILE_PATH="$root_path/run_pretraining.py"
tensorboard_path="$root_path/tensorboard"
log_path="$root_path/exp_log"
ckpt_path="$root_path/ckpt"
colossal_config="$root_path/../configs/colossalai_ddp.py"
mkdir -p $tensorboard_path
mkdir -p $log_path
mkdir -p $ckpt_path
export PYTHONPATH=$PWD
env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \
--include GPU002,GPU003,GPU004,GPU007 \
--nproc_per_node=8 \
$PY_FILE_PATH \
--master_addr GPU007 \
--master_port 20024 \
--lr 2.0e-4 \
--train_micro_batch_size_per_gpu 190 \
--eval_micro_batch_size_per_gpu 20 \
--epoch 15 \
--data_path_prefix /h5 \
--eval_data_path_prefix /eval_h5 \
--tokenizer_path /roberta \
--bert_config /roberta/config.json \
--tensorboard_path $tensorboard_path \
--log_path $log_path \
--ckpt_path $ckpt_path \
--colossal_config $colossal_config \
--log_interval 50 \
--mlm bert \
--wandb \
--checkpoint_activations \
--resume_train \
--load_pretrain_model /ckpt/1.pt \
--load_optimizer_lr /ckpt/1.op_lrs \

View File

@ -0,0 +1,226 @@
import colossalai
import math
import torch
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
import colossalai.nn as col_nn
from arguments import parse_args
from pretrain_utils import get_model, get_optimizer, get_lr_scheduler, save_ckpt
from utils.exp_util import get_tflops, get_mem_info, throughput_calculator, log_args
from utils.global_vars import set_global_variables, get_timers, get_tensorboard_writer
from utils.logger import Logger
from evaluation import evaluate
from loss import LossForPretraining
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
from tqdm import tqdm
import os
import time
from functools import partial
from transformers import AutoTokenizer
from colossalai.gemini import ChunkManager, GeminiManager
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.utils import get_current_device
from colossalai.nn.parallel import ZeroDDP
from colossalai.zero import ZeroOptimizer
from colossalai.tensor import ProcessGroup
from colossalai.nn.optimizer import HybridAdam
def main():
args = parse_args()
launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug)
if args.vscode_debug:
colossalai.launch(config={},
rank=args.rank,
world_size=args.world_size,
host=args.host,
port=args.port,
backend=args.backend)
args.local_rank = -1
args.log_interval = 1
else:
colossalai.launch_from_torch(args.colossal_config) #args.colossal_config
args.local_rank = int(os.environ["LOCAL_RANK"])
logger.info(f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' +
f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}')
log_args(logger, args)
args.tokenizer = tokenizer
args.logger = logger
set_global_variables(launch_time, args.tensorboard_path)
use_zero = hasattr(gpc.config, 'zero')
world_size = torch.distributed.get_world_size()
# build model, optimizer and criterion
if use_zero:
shard_strategy = TensorShardStrategy()
with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy,
shard_param=True):
config, model, numel = get_model(args, logger)
# model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True)
else:
config, model, numel = get_model(args, logger)
logger.info("no_zero")
if torch.distributed.get_rank() == 0:
os.mkdir(os.path.join(args.ckpt_path, launch_time))
logger.info(f'Model numel: {numel}')
get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length)
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader)
total_steps = steps_per_epoch * args.epoch
# build optimizer and lr_scheduler
start_epoch = 0
start_shard = 0
global_step = 0
if args.resume_train:
assert os.path.exists(args.load_optimizer_lr)
o_l_state_dict = torch.load(args.load_optimizer_lr, map_location='cpu')
o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1
optimizer = get_optimizer(model, lr=args.lr)
optimizer.load_state_dict(o_l_state_dict['optimizer'])
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) #o_l_state_dict['lr_scheduler']['last_epoch']
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}")
# if you want delete the above three code, have to move the model to gpu, because in optimizer.step()
lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler'])
start_epoch = o_l_state_dict['epoch']
start_shard = o_l_state_dict['shard'] + 1
# global_step = o_l_state_dict['global_step'] + 1
logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}')
else:
optimizer = get_optimizer(model, lr=args.lr)
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1)
# optimizer = gpc.config.optimizer.pop('type')(
# model.parameters(), **gpc.config.optimizer)
# optimizer = ShardedOptimizerV2(model, optimizer, initial_scale=2**5)
criterion = LossForPretraining(config.vocab_size)
# build dataloader
pretrain_dataset_provider = NvidiaBertDatasetProvider(args)
# initialize with colossalai
engine, _, _, lr_scheduelr = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
lr_scheduler=lr_scheduler)
logger.info(get_mem_info(prefix='After init model, '))
best_loss = None
eval_loss = 0
train_loss = 0
timers = get_timers()
timers('interval_time').start()
timers('epoch_time').start()
timers('shard_time').start()
for epoch in range(start_epoch, args.epoch):
for shard in range(start_shard, len(os.listdir(args.data_path_prefix))):
dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard)
# pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload
if torch.distributed.get_rank() == 0:
iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1)
else:
iterator_data = enumerate(dataset_iterator)
engine.train()
for step, batch_data in iterator_data:
# batch_data = pretrain_dataset_provider.get_batch(batch_index)
input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}")
attention_mask = batch_data[1].cuda(f"cuda:{torch.cuda.current_device()}")
token_type_ids = batch_data[2].cuda(f"cuda:{torch.cuda.current_device()}")
mlm_label = batch_data[3].cuda(f"cuda:{torch.cuda.current_device()}")
# nsp_label = batch_data[5].cuda()
output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
loss = engine.criterion(output.logits, mlm_label)
pretrain_dataset_provider.prefetch_batch()
engine.backward(loss)
train_loss += loss.float().item()
# if (step + 1) % args.accumulation_step == 0:
engine.step()
lr_scheduelr.step()
engine.zero_grad()
global_step += 1
if global_step % args.log_interval == 0 and global_step != 0 \
and torch.distributed.get_rank() == 0:
elapsed_time = timers('interval_time').elapsed(reset=False)
elapsed_time_per_iteration = elapsed_time / global_step
samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(numel, args, config, elapsed_time, global_step, world_size)
cur_loss = train_loss / args.log_interval
current_lr = lr_scheduelr.get_last_lr()[0]
log_str = f'| epoch: {epoch} | shard: {shard} | step: {global_step} | lr {current_lr:.7f} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \
f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {math.exp(cur_loss):.3f} | TFLOPS: {get_tflops_func(elapsed_time_per_iteration):.3f} or {tflops:.3f}'
logger.info(log_str, print_=False)
if args.wandb:
tensorboard_log = get_tensorboard_writer()
tensorboard_log.log_train({
'lr': current_lr,
'loss': cur_loss,
'ppl': math.exp(cur_loss),
'mins_batch': elapsed_time_per_iteration
}, global_step)
train_loss = 0
logger.info(f'epoch {epoch} shard {shard} has cost {timers("shard_time").elapsed() / 60 :.3f} mins')
logger.info('*' * 100)
eval_loss += evaluate(engine, args, logger, global_step)
save_ckpt(engine.model, optimizer, lr_scheduelr, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step)
eval_loss /= len(os.listdir(args.data_path_prefix))
logger.info(f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' + \
f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}')
logger.info('-' * 100)
if args.wandb and torch.distributed.get_rank() == 0:
tensorboard_log = get_tensorboard_writer()
tensorboard_log.log_eval({
'all_eval_shard_loss': eval_loss,
}, epoch)
start_shard = 0
eval_loss = 0
pretrain_dataset_provider.release_shard()
logger.info('Congratulation, training has finished!!!')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,46 @@
import time
import wandb
import os
from torch.utils.tensorboard import SummaryWriter
class WandbLog:
@classmethod
def init_wandb(cls, project, notes=None, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None):
wandb.init(project=project, notes=notes, name=name, config=config)
@classmethod
def log(cls, result, model=None, gradient=None):
wandb.log(result)
if model:
wandb.watch(model)
if gradient:
wandb.watch(gradient)
class TensorboardLog:
def __init__(self, location, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None):
if not os.path.exists(location):
os.mkdir(location)
self.writer = SummaryWriter(location, comment=name)
def log_train(self, result, step):
for k, v in result.items():
self.writer.add_scalar(f'{k}/train', v, step)
def log_eval(self, result, step):
for k, v in result.items():
self.writer.add_scalar(f'{k}/eval', v, step)
def log_zeroshot(self, result, step):
for k, v in result.items():
self.writer.add_scalar(f'{k}_acc/eval', v, step)

View File

@ -0,0 +1,99 @@
import functools
import os, shutil
import torch
import psutil
from colossalai.core import global_context as gpc
def logging(s, log_path, print_=True, log_=True):
if print_:
print(s)
if log_:
with open(log_path, 'a+') as f_log:
f_log.write(s + '\n')
def get_logger(log_path, **kwargs):
return functools.partial(logging, log_path=log_path, **kwargs)
def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
if debug:
print('Debug Mode : no experiment dir created')
return functools.partial(logging, log_path=None, log_=False)
if not os.path.exists(dir_path):
os.makedirs(dir_path)
print('Experiment dir : {}'.format(dir_path))
if scripts_to_save is not None:
script_path = os.path.join(dir_path, 'scripts')
if not os.path.exists(script_path):
os.makedirs(script_path)
for script in scripts_to_save:
dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script))
shutil.copyfile(script, dst_file)
return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
def get_cpu_mem():
return psutil.Process().memory_info().rss / 1024**2
def get_gpu_mem():
return torch.cuda.memory_allocated() / 1024**2
def get_mem_info(prefix=''):
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
def get_parameters_in_billions(model, world_size=1):
gpus_per_model = world_size
approx_parameters_in_billions = sum([sum([p.ds_numel if hasattr(p,'ds_id') else p.nelement() for p in model_module.parameters()])
for model_module in model])
return approx_parameters_in_billions * gpus_per_model / (1e9)
def throughput_calculator(numel, args, config, iteration_time, total_iterations, world_size=1):
gpus_per_model = 1
batch_size = args.train_micro_batch_size_per_gpu
samples_per_model = batch_size * args.max_seq_length
model_replica_count = world_size / gpus_per_model
approx_parameters_in_billions = numel
elapsed_time_per_iter = iteration_time / total_iterations
samples_per_second = batch_size / elapsed_time_per_iter
#flops calculator
hidden_size = config.hidden_size
num_layers = config.num_hidden_layers
vocab_size = config.vocab_size
# General TFLOPs formula (borrowed from Equation 3 in Section 5.1 of
# https://arxiv.org/pdf/2104.04473.pdf).
# The factor of 4 is when used with activation check-pointing,
# otherwise it will be 3.
checkpoint_activations_factor = 4 if args.checkpoint_activations else 3
flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers * (hidden_size**2)) * (1. + (args.max_seq_length / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size)))
tflops = flops_per_iteration / (elapsed_time_per_iter * (10**12))
return samples_per_second, tflops, approx_parameters_in_billions
def synchronize():
if not torch.distributed.is_available():
return
if not torch.distributed.is_intialized():
return
world_size = torch.distributed.get_world_size()
if world_size == 1:
return
torch.distributed.barrier()
def log_args(logger, args):
logger.info('--------args----------')
message = '\n'.join([f'{k:<30}: {v}' for k, v in vars(args).items()])
message += '\n'
message += '\n'.join([f'{k:<30}: {v}' for k, v in gpc.config.items()])
logger.info(message)
logger.info('--------args----------\n')

View File

@ -0,0 +1,126 @@
import time
import torch
from .WandbLog import TensorboardLog
_GLOBAL_TIMERS = None
_GLOBAL_TENSORBOARD_WRITER = None
def set_global_variables(launch_time, tensorboard_path):
_set_timers()
_set_tensorboard_writer(launch_time, tensorboard_path)
def _set_timers():
"""Initialize timers."""
global _GLOBAL_TIMERS
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
_GLOBAL_TIMERS = Timers()
def _set_tensorboard_writer(launch_time, tensorboard_path):
"""Set tensorboard writer."""
global _GLOBAL_TENSORBOARD_WRITER
_ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER,
'tensorboard writer')
if torch.distributed.get_rank() == 0:
_GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f'/{launch_time}', launch_time)
def get_timers():
"""Return timers."""
_ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
return _GLOBAL_TIMERS
def get_tensorboard_writer():
"""Return tensorboard writer. It can be None so no need
to check if it is initialized."""
return _GLOBAL_TENSORBOARD_WRITER
def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is not None, '{} is not initialized.'.format(name)
def _ensure_var_is_not_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is None, '{} is already initialized.'.format(name)
class _Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
# assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
class Timers:
"""Group of timers."""
def __init__(self):
self.timers = {}
def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '-time', value, iteration)
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = 'time (ms)'
for name in names:
elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1):
print(string, flush=True)
else:
print(string, flush=True)

View File

@ -0,0 +1,31 @@
import os
import logging
import torch.distributed as dist
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
class Logger():
def __init__(self, log_path, cuda=False, debug=False):
self.logger = logging.getLogger(__name__)
self.cuda = cuda
self.log_path = log_path
self.debug = debug
def info(self, message, log_=True, print_=True, *args, **kwargs):
if (self.cuda and dist.get_rank() == 0) or not self.cuda:
if print_:
self.logger.info(message, *args, **kwargs)
if log_:
with open(self.log_path, 'a+') as f_log:
f_log.write(message + '\n')
def error(self, message, *args, **kwargs):
self.logger.error(message, *args, **kwargs)