[example] reorganize for community examples (#3557)

This commit is contained in:
binmakeswell 2023-04-14 16:27:48 +08:00 committed by GitHub
parent 1a809eddaa
commit f1b3d60cae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 785 additions and 844 deletions

View File

@ -10,9 +10,12 @@
## Overview ## Overview
This folder provides several examples accelerated by Colossal-AI. The `tutorial` folder is for everyone to quickly try out the different features in Colossal-AI. Other folders such as `images` and `language` include a wide range of deep learning tasks and applications. This folder provides several examples accelerated by Colossal-AI.
Folders such as `images` and `language` include a wide range of deep learning tasks and applications.
The `community` folder aim to create a collaborative platform for developers to contribute exotic features built on top of Colossal-AI.
The `tutorial` folder is for everyone to quickly try out the different features in Colossal-AI.
You can find applications such as Chatbot, Stable Diffusion and Biomedicine in the [Applications](https://github.com/hpcaitech/ColossalAI/tree/main/applications) directory. You can find applications such as Chatbot, AIGC and Biomedicine in the [Applications](https://github.com/hpcaitech/ColossalAI/tree/main/applications) directory.
## Folder Structure ## Folder Structure
@ -52,3 +55,10 @@ Therefore, it is essential for the example contributors to know how to integrate
2. Configure your testing parameters such as number steps, batch size in `test_ci.sh`, e.t.c. Keep these parameters small such that each example only takes several minutes. 2. Configure your testing parameters such as number steps, batch size in `test_ci.sh`, e.t.c. Keep these parameters small such that each example only takes several minutes.
3. Export your dataset path with the prefix `/data` and make sure you have a copy of the dataset in the `/data/scratch/examples-data` directory on the CI machine. Community contributors can contact us via slack to request for downloading the dataset on the CI machine. 3. Export your dataset path with the prefix `/data` and make sure you have a copy of the dataset in the `/data/scratch/examples-data` directory on the CI machine. Community contributors can contact us via slack to request for downloading the dataset on the CI machine.
4. Implement the logic such as dependency setup and example execution 4. Implement the logic such as dependency setup and example execution
## Community Dependency
We are happy to introduce the following nice community dependency repos that are powered by Colossal-AI:
- [lightning-ColossalAI](https://github.com/Lightning-AI/lightning)
- [HCP-Diffusion](https://github.com/7eu7d7/HCP-Diffusion)
- [KoChatGPT](https://github.com/airobotlab/KoChatGPT)
- [minichatgpt](https://github.com/juncongmoo/minichatgpt)

View File

@ -0,0 +1,28 @@
#Community Examples
Community-driven Examples is an initiative that allows users to share their own examples to the Colossal-AI community, fostering a sense of community and making it easy for others to access and benefit from shared work. The primary goal with community-driven examples is to have a community-maintained collection of diverse and exotic functionalities built on top of the Colossal-AI package.
If a community example doesn't work as expected, you can [open an issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) and @ the author to report it.
| Example | Description | Code Example | Colab |Author |
|:------------------|:---------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------|:-----------------------------------------|-----------------------------------------------------:|
| RoBERTa | Adding RoBERTa for SFT and Prompts model training | [RoBERTa](./roberta) | - | [YY Lin](https://github.com/yynil) (Moore Threads) |
| TransformerEngine FP8 | Adding TransformerEngine with FP8 training | [TransformerEngine FP8](./fp8) | - | [Kirthi Shankar Sivamani](https://github.com/ksivaman) (NVIDIA) |
|...|...|...|...|...|
## Looking for Examples
* [Swin-Transformer](https://github.com/microsoft/Swin-Transformer)
* [T-5](https://github.com/google-research/text-to-text-transfer-transformer)
* [Segment Anything (SAM)](https://github.com/facebookresearch/segment-anything)
* [ControlNet](https://github.com/lllyasviel/ControlNet)
* [Consistency Models](https://github.com/openai/consistency_models)
* [MAE](https://github.com/facebookresearch/mae)
* [CLIP](https://github.com/openai/CLIP)
Welcome to [open an issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) to share your insights and needs.
## How to get involved
To join our community-driven initiative, please visit the [Colossal-AI examples](https://github.com/hpcaitech/ColossalAI/tree/main/examples), review the provided information, and explore the codebase.
To contribute, create a new issue outlining your proposed feature or enhancement, and our team will review and provide feedback. If you are confident enough you can also submit a PR directly. We look forward to collaborating with you on this exciting project!

View File

@ -3,12 +3,13 @@
# See LICENSE for license information. # See LICENSE for license information.
import argparse import argparse
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
try: try:
from transformer_engine import pytorch as te from transformer_engine import pytorch as te
@ -18,6 +19,7 @@ except (ImportError, ModuleNotFoundError):
class Net(nn.Module): class Net(nn.Module):
def __init__(self, use_te=False): def __init__(self, use_te=False):
super(Net, self).__init__() super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv1 = nn.Conv2d(1, 32, 3, 1)
@ -62,12 +64,10 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
if batch_idx % args.log_interval == 0: if batch_idx % args.log_interval == 0:
print( print(f"Train Epoch: {epoch} "
f"Train Epoch: {epoch} "
f"[{batch_idx * len(data)}/{len(train_loader.dataset)} " f"[{batch_idx * len(data)}/{len(train_loader.dataset)} "
f"({100. * batch_idx / len(train_loader):.0f}%)]\t" f"({100. * batch_idx / len(train_loader):.0f}%)]\t"
f"Loss: {loss.item():.6f}" f"Loss: {loss.item():.6f}")
)
if args.dry_run: if args.dry_run:
break break
@ -83,6 +83,7 @@ def calibrate(model, device, test_loader):
with te.fp8_autocast(enabled=False, calibrating=True): with te.fp8_autocast(enabled=False, calibrating=True):
output = model(data) output = model(data)
def test(model, device, test_loader, use_fp8): def test(model, device, test_loader, use_fp8):
"""Testing function.""" """Testing function."""
model.eval() model.eval()
@ -93,21 +94,15 @@ def test(model, device, test_loader, use_fp8):
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=use_fp8): with te.fp8_autocast(enabled=use_fp8):
output = model(data) output = model(data)
test_loss += F.nll_loss( test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
output, target, reduction="sum" pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item() correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset) test_loss /= len(test_loader.dataset)
print( print(f"\nTest set: Average loss: {test_loss:.4f}, "
f"\nTest set: Average loss: {test_loss:.4f}, "
f"Accuracy: {correct}/{len(test_loader.dataset)} " f"Accuracy: {correct}/{len(test_loader.dataset)} "
f"({100. * correct / len(test_loader.dataset):.0f}%)\n" f"({100. * correct / len(test_loader.dataset):.0f}%)\n")
)
def main(): def main():
@ -154,9 +149,7 @@ def main():
default=False, default=False,
help="quickly check a single pass", help="quickly check a single pass",
) )
parser.add_argument( parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument( parser.add_argument(
"--log-interval", "--log-interval",
type=int, type=int,
@ -170,15 +163,12 @@ def main():
default=False, default=False,
help="For Saving the current Model", help="For Saving the current Model",
) )
parser.add_argument( parser.add_argument("--use-fp8",
"--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration" action="store_true",
) default=False,
parser.add_argument( help="Use FP8 for inference and training without recalibration")
"--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only" parser.add_argument("--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only")
) parser.add_argument("--use-te", action="store_true", default=False, help="Use Transformer Engine")
parser.add_argument(
"--use-te", action="store_true", default=False, help="Use Transformer Engine"
)
args = parser.parse_args() args = parser.parse_args()
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
@ -205,9 +195,7 @@ def main():
train_kwargs.update(cuda_kwargs) train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs) test_kwargs.update(cuda_kwargs)
transform = transforms.Compose( transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST("../data", train=False, transform=transform) dataset2 = datasets.MNIST("../data", train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
@ -227,7 +215,7 @@ def main():
if args.save_model or args.use_fp8_infer: if args.save_model or args.use_fp8_infer:
torch.save(model.state_dict(), "mnist_cnn.pt") torch.save(model.state_dict(), "mnist_cnn.pt")
print('Eval with reloaded checkpoint : fp8='+str(args.use_fp8_infer)) print('Eval with reloaded checkpoint : fp8=' + str(args.use_fp8_infer))
weights = torch.load("mnist_cnn.pt") weights = torch.load("mnist_cnn.pt")
model.load_state_dict(weights) model.load_state_dict(weights)
test(model, device, test_loader, args.use_fp8_infer) test(model, device, test_loader, args.use_fp8_infer)

View File

@ -1,20 +1,22 @@
import torch import collections
import logging
import os import os
import random
import time
from enum import IntEnum from enum import IntEnum
from random import choice from random import choice
import random
import collections
import time
import logging
import jieba import jieba
import torch
jieba.setLogLevel(logging.CRITICAL) jieba.setLogLevel(logging.CRITICAL)
import re import re
import numpy as np
import mask import mask
import numpy as np
PAD = 0 PAD = 0
MaskedLMInstance = collections.namedtuple("MaskedLMInstance", MaskedLMInstance = collections.namedtuple("MaskedLMInstance", ["index", "label"])
["index", "label"])
def map_to_numpy(data): def map_to_numpy(data):
@ -22,6 +24,7 @@ def map_to_numpy(data):
class PreTrainingDataset(): class PreTrainingDataset():
def __init__(self, def __init__(self,
tokenizer, tokenizer,
max_seq_length, max_seq_length,
@ -43,14 +46,12 @@ class PreTrainingDataset():
self.mlm_tamper_p = 0.05 self.mlm_tamper_p = 0.05
self.mlm_maintain_p = 0.1 self.mlm_maintain_p = 0.1
def tokenize(self, doc): def tokenize(self, doc):
temp = [] temp = []
for d in doc: for d in doc:
temp.append(self.tokenizer.tokenize(d)) temp.append(self.tokenizer.tokenize(d))
return temp return temp
def create_training_instance(self, instance): def create_training_instance(self, instance):
is_next = 1 is_next = 1
raw_text_list = self.get_new_segment(instance) raw_text_list = self.get_new_segment(instance)
@ -83,8 +84,9 @@ class PreTrainingDataset():
# Get Masked LM predictions # Get Masked LM predictions
if self.backend == 'c++': if self.backend == 'c++':
output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions(tokens, original_tokens, self.vocab_words, output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions(
self.tokenizer.vocab, self.max_predictions_per_seq, self.masked_lm_prob) tokens, original_tokens, self.vocab_words, self.tokenizer.vocab, self.max_predictions_per_seq,
self.masked_lm_prob)
elif self.backend == 'python': elif self.backend == 'python':
output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens) output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens)
@ -105,14 +107,12 @@ class PreTrainingDataset():
map_to_numpy([is_next]) map_to_numpy([is_next])
]) ])
def create_masked_lm_predictions(self, tokens): def create_masked_lm_predictions(self, tokens):
cand_indexes = [] cand_indexes = []
for i, token in enumerate(tokens): for i, token in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]": if token == "[CLS]" or token == "[SEP]":
continue continue
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")):
token.startswith("##")):
cand_indexes[-1].append(i) cand_indexes[-1].append(i)
else: else:
cand_indexes.append([i]) cand_indexes.append([i])
@ -122,9 +122,7 @@ class PreTrainingDataset():
random.shuffle(cand_indexes) random.shuffle(cand_indexes)
output_tokens = list(tokens) output_tokens = list(tokens)
num_to_predict = min( num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob))))
self.max_predictions_per_seq,
max(1, int(round(len(tokens) * self.masked_lm_prob))))
masked_lms = [] masked_lms = []
covered_indexes = set() covered_indexes = set()
@ -145,13 +143,10 @@ class PreTrainingDataset():
masked_token = tokens[index] masked_token = tokens[index]
# 10% replace w/ random word # 10% replace w/ random word
else: else:
masked_token = self.vocab_words[random.randint( masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)]
0,
len(self.vocab_words) - 1)]
output_tokens[index] = masked_token output_tokens[index] = masked_token
masked_lms.append( masked_lms.append(MaskedLMInstance(index=index, label=tokens[index]))
MaskedLMInstance(index=index, label=tokens[index]))
masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_output = [-1] * len(output_tokens) masked_lm_output = [-1] * len(output_tokens)
@ -160,7 +155,6 @@ class PreTrainingDataset():
return (output_tokens, masked_lm_output) return (output_tokens, masked_lm_output)
def get_new_segment(self, segment): def get_new_segment(self, segment):
""" """
Input a sentence, return a processed sentence: In order to support the Chinese whole word mask, the words that are separated will be marked with a special mark ("#"), so that the subsequent processing module can know which words belong to the same word. Input a sentence, return a processed sentence: In order to support the Chinese whole word mask, the words that are separated will be marked with a special mark ("#"), so that the subsequent processing module can know which words belong to the same word.
@ -180,10 +174,10 @@ class PreTrainingDataset():
for length in range(3, 0, -1): for length in range(3, 0, -1):
if i + length > len(segment): if i + length > len(segment):
continue continue
if ''.join(segment[i: i+length]) in seq_cws_dict: if ''.join(segment[i:i + length]) in seq_cws_dict:
new_segment.append(segment[i]) new_segment.append(segment[i])
for l in range(1, length): for l in range(1, length):
new_segment.append('##' + segment[i+l]) new_segment.append('##' + segment[i + l])
i += length i += length
has_add = True has_add = True
break break
@ -192,7 +186,6 @@ class PreTrainingDataset():
i += 1 i += 1
return new_segment return new_segment
def create_whole_masked_lm_predictions(self, tokens): def create_whole_masked_lm_predictions(self, tokens):
"""Creates the predictions for the masked LM objective.""" """Creates the predictions for the masked LM objective."""
@ -209,18 +202,16 @@ class PreTrainingDataset():
# Note that Whole Word Masking does *not* change the training code # Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed # at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary. # over the entire vocabulary.
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")):
token.startswith("##")):
cand_indexes[-1].append(i) cand_indexes[-1].append(i)
else: else:
cand_indexes.append([i]) cand_indexes.append([i])
random.shuffle(cand_indexes) random.shuffle(cand_indexes)
output_tokens = [t[2:] if len(self.whole_rec.findall(t))>0 else t for t in tokens] # 去掉"##" output_tokens = [t[2:] if len(self.whole_rec.findall(t)) > 0 else t for t in tokens] # 去掉"##"
num_to_predict = min(self.max_predictions_per_seq, num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob))))
max(1, int(round(len(tokens) * self.masked_lm_prob))))
masked_lms = [] masked_lms = []
covered_indexes = set() covered_indexes = set()
@ -248,14 +239,18 @@ class PreTrainingDataset():
else: else:
# 10% of the time, keep original # 10% of the time, keep original
if random.random() < 0.5: if random.random() < 0.5:
masked_token = tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index] # 去掉"##" masked_token = tokens[index][2:] if len(self.whole_rec.findall(
tokens[index])) > 0 else tokens[index] # 去掉"##"
# 10% of the time, replace with random word # 10% of the time, replace with random word
else: else:
masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)] masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)]
output_tokens[index] = masked_token output_tokens[index] = masked_token
masked_lms.append(MaskedLMInstance(index=index, label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index])) masked_lms.append(
MaskedLMInstance(
index=index,
label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index]))
assert len(masked_lms) <= num_to_predict assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_output = [-1] * len(output_tokens) masked_lm_output = [-1] * len(output_tokens)

View File

@ -0,0 +1,190 @@
#include <math.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <algorithm>
#include <chrono>
#include <iostream>
#include <limits>
#include <random>
#include <stdexcept>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace py = pybind11;
const int32_t LONG_SENTENCE_LEN = 512;
struct MaskedLMInstance {
int index;
std::string label;
MaskedLMInstance(int index, std::string label) {
this->index = index;
this->label = label;
}
};
auto get_new_segment(
std::vector<std::string> segment, std::vector<std::string> segment_jieba,
const std::vector<bool> chinese_vocab) { // const
// std::unordered_set<std::string>
// &chinese_vocab
std::unordered_set<std::string> seq_cws_dict;
for (auto word : segment_jieba) {
seq_cws_dict.insert(word);
}
int i = 0;
std::vector<std::string> new_segment;
int segment_size = segment.size();
while (i < segment_size) {
if (!chinese_vocab[i]) { // chinese_vocab.find(segment[i]) ==
// chinese_vocab.end()
new_segment.emplace_back(segment[i]);
i += 1;
continue;
}
bool has_add = false;
for (int length = 3; length >= 1; length--) {
if (i + length > segment_size) {
continue;
}
std::string chinese_word = "";
for (int j = i; j < i + length; j++) {
chinese_word += segment[j];
}
if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) {
new_segment.emplace_back(segment[i]);
for (int j = i + 1; j < i + length; j++) {
new_segment.emplace_back("##" + segment[j]);
}
i += length;
has_add = true;
break;
}
}
if (!has_add) {
new_segment.emplace_back(segment[i]);
i += 1;
}
}
return new_segment;
}
bool startsWith(const std::string &s, const std::string &sub) {
return s.find(sub) == 0 ? true : false;
}
auto create_whole_masked_lm_predictions(
std::vector<std::string> &tokens,
const std::vector<std::string> &original_tokens,
const std::vector<std::string> &vocab_words,
std::map<std::string, int> &vocab, const int max_predictions_per_seq,
const double masked_lm_prob) {
// for (auto item : vocab) {
// std::cout << "key=" << std::string(py::str(item.first)) << ", "
// << "value=" << std::string(py::str(item.second)) <<
// std::endl;
// }
std::vector<std::vector<int> > cand_indexes;
std::vector<int> cand_temp;
int tokens_size = tokens.size();
std::string prefix = "##";
bool do_whole_masked = true;
for (int i = 0; i < tokens_size; i++) {
if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") {
continue;
}
if (do_whole_masked && (cand_indexes.size() > 0) &&
(tokens[i].rfind(prefix, 0) == 0)) {
cand_temp.emplace_back(i);
} else {
if (cand_temp.size() > 0) {
cand_indexes.emplace_back(cand_temp);
}
cand_temp.clear();
cand_temp.emplace_back(i);
}
}
auto seed = std::chrono::system_clock::now().time_since_epoch().count();
std::shuffle(cand_indexes.begin(), cand_indexes.end(),
std::default_random_engine(seed));
// for (auto i : cand_indexes) {
// for (auto j : i) {
// std::cout << tokens[j] << " ";
// }
// std::cout << std::endl;
// }
// for (auto i : output_tokens) {
// std::cout << i;
// }
// std::cout << std::endl;
int num_to_predict = std::min(max_predictions_per_seq,
std::max(1, int(tokens_size * masked_lm_prob)));
// std::cout << num_to_predict << std::endl;
std::set<int> covered_indexes;
std::vector<int> masked_lm_output(tokens_size, -1);
int vocab_words_len = vocab_words.size();
std::default_random_engine e(seed);
std::uniform_real_distribution<double> u1(0.0, 1.0);
std::uniform_int_distribution<unsigned> u2(0, vocab_words_len - 1);
int mask_cnt = 0;
std::vector<std::string> output_tokens;
output_tokens = original_tokens;
for (auto index_set : cand_indexes) {
if (mask_cnt > num_to_predict) {
break;
}
int index_set_size = index_set.size();
if (mask_cnt + index_set_size > num_to_predict) {
continue;
}
bool is_any_index_covered = false;
for (auto index : index_set) {
if (covered_indexes.find(index) != covered_indexes.end()) {
is_any_index_covered = true;
break;
}
}
if (is_any_index_covered) {
continue;
}
for (auto index : index_set) {
covered_indexes.insert(index);
std::string masked_token;
if (u1(e) < 0.8) {
masked_token = "[MASK]";
} else {
if (u1(e) < 0.5) {
masked_token = output_tokens[index];
} else {
int random_index = u2(e);
masked_token = vocab_words[random_index];
}
}
// masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index]));
masked_lm_output[index] = vocab[output_tokens[index]];
output_tokens[index] = masked_token;
mask_cnt++;
}
}
// for (auto p : masked_lms) {
// masked_lm_output[p.index] = vocab[p.label];
// }
return std::make_tuple(output_tokens, masked_lm_output);
}
PYBIND11_MODULE(mask, m) {
m.def("create_whole_masked_lm_predictions",
&create_whole_masked_lm_predictions);
m.def("get_new_segment", &get_new_segment);
}

View File

@ -1,13 +1,14 @@
import argparse
import functools
import json
import multiprocessing import multiprocessing
import os import os
import re import re
from tqdm import tqdm
from typing import List
import json
import time import time
import argparse from typing import List
import functools
from tqdm import tqdm
def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]: def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]:
sent_list = [] sent_list = []
@ -17,7 +18,8 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s
document = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\'])', r'\g<quotation_mark>\n', document) document = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\'])', r'\g<quotation_mark>\n', document)
elif flag == "en": elif flag == "en":
document = re.sub('(?P<quotation_mark>([.?!](?![”’"\'])))', r'\g<quotation_mark>\n', document) document = re.sub('(?P<quotation_mark>([.?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
document = re.sub('(?P<quotation_mark>([?!.]["\']))', r'\g<quotation_mark>\n', document) # Special quotation marks document = re.sub('(?P<quotation_mark>([?!.]["\']))', r'\g<quotation_mark>\n',
document) # Special quotation marks
else: else:
document = re.sub('(?P<quotation_mark>([。?!….?!](?![”’"\'])))', r'\g<quotation_mark>\n', document) document = re.sub('(?P<quotation_mark>([。?!….?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
@ -43,9 +45,7 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s
return sent_list return sent_list
def get_sent(output_path, def get_sent(output_path, input_path, fin_list=[], host=-1, seq_len=512) -> None:
input_path,
fin_list=[], host=-1, seq_len=512) -> None:
workers = 32 workers = 32
@ -53,7 +53,7 @@ def get_sent(output_path,
input_path = input_path[:-1] input_path = input_path[:-1]
cur_path = os.path.join(output_path, str(host) + '.txt') cur_path = os.path.join(output_path, str(host) + '.txt')
new_split_sentence = functools.partial(split_sentence, limit=seq_len-2) new_split_sentence = functools.partial(split_sentence, limit=seq_len - 2)
with open(cur_path, 'w', encoding='utf-8') as f: with open(cur_path, 'w', encoding='utf-8') as f:
for fi, fin_path in enumerate(fin_list): for fi, fin_path in enumerate(fin_list):
if not os.path.exists(os.path.join(input_path, fin_path[0])): if not os.path.exists(os.path.join(input_path, fin_path[0])):
@ -136,11 +136,7 @@ if __name__ == '__main__':
start = time.time() start = time.time()
for index, shard in enumerate(real_shard): for index, shard in enumerate(real_shard):
get_sent(output_path, get_sent(output_path, input_path, fin_list=shard, host=index, seq_len=seq_len)
input_path,
fin_list=shard,
host=index,
seq_len=seq_len)
print(f'cost {str(time.time() - start)}') print(f'cost {str(time.time() - start)}')
# if you have multiple server, you can use code below or modify code to openmpi # if you have multiple server, you can use code below or modify code to openmpi

View File

@ -1,19 +1,19 @@
import time
import os
import psutil
import h5py
import socket
import argparse import argparse
import numpy as np
import multiprocessing import multiprocessing
from tqdm import tqdm import os
import socket
import time
from random import shuffle from random import shuffle
from transformers import AutoTokenizer
import h5py
import numpy as np
import psutil
from get_mask import PreTrainingDataset from get_mask import PreTrainingDataset
from tqdm import tqdm
from transformers import AutoTokenizer
def get_raw_instance(document, max_sequence_length=512): def get_raw_instance(document, max_sequence_length=512):
""" """
Get the initial training instances, split the whole segment into multiple parts according to the max_sequence_length, and return as multiple processed instances. Get the initial training instances, split the whole segment into multiple parts according to the max_sequence_length, and return as multiple processed instances.
:param document: document :param document: document
@ -37,7 +37,7 @@ def get_raw_instance(document, max_sequence_length=512):
if len(curr_seq) > 0: if len(curr_seq) > 0:
result_list.append(curr_seq) result_list.append(curr_seq)
curr_seq = [] curr_seq = []
result_list.append(document[sz_idx][ : max_sequence_length_allowed]) result_list.append(document[sz_idx][:max_sequence_length_allowed])
sz_idx += 1 sz_idx += 1
else: else:
result_list.append(curr_seq) result_list.append(curr_seq)
@ -70,8 +70,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host):
# document = line # document = line
# if len(document.split("<sep>")) <= 3: # if len(document.split("<sep>")) <= 3:
# continue # continue
if len(line if len(line) > 0 and line[:2] == "]]": # This is end of document
) > 0 and line[:2] == "]]": # This is end of document
documents.append(document) documents.append(document)
document = [] document = []
elif len(line) >= 2: elif len(line) >= 2:
@ -84,8 +83,8 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host):
# print(len(documents)) # print(len(documents))
# print(len(documents[0])) # print(len(documents[0]))
# print(documents[0][0:10]) # print(documents[0][0:10])
from typing import List
import multiprocessing import multiprocessing
from typing import List
ans = [] ans = []
for docs in tqdm(documents): for docs in tqdm(documents):
@ -124,13 +123,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host):
del instances del instances
def split_numpy_chunk_pool(input_path, def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_factor, seq_len, file_name):
output_path,
pretrain_data,
worker,
dupe_factor,
seq_len,
file_name):
if os.path.exists(os.path.join(output_path, f'{file_name}.h5')): if os.path.exists(os.path.join(output_path, f'{file_name}.h5')):
print(f'{file_name}.h5 exists') print(f'{file_name}.h5 exists')
@ -144,8 +137,7 @@ def split_numpy_chunk_pool(input_path,
document = [] document = []
for i, line in enumerate(tqdm(fd)): for i, line in enumerate(tqdm(fd)):
line = line.strip() line = line.strip()
if len(line if len(line) > 0 and line[:2] == "]]": # This is end of document
) > 0 and line[:2] == "]]": # This is end of document
documents.append(document) documents.append(document)
document = [] document = []
elif len(line) >= 2: elif len(line) >= 2:
@ -212,11 +204,21 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer') parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer')
parser.add_argument('--seq_len', type=int, default=512, help='sequence length') parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
parser.add_argument('--max_predictions_per_seq', type=int, default=80, help='number of shards, e.g., 10, 50, or 100') parser.add_argument('--max_predictions_per_seq',
type=int,
default=80,
help='number of shards, e.g., 10, 50, or 100')
parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence') parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence')
parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id') parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id')
parser.add_argument('--backend', type=str, default='python', help='backend of mask token, python, c++, numpy respectively') parser.add_argument('--backend',
parser.add_argument('--dupe_factor', type=int, default=1, help='specifies how many times the preprocessor repeats to create the input from the same article/document') type=str,
default='python',
help='backend of mask token, python, c++, numpy respectively')
parser.add_argument(
'--dupe_factor',
type=int,
default=1,
help='specifies how many times the preprocessor repeats to create the input from the same article/document')
parser.add_argument('--worker', type=int, default=32, help='number of process') parser.add_argument('--worker', type=int, default=32, help='number of process')
parser.add_argument('--server_num', type=int, default=10, help='number of servers') parser.add_argument('--server_num', type=int, default=10, help='number of servers')
args = parser.parse_args() args = parser.parse_args()
@ -227,7 +229,6 @@ if __name__ == '__main__':
args.backend, args.backend,
max_predictions_per_seq=args.max_predictions_per_seq) max_predictions_per_seq=args.max_predictions_per_seq)
data_len = len(os.listdir(args.input_path)) data_len = len(os.listdir(args.input_path))
for i in range(data_len): for i in range(data_len):
@ -235,15 +236,10 @@ if __name__ == '__main__':
if os.path.exists(input_path): if os.path.exists(input_path):
start = time.time() start = time.time()
print(f'process {input_path}') print(f'process {input_path}')
split_numpy_chunk_pool(input_path, split_numpy_chunk_pool(input_path, args.output_path, pretrain_data, args.worker, args.dupe_factor,
args.output_path, args.seq_len, i)
pretrain_data,
args.worker,
args.dupe_factor,
args.seq_len,
i)
end_ = time.time() end_ = time.time()
print(u'memory%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ) print(u'memory%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024))
print(f'has cost {(end_ - start) / 60}') print(f'has cost {(end_ - start) / 60}')
print('-' * 100) print('-' * 100)
print('') print('')
@ -269,5 +265,3 @@ if __name__ == '__main__':
# print(f'has cost {(end_ - start) / 60}') # print(f'has cost {(end_ - start) / 60}')
# print('-' * 100) # print('-' * 100)
# print('') # print('')

View File

@ -21,4 +21,3 @@ bash run_pretrain_resume.sh
* `--resume_train`: whether to resume training * `--resume_train`: whether to resume training
* `--load_pretrain_model`: absolute path which contains model checkpoint * `--load_pretrain_model`: absolute path which contains model checkpoint
* `--load_optimizer_lr`: absolute path which contains optimizer checkpoint * `--load_optimizer_lr`: absolute path which contains optimizer checkpoint

View File

@ -0,0 +1,87 @@
from numpy import require
import colossalai
__all__ = ['parse_args']
def parse_args():
parser = colossalai.get_default_parser()
parser.add_argument(
"--distplan",
type=str,
default='CAI_Gemini',
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
)
parser.add_argument(
"--tp_degree",
type=int,
default=1,
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--shardinit",
action='store_true',
help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
)
parser.add_argument('--lr', type=float, required=True, help='initial learning rate')
parser.add_argument('--epoch', type=int, required=True, help='number of epoch')
parser.add_argument('--data_path_prefix', type=str, required=True, help="location of the train data corpus")
parser.add_argument('--eval_data_path_prefix',
type=str,
required=True,
help='location of the evaluation data corpus')
parser.add_argument('--tokenizer_path', type=str, required=True, help='location of the tokenizer')
parser.add_argument('--max_seq_length', type=int, default=512, help='sequence length')
parser.add_argument('--refresh_bucket_size',
type=int,
default=1,
help="This param makes sure that a certain task is repeated for this time steps to \
optimise on the back propogation speed with APEX's DistributedDataParallel")
parser.add_argument("--max_predictions_per_seq",
"--max_pred",
default=80,
type=int,
help="The maximum number of masked tokens in a sequence to be predicted.")
parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="accumulation_steps")
parser.add_argument("--train_micro_batch_size_per_gpu", default=2, type=int, required=True, help="train batch size")
parser.add_argument("--eval_micro_batch_size_per_gpu", default=2, type=int, required=True, help="eval batch size")
parser.add_argument("--num_workers", default=8, type=int, help="")
parser.add_argument("--async_worker", action='store_true', help="")
parser.add_argument("--bert_config", required=True, type=str, help="location of config.json")
parser.add_argument("--wandb", action='store_true', help="use wandb to watch model")
parser.add_argument("--wandb_project_name", default='roberta', help="wandb project name")
parser.add_argument("--log_interval", default=100, type=int, help="report interval")
parser.add_argument("--log_path", type=str, required=True, help="log file which records train step")
parser.add_argument("--tensorboard_path", type=str, required=True, help="location of tensorboard file")
parser.add_argument("--colossal_config",
type=str,
required=True,
help="colossal config, which contains zero config and so on")
parser.add_argument("--ckpt_path",
type=str,
required=True,
help="location of saving checkpoint, which contains model and optimizer")
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
parser.add_argument('--vscode_debug', action='store_true', help="use vscode to debug")
parser.add_argument('--load_pretrain_model', default='', type=str, help="location of model's checkpoin")
parser.add_argument(
'--load_optimizer_lr',
default='',
type=str,
help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step")
parser.add_argument('--resume_train', action='store_true', help="whether resume training from a early checkpoint")
parser.add_argument('--mlm', default='bert', type=str, help="model type, bert or deberta")
parser.add_argument('--checkpoint_activations', action='store_true', help="whether to use gradient checkpointing")
args = parser.parse_args()
return args

View File

@ -1,4 +1,5 @@
class BertDatasetProviderInterface: class BertDatasetProviderInterface:
def get_shard(self, index, shuffle=True): def get_shard(self, index, shuffle=True):
raise NotImplementedError raise NotImplementedError

View File

@ -1,9 +1,11 @@
import os
import math import math
import os
import torch import torch
from tqdm import tqdm
from utils.global_vars import get_timers, get_tensorboard_writer
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
from tqdm import tqdm
from utils.global_vars import get_tensorboard_writer, get_timers
def evaluate(model, args, logger, global_step, criterion): def evaluate(model, args, logger, global_step, criterion):
evaluate_dataset_provider = NvidiaBertDatasetProvider(args, evaluate=True) evaluate_dataset_provider = NvidiaBertDatasetProvider(args, evaluate=True)
@ -25,7 +27,10 @@ def evaluate(model, args, logger, global_step, criterion):
dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard) dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard)
# evaluate_dataset_provider.prefetch_shard(shard + 1) # evaluate_dataset_provider.prefetch_shard(shard + 1)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), colour='MAGENTA', smoothing=1) iterator_data = tqdm(enumerate(dataset_iterator),
total=(total_length // args.eval_micro_batch_size_per_gpu // world_size),
colour='MAGENTA',
smoothing=1)
else: else:
iterator_data = enumerate(dataset_iterator) iterator_data = enumerate(dataset_iterator)
@ -41,7 +46,7 @@ def evaluate(model, args, logger, global_step, criterion):
output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
loss = criterion(output.logits, mlm_label)#prediction_scores loss = criterion(output.logits, mlm_label) #prediction_scores
evaluate_dataset_provider.prefetch_batch() evaluate_dataset_provider.prefetch_batch()
eval_loss += loss.float().item() eval_loss += loss.float().item()

View File

@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
"""PyTorch BERT model.""" """PyTorch BERT model."""
import math import math
import os import os
import warnings import warnings
@ -27,7 +26,6 @@ import torch.utils.checkpoint
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@ -41,6 +39,7 @@ from transformers.modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.models.bert.configuration_bert import BertConfig
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from transformers.utils import ( from transformers.utils import (
ModelOutput, ModelOutput,
@ -50,8 +49,6 @@ from transformers.utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from transformers.models.bert.configuration_bert import BertConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -62,8 +59,7 @@ _TOKENIZER_FOR_DOC = "BertTokenizer"
# TokenClassification docstring # TokenClassification docstring
_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english" _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
_TOKEN_CLASS_EXPECTED_OUTPUT = ( _TOKEN_CLASS_EXPECTED_OUTPUT = (
"['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] " "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] ")
)
_TOKEN_CLASS_EXPECTED_LOSS = 0.01 _TOKEN_CLASS_EXPECTED_LOSS = 0.01
# QuestionAnswering docstring # QuestionAnswering docstring
@ -78,7 +74,6 @@ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-pol
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'" _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
_SEQ_CLASS_EXPECTED_LOSS = 0.01 _SEQ_CLASS_EXPECTED_LOSS = 0.01
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"bert-base-uncased", "bert-base-uncased",
"bert-large-uncased", "bert-large-uncased",
@ -114,10 +109,8 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
except ImportError: except ImportError:
logger.error( logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " "https://www.tensorflow.org/install/ for installation instructions.")
"https://www.tensorflow.org/install/ for installation instructions."
)
raise raise
tf_path = os.path.abspath(tf_checkpoint_path) tf_path = os.path.abspath(tf_checkpoint_path)
logger.info(f"Converting TensorFlow checkpoint from {tf_path}") logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
@ -135,10 +128,8 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
name = name.split("/") name = name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model # which are not required for using pretrained model
if any( if any(n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] for n in name):
for n in name
):
logger.info(f"Skipping {'/'.join(name)}") logger.info(f"Skipping {'/'.join(name)}")
continue continue
pointer = model pointer = model
@ -218,7 +209,7 @@ class BertEmbeddings(nn.Module):
seq_length = input_shape[1] seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] position_ids = self.position_ids[:, past_key_values_length:seq_length + past_key_values_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
@ -245,13 +236,12 @@ class BertEmbeddings(nn.Module):
class BertSelfAttention(nn.Module): class BertSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})")
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
@ -262,9 +252,7 @@ class BertSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr( self.position_embedding_type = position_embedding_type or getattr(config, "position_embedding_type", "absolute")
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
@ -372,6 +360,7 @@ class BertSelfAttention(nn.Module):
class BertSelfOutput(nn.Module): class BertSelfOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
@ -386,6 +375,7 @@ class BertSelfOutput(nn.Module):
class BertAttention(nn.Module): class BertAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type) self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
@ -395,9 +385,8 @@ class BertAttention(nn.Module):
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
return return
heads, index = find_pruneable_heads_and_indices( heads, index = find_pruneable_heads_and_indices(heads, self.self.num_attention_heads,
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads self.self.attention_head_size, self.pruned_heads)
)
# Prune linear layers # Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index) self.self.query = prune_linear_layer(self.self.query, index)
@ -435,6 +424,7 @@ class BertAttention(nn.Module):
class BertIntermediate(nn.Module): class BertIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
@ -450,6 +440,7 @@ class BertIntermediate(nn.Module):
class BertOutput(nn.Module): class BertOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
@ -464,6 +455,7 @@ class BertOutput(nn.Module):
class BertLayer(nn.Module): class BertLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -511,8 +503,7 @@ class BertLayer(nn.Module):
if not hasattr(self, "crossattention"): if not hasattr(self, "crossattention"):
raise ValueError( raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`" " by setting `config.add_cross_attention=True`")
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
@ -532,9 +523,8 @@ class BertLayer(nn.Module):
cross_attn_present_key_value = cross_attention_outputs[-1] cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward( layer_output = apply_chunking_to_forward(self.feed_forward_chunk, self.chunk_size_feed_forward,
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output self.seq_len_dim, attention_output)
)
outputs = (layer_output,) + outputs outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output # if decoder, return the attn key/values as the last output
@ -550,6 +540,7 @@ class BertLayer(nn.Module):
class BertEncoder(nn.Module): class BertEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
@ -585,11 +576,11 @@ class BertEncoder(nn.Module):
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
)
use_cache = False use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions) return module(*inputs, past_key_value, output_attentions)
@ -626,17 +617,13 @@ class BertEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: if not return_dict:
return tuple( return tuple(v for v in [
v
for v in [
hidden_states, hidden_states,
next_decoder_cache, next_decoder_cache,
all_hidden_states, all_hidden_states,
all_self_attentions, all_self_attentions,
all_cross_attentions, all_cross_attentions,
] ] if v is not None)
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_decoder_cache, past_key_values=next_decoder_cache,
@ -647,6 +634,7 @@ class BertEncoder(nn.Module):
class BertPooler(nn.Module): class BertPooler(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
@ -662,6 +650,7 @@ class BertPooler(nn.Module):
class BertPredictionHeadTransform(nn.Module): class BertPredictionHeadTransform(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
@ -679,6 +668,7 @@ class BertPredictionHeadTransform(nn.Module):
class BertLMPredictionHead(nn.Module): class BertLMPredictionHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.transform = BertPredictionHeadTransform(config) self.transform = BertPredictionHeadTransform(config)
@ -699,6 +689,7 @@ class BertLMPredictionHead(nn.Module):
class BertOnlyMLMHead(nn.Module): class BertOnlyMLMHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.predictions = BertLMPredictionHead(config) self.predictions = BertLMPredictionHead(config)
@ -709,6 +700,7 @@ class BertOnlyMLMHead(nn.Module):
class BertOnlyNSPHead(nn.Module): class BertOnlyNSPHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2) self.seq_relationship = nn.Linear(config.hidden_size, 2)
@ -719,6 +711,7 @@ class BertOnlyNSPHead(nn.Module):
class BertPreTrainingHeads(nn.Module): class BertPreTrainingHeads(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.predictions = BertLMPredictionHead(config) self.predictions = BertLMPredictionHead(config)
@ -950,9 +943,8 @@ class BertModel(BertPreTrainedModel):
`past_key_values`). `past_key_values`).
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (output_hidden_states
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder: if self.config.is_decoder:
@ -1051,6 +1043,7 @@ class BertModel(BertPreTrainedModel):
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
) )
class BertForPreTraining(BertPreTrainedModel): class BertForPreTraining(BertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@ -1151,9 +1144,8 @@ class BertForPreTraining(BertPreTrainedModel):
) )
@add_start_docstrings( @add_start_docstrings("""Bert Model with a `language modeling` head on top for CLM fine-tuning.""",
"""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING BERT_START_DOCSTRING)
)
class BertLMHeadModel(BertPreTrainedModel): class BertLMHeadModel(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
@ -1298,10 +1290,8 @@ class BertForMaskedLM(BertPreTrainedModel):
super().__init__(config) super().__init__(config)
if config.is_decoder: if config.is_decoder:
logger.warning( logger.warning("If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
"If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " "bi-directional self-attention.")
"bi-directional self-attention."
)
self.bert = BertModel(config, add_pooling_layer=False) self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config) self.cls = BertOnlyMLMHead(config)
@ -1390,9 +1380,10 @@ class BertForMaskedLM(BertPreTrainedModel):
raise ValueError("The PAD token should be defined for generation") raise ValueError("The PAD token should be defined for generation")
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
dummy_token = torch.full( dummy_token = torch.full((effective_batch_size, 1),
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device self.config.pad_token_id,
) dtype=torch.long,
device=input_ids.device)
input_ids = torch.cat([input_ids, dummy_token], dim=1) input_ids = torch.cat([input_ids, dummy_token], dim=1)
return {"input_ids": input_ids, "attention_mask": attention_mask} return {"input_ids": input_ids, "attention_mask": attention_mask}
@ -1403,6 +1394,7 @@ class BertForMaskedLM(BertPreTrainedModel):
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
) )
class BertForNextSentencePrediction(BertPreTrainedModel): class BertForNextSentencePrediction(BertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@ -1508,15 +1500,15 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
) )
class BertForSequenceClassification(BertPreTrainedModel): class BertForSequenceClassification(BertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.config = config self.config = config
self.bert = BertModel(config) self.bert = BertModel(config)
classifier_dropout = ( classifier_dropout = (config.classifier_dropout
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob if config.classifier_dropout is not None else config.hidden_dropout_prob)
)
self.dropout = nn.Dropout(classifier_dropout) self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
@ -1612,13 +1604,13 @@ class BertForSequenceClassification(BertPreTrainedModel):
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
) )
class BertForMultipleChoice(BertPreTrainedModel): class BertForMultipleChoice(BertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.bert = BertModel(config) self.bert = BertModel(config)
classifier_dropout = ( classifier_dropout = (config.classifier_dropout
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob if config.classifier_dropout is not None else config.hidden_dropout_prob)
)
self.dropout = nn.Dropout(classifier_dropout) self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, 1) self.classifier = nn.Linear(config.hidden_size, 1)
@ -1658,11 +1650,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
inputs_embeds = ( inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) if inputs_embeds is not None else None)
if inputs_embeds is not None
else None
)
outputs = self.bert( outputs = self.bert(
input_ids, input_ids,
@ -1715,9 +1704,8 @@ class BertForTokenClassification(BertPreTrainedModel):
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.bert = BertModel(config, add_pooling_layer=False) self.bert = BertModel(config, add_pooling_layer=False)
classifier_dropout = ( classifier_dropout = (config.classifier_dropout
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob if config.classifier_dropout is not None else config.hidden_dropout_prob)
)
self.dropout = nn.Dropout(classifier_dropout) self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)

View File

@ -23,7 +23,7 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from transformers import FillMaskPipeline, T5ForConditionalGeneration, T5Tokenizer
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutput, BaseModelOutput,
@ -34,10 +34,14 @@ from transformers.modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import softmax_backward_data
from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from transformers.models.deberta_v2.configuration_deberta_v2 import DebertaV2Config from transformers.models.deberta_v2.configuration_deberta_v2 import DebertaV2Config
from transformers import T5Tokenizer, T5ForConditionalGeneration, FillMaskPipeline from transformers.pytorch_utils import softmax_backward_data
from transformers.utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -55,6 +59,7 @@ DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [
# Copied from transformers.models.deberta.modeling_deberta.ContextPooler # Copied from transformers.models.deberta.modeling_deberta.ContextPooler
class ContextPooler(nn.Module): class ContextPooler(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
@ -133,15 +138,15 @@ class XSoftmax(torch.autograd.Function):
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"], to_i=sym_help.cast_pytorch_to_onnx["Byte"],
) )
output = masked_fill( output = masked_fill(g, self, r_mask,
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)))
)
output = softmax(g, output, dim) output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext # Copied from transformers.models.deberta.modeling_deberta.DropoutContext
class DropoutContext(object): class DropoutContext(object):
def __init__(self): def __init__(self):
self.dropout = 0 self.dropout = 0
self.mask = None self.mask = None
@ -244,6 +249,7 @@ class StableDropout(nn.Module):
# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm # Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm
class DebertaV2SelfOutput(nn.Module): class DebertaV2SelfOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
@ -259,6 +265,7 @@ class DebertaV2SelfOutput(nn.Module):
# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2 # Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2
class DebertaV2Attention(nn.Module): class DebertaV2Attention(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.self = DisentangledSelfAttention(config) self.self = DisentangledSelfAttention(config)
@ -296,6 +303,7 @@ class DebertaV2Attention(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2 # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2
class DebertaV2Intermediate(nn.Module): class DebertaV2Intermediate(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
@ -312,6 +320,7 @@ class DebertaV2Intermediate(nn.Module):
# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm # Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm
class DebertaV2Output(nn.Module): class DebertaV2Output(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
@ -328,6 +337,7 @@ class DebertaV2Output(nn.Module):
# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2 # Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
class DebertaV2Layer(nn.Module): class DebertaV2Layer(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.attention = DebertaV2Attention(config) self.attention = DebertaV2Attention(config)
@ -362,14 +372,17 @@ class DebertaV2Layer(nn.Module):
class ConvLayer(nn.Module): class ConvLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
kernel_size = getattr(config, "conv_kernel_size", 3) kernel_size = getattr(config, "conv_kernel_size", 3)
groups = getattr(config, "conv_groups", 1) groups = getattr(config, "conv_groups", 1)
self.conv_act = getattr(config, "conv_act", "tanh") self.conv_act = getattr(config, "conv_act", "tanh")
self.conv = nn.Conv1d( self.conv = nn.Conv1d(config.hidden_size,
config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups config.hidden_size,
) kernel_size,
padding=(kernel_size - 1) // 2,
groups=groups)
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
self.dropout = StableDropout(config.hidden_dropout_prob) self.dropout = StableDropout(config.hidden_dropout_prob)
self.config = config self.config = config
@ -452,9 +465,10 @@ class DebertaV2Encoder(nn.Module):
def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
if self.relative_attention and relative_pos is None: if self.relative_attention and relative_pos is None:
q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
relative_pos = build_relative_position( relative_pos = build_relative_position(q,
q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions hidden_states.size(-2),
) bucket_size=self.position_buckets,
max_position=self.max_relative_positions)
return relative_pos return relative_pos
def forward( def forward(
@ -491,6 +505,7 @@ class DebertaV2Encoder(nn.Module):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
return module(*inputs, output_attentions) return module(*inputs, output_attentions)
@ -535,9 +550,9 @@ class DebertaV2Encoder(nn.Module):
if not return_dict: if not return_dict:
return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput( return BaseModelOutput(last_hidden_state=output_states,
last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions hidden_states=all_hidden_states,
) attentions=all_attentions)
def make_log_bucket_position(relative_pos, bucket_size, max_position): def make_log_bucket_position(relative_pos, bucket_size, max_position):
@ -610,10 +625,8 @@ class DisentangledSelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0: if config.hidden_size % config.num_attention_heads != 0:
raise ValueError( raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})")
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
_attention_head_size = config.hidden_size // config.num_attention_heads _attention_head_size = config.hidden_size // config.num_attention_heads
self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size)
@ -706,28 +719,22 @@ class DisentangledSelfAttention(nn.Module):
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
if self.relative_attention: if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings) rel_embeddings = self.pos_dropout(rel_embeddings)
rel_att = self.disentangled_attention_bias( rel_att = self.disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings,
query_layer, key_layer, relative_pos, rel_embeddings, scale_factor scale_factor)
)
if rel_att is not None: if rel_att is not None:
attention_scores = attention_scores + rel_att attention_scores = attention_scores + rel_att
attention_scores = attention_scores attention_scores = attention_scores
attention_scores = attention_scores.view( attention_scores = attention_scores.view(-1, self.num_attention_heads, attention_scores.size(-2),
-1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) attention_scores.size(-1))
)
# bsz x height x length x dimension # bsz x height x length x dimension
attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
attention_probs = self.dropout(attention_probs) attention_probs = self.dropout(attention_probs)
context_layer = torch.bmm( context_layer = torch.bmm(attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)),
attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer value_layer)
) context_layer = (context_layer.view(-1, self.num_attention_heads, context_layer.size(-2),
context_layer = ( context_layer.size(-1)).permute(0, 2, 1, 3).contiguous())
context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1))
.permute(0, 2, 1, 3)
.contiguous()
)
new_context_layer_shape = context_layer.size()[:-2] + (-1,) new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.view(new_context_layer_shape) context_layer = context_layer.view(new_context_layer_shape)
if output_attentions: if output_attentions:
@ -738,9 +745,10 @@ class DisentangledSelfAttention(nn.Module):
def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
if relative_pos is None: if relative_pos is None:
q = query_layer.size(-2) q = query_layer.size(-2)
relative_pos = build_relative_position( relative_pos = build_relative_position(q,
q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions key_layer.size(-2),
) bucket_size=self.position_buckets,
max_position=self.max_relative_positions)
if relative_pos.dim() == 2: if relative_pos.dim() == 2:
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
elif relative_pos.dim() == 3: elif relative_pos.dim() == 3:
@ -758,25 +766,22 @@ class DisentangledSelfAttention(nn.Module):
# rel_embeddings = rel_embeddings.unsqueeze(0) # rel_embeddings = rel_embeddings.unsqueeze(0)
# rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0) # rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
if self.share_att_key: if self.share_att_key:
pos_query_layer = self.transpose_for_scores( pos_query_layer = self.transpose_for_scores(self.query_proj(rel_embeddings),
self.query_proj(rel_embeddings), self.num_attention_heads self.num_attention_heads).repeat(
).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) query_layer.size(0) // self.num_attention_heads, 1, 1)
pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat( pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(
query_layer.size(0) // self.num_attention_heads, 1, 1 query_layer.size(0) // self.num_attention_heads, 1, 1)
)
else: else:
if "c2p" in self.pos_att_type: if "c2p" in self.pos_att_type:
pos_key_layer = self.transpose_for_scores( pos_key_layer = self.transpose_for_scores(self.pos_key_proj(rel_embeddings),
self.pos_key_proj(rel_embeddings), self.num_attention_heads self.num_attention_heads).repeat(
).repeat( query_layer.size(0) // self.num_attention_heads, 1,
query_layer.size(0) // self.num_attention_heads, 1, 1 1) # .split(self.all_head_size, dim=-1)
) # .split(self.all_head_size, dim=-1)
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
pos_query_layer = self.transpose_for_scores( pos_query_layer = self.transpose_for_scores(self.pos_query_proj(rel_embeddings),
self.pos_query_proj(rel_embeddings), self.num_attention_heads self.num_attention_heads).repeat(
).repeat( query_layer.size(0) // self.num_attention_heads, 1,
query_layer.size(0) // self.num_attention_heads, 1, 1 1) # .split(self.all_head_size, dim=-1)
) # .split(self.all_head_size, dim=-1)
score = 0 score = 0
# content->position # content->position
@ -787,7 +792,9 @@ class DisentangledSelfAttention(nn.Module):
c2p_att = torch.gather( c2p_att = torch.gather(
c2p_att, c2p_att,
dim=-1, dim=-1,
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), index=c2p_pos.squeeze(0).expand([query_layer.size(0),
query_layer.size(1),
relative_pos.size(-1)]),
) )
score += c2p_att / scale score += c2p_att / scale
@ -810,7 +817,9 @@ class DisentangledSelfAttention(nn.Module):
p2c_att = torch.gather( p2c_att = torch.gather(
p2c_att, p2c_att,
dim=-1, dim=-1,
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), index=p2c_pos.squeeze(0).expand([query_layer.size(0),
key_layer.size(-2),
key_layer.size(-2)]),
).transpose(-1, -2) ).transpose(-1, -2)
score += p2c_att / scale score += p2c_att / scale
@ -990,6 +999,7 @@ DEBERTA_INPUTS_DOCSTRING = r"""
) )
# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 # Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2
class DebertaV2Model(DebertaV2PreTrainedModel): class DebertaV2Model(DebertaV2PreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@ -1032,9 +1042,8 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]: ) -> Union[Tuple, BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (output_hidden_states
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
@ -1091,7 +1100,7 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
sequence_output = encoded_layers[-1] sequence_output = encoded_layers[-1]
if not return_dict: if not return_dict:
return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :] return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2):]
return BaseModelOutput( return BaseModelOutput(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
@ -1182,6 +1191,7 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta # copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta
class DebertaV2PredictionHeadTransform(nn.Module): class DebertaV2PredictionHeadTransform(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
@ -1200,6 +1210,7 @@ class DebertaV2PredictionHeadTransform(nn.Module):
# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta # copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta
class DebertaV2LMPredictionHead(nn.Module): class DebertaV2LMPredictionHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.transform = DebertaV2PredictionHeadTransform(config) self.transform = DebertaV2PredictionHeadTransform(config)
@ -1221,6 +1232,7 @@ class DebertaV2LMPredictionHead(nn.Module):
# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta # copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta
class DebertaV2OnlyMLMHead(nn.Module): class DebertaV2OnlyMLMHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.predictions = DebertaV2LMPredictionHead(config) self.predictions = DebertaV2LMPredictionHead(config)
@ -1239,6 +1251,7 @@ class DebertaV2OnlyMLMHead(nn.Module):
) )
# Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2 # Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2
class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@ -1318,9 +1331,8 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
label_index = (labels >= 0).nonzero() label_index = (labels >= 0).nonzero()
labels = labels.long() labels = labels.long()
if label_index.size(0) > 0: if label_index.size(0) > 0:
labeled_logits = torch.gather( labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0),
logits, 0, label_index.expand(label_index.size(0), logits.size(1)) logits.size(1)))
)
labels = torch.gather(labels, 0, label_index.view(-1)) labels = torch.gather(labels, 0, label_index.view(-1))
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
@ -1345,9 +1357,10 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(loss=loss,
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions logits=logits,
) hidden_states=outputs.hidden_states,
attentions=outputs.attentions)
@add_start_docstrings( @add_start_docstrings(
@ -1422,9 +1435,10 @@ class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return TokenClassifierOutput(loss=loss,
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions logits=logits,
) hidden_states=outputs.hidden_states,
attentions=outputs.attentions)
@add_start_docstrings( @add_start_docstrings(
@ -1536,6 +1550,7 @@ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
DEBERTA_START_DOCSTRING, DEBERTA_START_DOCSTRING,
) )
class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel): class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@ -1591,11 +1606,8 @@ class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel):
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
flat_inputs_embeds = ( flat_inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) if inputs_embeds is not None else None)
if inputs_embeds is not None
else None
)
outputs = self.deberta( outputs = self.deberta(
flat_input_ids, flat_input_ids,

View File

@ -1,24 +1,25 @@
import json
import logging
import os import os
import random import random
import h5py
import logging
import json
import time import time
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
import h5py
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import RandomSampler
from torch.utils.data.distributed import DistributedSampler
from bert_dataset_provider import BertDatasetProviderInterface from bert_dataset_provider import BertDatasetProviderInterface
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler
import colossalai.utils as utils import colossalai.utils as utils
# Workaround because python functions are not picklable # Workaround because python functions are not picklable
class WorkerInitObj(object): class WorkerInitObj(object):
def __init__(self, seed): def __init__(self, seed):
self.seed = seed self.seed = seed
@ -27,29 +28,25 @@ class WorkerInitObj(object):
random.seed(self.seed + id) random.seed(self.seed + id)
def create_pretraining_dataset(input_file, max_predictions_per_seq, def create_pretraining_dataset(input_file, max_predictions_per_seq, num_workers, train_batch_size, worker_init,
num_workers, train_batch_size, worker_init,
data_sampler): data_sampler):
train_data = pretraining_dataset( train_data = pretraining_dataset(input_file=input_file, max_predictions_per_seq=max_predictions_per_seq)
input_file=input_file, max_predictions_per_seq=max_predictions_per_seq)
train_dataloader = DataLoader(train_data, train_dataloader = DataLoader(train_data,
sampler=data_sampler(train_data), sampler=data_sampler(train_data),
batch_size=train_batch_size, batch_size=train_batch_size,
num_workers=num_workers, num_workers=num_workers,
worker_init_fn=worker_init, worker_init_fn=worker_init,
pin_memory=True pin_memory=True)
)
return train_dataloader, len(train_data) return train_dataloader, len(train_data)
class pretraining_dataset(Dataset): class pretraining_dataset(Dataset):
def __init__(self, input_file, max_predictions_per_seq): def __init__(self, input_file, max_predictions_per_seq):
self.input_file = input_file self.input_file = input_file
self.max_predictions_per_seq = max_predictions_per_seq self.max_predictions_per_seq = max_predictions_per_seq
f = h5py.File(input_file, "r") f = h5py.File(input_file, "r")
keys = [ keys = ['input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions']
'input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions'
]
self.inputs = [np.asarray(f[key][:]) for key in keys] self.inputs = [np.asarray(f[key][:]) for key in keys]
f.close() f.close()
@ -59,21 +56,16 @@ class pretraining_dataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
[ [input_ids, input_mask, segment_ids, masked_lm_labels] = [
input_ids, input_mask, segment_ids, masked_lm_labels torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else torch.from_numpy(
] = [ np.asarray(input[index].astype(np.int64))) for indice, input in enumerate(self.inputs)
torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else
torch.from_numpy(np.asarray(input[index].astype(np.int64)))
for indice, input in enumerate(self.inputs)
] ]
return [ return [input_ids, input_mask, segment_ids, masked_lm_labels]
input_ids, input_mask,
segment_ids, masked_lm_labels
]
class NvidiaBertDatasetProvider(BertDatasetProviderInterface): class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
def __init__(self, args, evaluate=False): def __init__(self, args, evaluate=False):
self.num_workers = args.num_workers self.num_workers = args.num_workers
self.max_seq_length = args.max_seq_length self.max_seq_length = args.max_seq_length
@ -92,13 +84,15 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
# Initialize dataset files # Initialize dataset files
if not evaluate: if not evaluate:
self.dataset_files = [ self.dataset_files = [
os.path.join(args.data_path_prefix, f) for f in os.listdir(args.data_path_prefix) if os.path.join(args.data_path_prefix, f)
os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f for f in os.listdir(args.data_path_prefix)
if os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f
] ]
else: else:
self.dataset_files = [ self.dataset_files = [
os.path.join(args.eval_data_path_prefix, f) for f in os.listdir(args.eval_data_path_prefix) if os.path.join(args.eval_data_path_prefix, f)
os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f for f in os.listdir(args.eval_data_path_prefix)
if os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f
] ]
self.dataset_files.sort() self.dataset_files.sort()
@ -114,9 +108,7 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
self.shuffle = True self.shuffle = True
if self.global_rank == 0: if self.global_rank == 0:
self.logger.info( self.logger.info(f"NvidiaBertDatasetProvider - Initialization: num_files = {self.num_files}")
f"NvidiaBertDatasetProvider - Initialization: num_files = {self.num_files}"
)
def get_shard(self, index): def get_shard(self, index):
start = time.time() start = time.time()
@ -130,8 +122,7 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
worker_init=self.worker_init, worker_init=self.worker_init,
data_sampler=self.data_sampler) data_sampler=self.data_sampler)
else: else:
self.train_dataloader, sample_count = self.dataset_future.result( self.train_dataloader, sample_count = self.dataset_future.result(timeout=None)
timeout=None)
self.logger.info( self.logger.info(
f"Data Loading Completed for Pretraining Data from {self.data_file} with {sample_count} samples took {time.time()-start:.2f}s." f"Data Loading Completed for Pretraining Data from {self.data_file} with {sample_count} samples took {time.time()-start:.2f}s."
@ -145,10 +136,8 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
def prefetch_shard(self, index): def prefetch_shard(self, index):
self.data_file = self._get_shard_file(index) self.data_file = self._get_shard_file(index)
self.dataset_future = self.pool.submit( self.dataset_future = self.pool.submit(create_pretraining_dataset, self.data_file, self.max_predictions_per_seq,
create_pretraining_dataset, self.data_file, self.num_workers, self.train_micro_batch_size_per_gpu, self.worker_init,
self.max_predictions_per_seq, self.num_workers,
self.train_micro_batch_size_per_gpu, self.worker_init,
self.data_sampler) self.data_sampler)
def get_batch(self, batch_iter): def get_batch(self, batch_iter):
@ -179,4 +168,3 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
indices = torch.randperm(self.num_files, generator=g).tolist() indices = torch.randperm(self.num_files, generator=g).tolist()
new_dataset = [self.dataset_files[i] for i in indices] new_dataset = [self.dataset_files[i] for i in indices]
self.dataset_files = new_dataset self.dataset_files = new_dataset

View File

@ -1,23 +1,32 @@
import transformers
import logging import logging
from colossalai.nn.lr_scheduler import LinearWarmupLR
from transformers import get_linear_schedule_with_warmup
from transformers import BertForPreTraining, RobertaForMaskedLM, RobertaConfig
from transformers import GPT2Config, GPT2LMHeadModel
from transformers import AutoTokenizer, AutoModelForMaskedLM
from colossalai.nn.optimizer import FusedAdam, HybridAdam
from torch.optim import AdamW
from colossalai.core import global_context as gpc
import torch
import os import os
import sys import sys
sys.path.append(os.getcwd())
from model.deberta_v2 import DebertaV2ForMaskedLM
from model.bert import BertForMaskedLM
import torch.nn as nn
import torch
import transformers
from torch.optim import AdamW
from transformers import (
AutoModelForMaskedLM,
AutoTokenizer,
BertForPreTraining,
GPT2Config,
GPT2LMHeadModel,
RobertaConfig,
RobertaForMaskedLM,
get_linear_schedule_with_warmup,
)
from colossalai.core import global_context as gpc
from colossalai.nn.lr_scheduler import LinearWarmupLR
from colossalai.nn.optimizer import FusedAdam, HybridAdam
sys.path.append(os.getcwd())
from collections import OrderedDict from collections import OrderedDict
import torch.nn as nn
from model.bert import BertForMaskedLM
from model.deberta_v2 import DebertaV2ForMaskedLM
__all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining'] __all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining']
@ -30,6 +39,7 @@ def get_new_state_dict(state_dict, start_index=13):
class LMModel(nn.Module): class LMModel(nn.Module):
def __init__(self, model, config, args): def __init__(self, model, config, args):
super().__init__() super().__init__()
@ -58,9 +68,11 @@ def get_model(args, logger):
if len(args.load_pretrain_model) > 0: if len(args.load_pretrain_model) > 0:
assert os.path.exists(args.load_pretrain_model) assert os.path.exists(args.load_pretrain_model)
# load_checkpoint(args.load_pretrain_model, model, strict=False) # load_checkpoint(args.load_pretrain_model, model, strict=False)
m_state_dict = torch.load(args.load_pretrain_model, map_location=torch.device(f"cuda:{torch.cuda.current_device()}")) m_state_dict = torch.load(args.load_pretrain_model,
map_location=torch.device(f"cuda:{torch.cuda.current_device()}"))
# new_state_dict = get_new_state_dict(m_state_dict) # new_state_dict = get_new_state_dict(m_state_dict)
model.load_state_dict(m_state_dict, strict=True) # must insure that every process have identical parameters !!!!!!! model.load_state_dict(m_state_dict,
strict=True) # must insure that every process have identical parameters !!!!!!!
logger.info("load model success") logger.info("load model success")
numel = sum([p.numel() for p in model.parameters()]) numel = sum([p.numel() for p in model.parameters()])
@ -89,7 +101,10 @@ def get_optimizer(model, lr):
def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1): def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1):
# warmup_steps = int(total_steps * warmup_ratio) # warmup_steps = int(total_steps * warmup_ratio)
lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, last_epoch=last_epoch) lr_scheduler = get_linear_schedule_with_warmup(optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
last_epoch=last_epoch)
# lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps) # lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps)
return lr_scheduler return lr_scheduler
@ -107,6 +122,3 @@ def save_ckpt(model, optimizer, lr_scheduler, path, epoch, shard, global_step):
if gpc.get_global_rank() == 0: if gpc.get_global_rank() == 0:
torch.save(checkpoint, optimizer_lr_path) torch.save(checkpoint, optimizer_lr_path)
torch.save(model_state, model_path) torch.save(model_state, model_path)

View File

@ -35,4 +35,3 @@ env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \
--mlm bert \ --mlm bert \
--wandb \ --wandb \
--checkpoint_activations \ --checkpoint_activations \

View File

@ -38,4 +38,3 @@ env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \
--resume_train \ --resume_train \
--load_pretrain_model /ckpt/1.pt \ --load_pretrain_model /ckpt/1.pt \
--load_optimizer_lr /ckpt/1.op_lrs \ --load_optimizer_lr /ckpt/1.op_lrs \

View File

@ -4,21 +4,6 @@ import time
from functools import partial from functools import partial
import torch import torch
from tqdm import tqdm
import os
import time
from functools import partial
from transformers import AutoTokenizer
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from arguments import parse_args from arguments import parse_args
from evaluation import evaluate from evaluation import evaluate
from loss import LossForPretraining from loss import LossForPretraining
@ -30,6 +15,15 @@ from utils.exp_util import get_mem_info, get_tflops, log_args, throughput_calcul
from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables
from utils.logger import Logger from utils.logger import Logger
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
def main(): def main():
@ -156,14 +150,15 @@ def main():
start_epoch = o_l_state_dict['epoch'] start_epoch = o_l_state_dict['epoch']
start_shard = o_l_state_dict['shard'] + 1 start_shard = o_l_state_dict['shard'] + 1
# global_step = o_l_state_dict['global_step'] + 1 # global_step = o_l_state_dict['global_step'] + 1
logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}') logger.info(
f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}'
)
criterion = LossForPretraining(config.vocab_size) criterion = LossForPretraining(config.vocab_size)
# build dataloader # build dataloader
pretrain_dataset_provider = NvidiaBertDatasetProvider(args) pretrain_dataset_provider = NvidiaBertDatasetProvider(args)
logger.info(get_mem_info(prefix='After init model, ')) logger.info(get_mem_info(prefix='After init model, '))
best_loss = None best_loss = None
@ -242,8 +237,9 @@ def main():
logger.info('*' * 100) logger.info('*' * 100)
eval_loss += evaluate(model, args, logger, global_step, criterion) eval_loss += evaluate(model, args, logger, global_step, criterion)
save_ckpt(model, optimizer, lr_scheduler, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step) save_ckpt(model, optimizer, lr_scheduler,
os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch,
shard, global_step)
eval_loss /= len(os.listdir(args.data_path_prefix)) eval_loss /= len(os.listdir(args.data_path_prefix))
logger.info( logger.info(

View File

@ -1,8 +1,10 @@
import time
import wandb
import os import os
import time
import wandb
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
class WandbLog: class WandbLog:
@classmethod @classmethod
@ -38,9 +40,3 @@ class TensorboardLog:
def log_zeroshot(self, result, step): def log_zeroshot(self, result, step):
for k, v in result.items(): for k, v in result.items():
self.writer.add_scalar(f'{k}_acc/eval', v, step) self.writer.add_scalar(f'{k}_acc/eval', v, step)

View File

@ -1,9 +1,13 @@
import functools import functools
import os, shutil import os
import torch import shutil
import psutil import psutil
import torch
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
def logging(s, log_path, print_=True, log_=True): def logging(s, log_path, print_=True, log_=True):
if print_: if print_:
print(s) print(s)
@ -11,9 +15,11 @@ def logging(s, log_path, print_=True, log_=True):
with open(log_path, 'a+') as f_log: with open(log_path, 'a+') as f_log:
f_log.write(s + '\n') f_log.write(s + '\n')
def get_logger(log_path, **kwargs): def get_logger(log_path, **kwargs):
return functools.partial(logging, log_path=log_path, **kwargs) return functools.partial(logging, log_path=log_path, **kwargs)
def create_exp_dir(dir_path, scripts_to_save=None, debug=False): def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
if debug: if debug:
print('Debug Mode : no experiment dir created') print('Debug Mode : no experiment dir created')
@ -33,6 +39,7 @@ def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
return get_logger(log_path=os.path.join(dir_path, 'log.txt')) return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
def get_cpu_mem(): def get_cpu_mem():
return psutil.Process().memory_info().rss / 1024**2 return psutil.Process().memory_info().rss / 1024**2
@ -52,11 +59,15 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
def get_parameters_in_billions(model, world_size=1): def get_parameters_in_billions(model, world_size=1):
gpus_per_model = world_size gpus_per_model = world_size
approx_parameters_in_billions = sum([sum([p.ds_numel if hasattr(p,'ds_id') else p.nelement() for p in model_module.parameters()]) approx_parameters_in_billions = sum([
for model_module in model]) sum([p.ds_numel if hasattr(p, 'ds_id') else p.nelement()
for p in model_module.parameters()])
for model_module in model
])
return approx_parameters_in_billions * gpus_per_model / (1e9) return approx_parameters_in_billions * gpus_per_model / (1e9)
def throughput_calculator(numel, args, config, iteration_time, total_iterations, world_size=1): def throughput_calculator(numel, args, config, iteration_time, total_iterations, world_size=1):
gpus_per_model = 1 gpus_per_model = 1
batch_size = args.train_micro_batch_size_per_gpu batch_size = args.train_micro_batch_size_per_gpu
@ -76,10 +87,13 @@ def throughput_calculator(numel, args, config, iteration_time, total_iterations,
# The factor of 4 is when used with activation check-pointing, # The factor of 4 is when used with activation check-pointing,
# otherwise it will be 3. # otherwise it will be 3.
checkpoint_activations_factor = 4 if args.checkpoint_activations else 3 checkpoint_activations_factor = 4 if args.checkpoint_activations else 3
flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers * (hidden_size**2)) * (1. + (args.max_seq_length / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size))) flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers *
(hidden_size**2)) * (1. + (args.max_seq_length / (6. * hidden_size)) +
(vocab_size / (16. * num_layers * hidden_size)))
tflops = flops_per_iteration / (elapsed_time_per_iter * (10**12)) tflops = flops_per_iteration / (elapsed_time_per_iter * (10**12))
return samples_per_second, tflops, approx_parameters_in_billions return samples_per_second, tflops, approx_parameters_in_billions
def synchronize(): def synchronize():
if not torch.distributed.is_available(): if not torch.distributed.is_available():
return return
@ -90,6 +104,7 @@ def synchronize():
return return
torch.distributed.barrier() torch.distributed.barrier()
def log_args(logger, args): def log_args(logger, args):
logger.info('--------args----------') logger.info('--------args----------')
message = '\n'.join([f'{k:<30}: {v}' for k, v in vars(args).items()]) message = '\n'.join([f'{k:<30}: {v}' for k, v in vars(args).items()])

View File

@ -1,5 +1,7 @@
import time import time
import torch import torch
from .WandbLog import TensorboardLog from .WandbLog import TensorboardLog
_GLOBAL_TIMERS = None _GLOBAL_TIMERS = None
@ -10,30 +12,34 @@ def set_global_variables(launch_time, tensorboard_path):
_set_timers() _set_timers()
_set_tensorboard_writer(launch_time, tensorboard_path) _set_tensorboard_writer(launch_time, tensorboard_path)
def _set_timers(): def _set_timers():
"""Initialize timers.""" """Initialize timers."""
global _GLOBAL_TIMERS global _GLOBAL_TIMERS
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
_GLOBAL_TIMERS = Timers() _GLOBAL_TIMERS = Timers()
def _set_tensorboard_writer(launch_time, tensorboard_path): def _set_tensorboard_writer(launch_time, tensorboard_path):
"""Set tensorboard writer.""" """Set tensorboard writer."""
global _GLOBAL_TENSORBOARD_WRITER global _GLOBAL_TENSORBOARD_WRITER
_ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, 'tensorboard writer')
'tensorboard writer')
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
_GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f'/{launch_time}', launch_time) _GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f'/{launch_time}', launch_time)
def get_timers(): def get_timers():
"""Return timers.""" """Return timers."""
_ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
return _GLOBAL_TIMERS return _GLOBAL_TIMERS
def get_tensorboard_writer(): def get_tensorboard_writer():
"""Return tensorboard writer. It can be None so no need """Return tensorboard writer. It can be None so no need
to check if it is initialized.""" to check if it is initialized."""
return _GLOBAL_TENSORBOARD_WRITER return _GLOBAL_TENSORBOARD_WRITER
def _ensure_var_is_initialized(var, name): def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None.""" """Make sure the input variable is not None."""
assert var is not None, '{} is not initialized.'.format(name) assert var is not None, '{} is not initialized.'.format(name)
@ -115,12 +121,10 @@ class Timers:
assert normalizer > 0.0 assert normalizer > 0.0
string = 'time (ms)' string = 'time (ms)'
for name in names: for name in names:
elapsed_time = self.timers[name].elapsed( elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time) string += ' | {}: {:.2f}'.format(name, elapsed_time)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if torch.distributed.get_rank() == ( if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1):
torch.distributed.get_world_size() - 1):
print(string, flush=True) print(string, flush=True)
else: else:
print(string, flush=True) print(string, flush=True)

View File

@ -1,22 +1,22 @@
import os
import logging import logging
import os
import torch.distributed as dist import torch.distributed as dist
logging.basicConfig( logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S', datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO) level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Logger(): class Logger():
def __init__(self, log_path, cuda=False, debug=False): def __init__(self, log_path, cuda=False, debug=False):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.cuda = cuda self.cuda = cuda
self.log_path = log_path self.log_path = log_path
self.debug = debug self.debug = debug
def info(self, message, log_=True, print_=True, *args, **kwargs): def info(self, message, log_=True, print_=True, *args, **kwargs):
if (self.cuda and dist.get_rank() == 0) or not self.cuda: if (self.cuda and dist.get_rank() == 0) or not self.cuda:
if print_: if print_:
@ -26,6 +26,5 @@ class Logger():
with open(self.log_path, 'a+') as f_log: with open(self.log_path, 'a+') as f_log:
f_log.write(message + '\n') f_log.write(message + '\n')
def error(self, message, *args, **kwargs): def error(self, message, *args, **kwargs):
self.logger.error(message, *args, **kwargs) self.logger.error(message, *args, **kwargs)

View File

@ -1,184 +0,0 @@
#include <algorithm>
#include <iostream>
#include <limits>
#include <math.h>
#include <stdexcept>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <random>
#include <vector>
#include <string>
#include <pybind11/stl.h>
#include <chrono>
#include <tuple>
#include <unordered_set>
#include <unordered_map>
namespace py = pybind11;
const int32_t LONG_SENTENCE_LEN = 512;
struct MaskedLMInstance {
int index;
std::string label;
MaskedLMInstance(int index, std::string label) {
this->index = index;
this->label = label;
}
};
auto get_new_segment(std::vector<std::string> segment, std::vector<std::string> segment_jieba, const std::vector<bool> chinese_vocab) { // const std::unordered_set<std::string> &chinese_vocab
std::unordered_set<std::string> seq_cws_dict;
for (auto word : segment_jieba) {
seq_cws_dict.insert(word);
}
int i = 0;
std::vector<std::string> new_segment;
int segment_size = segment.size();
while (i < segment_size) {
if (!chinese_vocab[i]) { //chinese_vocab.find(segment[i]) == chinese_vocab.end()
new_segment.emplace_back(segment[i]);
i += 1;
continue;
}
bool has_add = false;
for (int length = 3; length >= 1; length--) {
if (i + length > segment_size) {
continue;
}
std::string chinese_word = "";
for (int j = i; j < i + length; j++) {
chinese_word += segment[j];
}
if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) {
new_segment.emplace_back(segment[i]);
for (int j = i + 1; j < i + length; j++) {
new_segment.emplace_back("##" + segment[j]);
}
i += length;
has_add = true;
break;
}
}
if (!has_add) {
new_segment.emplace_back(segment[i]);
i += 1;
}
}
return new_segment;
}
bool startsWith(const std::string& s, const std::string& sub) {
return s.find(sub) == 0 ? true : false;
}
auto create_whole_masked_lm_predictions(std::vector<std::string> &tokens,
const std::vector<std::string> &original_tokens,
const std::vector<std::string> &vocab_words,
std::map<std::string, int> &vocab,
const int max_predictions_per_seq,
const double masked_lm_prob) {
// for (auto item : vocab) {
// std::cout << "key=" << std::string(py::str(item.first)) << ", "
// << "value=" << std::string(py::str(item.second)) << std::endl;
// }
std::vector<std::vector<int> > cand_indexes;
std::vector<int> cand_temp;
int tokens_size = tokens.size();
std::string prefix = "##";
bool do_whole_masked = true;
for (int i = 0; i < tokens_size; i++) {
if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") {
continue;
}
if (do_whole_masked && (cand_indexes.size() > 0) && (tokens[i].rfind(prefix, 0) == 0)) {
cand_temp.emplace_back(i);
}
else {
if (cand_temp.size() > 0) {
cand_indexes.emplace_back(cand_temp);
}
cand_temp.clear();
cand_temp.emplace_back(i);
}
}
auto seed = std::chrono::system_clock::now().time_since_epoch().count();
std::shuffle(cand_indexes.begin(), cand_indexes.end(), std::default_random_engine(seed));
// for (auto i : cand_indexes) {
// for (auto j : i) {
// std::cout << tokens[j] << " ";
// }
// std::cout << std::endl;
// }
// for (auto i : output_tokens) {
// std::cout << i;
// }
// std::cout << std::endl;
int num_to_predict = std::min(max_predictions_per_seq,
std::max(1, int(tokens_size * masked_lm_prob)));
// std::cout << num_to_predict << std::endl;
std::set<int> covered_indexes;
std::vector<int> masked_lm_output(tokens_size, -1);
int vocab_words_len = vocab_words.size();
std::default_random_engine e(seed);
std::uniform_real_distribution<double> u1(0.0, 1.0);
std::uniform_int_distribution<unsigned> u2(0, vocab_words_len - 1);
int mask_cnt = 0;
std::vector<std::string> output_tokens;
output_tokens = original_tokens;
for (auto index_set : cand_indexes) {
if (mask_cnt > num_to_predict) {
break;
}
int index_set_size = index_set.size();
if (mask_cnt + index_set_size > num_to_predict) {
continue;
}
bool is_any_index_covered = false;
for (auto index : index_set) {
if (covered_indexes.find(index) != covered_indexes.end()) {
is_any_index_covered = true;
break;
}
}
if (is_any_index_covered) {
continue;
}
for (auto index : index_set) {
covered_indexes.insert(index);
std::string masked_token;
if (u1(e) < 0.8) {
masked_token = "[MASK]";
}
else {
if (u1(e) < 0.5) {
masked_token = output_tokens[index];
}
else {
int random_index = u2(e);
masked_token = vocab_words[random_index];
}
}
// masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index]));
masked_lm_output[index] = vocab[output_tokens[index]];
output_tokens[index] = masked_token;
mask_cnt++;
}
}
// for (auto p : masked_lms) {
// masked_lm_output[p.index] = vocab[p.label];
// }
return std::make_tuple(output_tokens, masked_lm_output);
}
PYBIND11_MODULE(mask, m) {
m.def("create_whole_masked_lm_predictions", &create_whole_masked_lm_predictions);
m.def("get_new_segment", &get_new_segment);
}

View File

@ -1,176 +0,0 @@
import colossalai
from numpy import require
__all__ = ['parse_args']
def parse_args():
parser = colossalai.get_default_parser()
parser.add_argument(
"--distplan",
type=str,
default='CAI_Gemini',
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
)
parser.add_argument(
"--tp_degree",
type=int,
default=1,
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--shardinit",
action='store_true',
help="Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
)
parser.add_argument(
'--lr',
type=float,
required=True,
help='initial learning rate')
parser.add_argument(
'--epoch',
type=int,
required=True,
help='number of epoch')
parser.add_argument(
'--data_path_prefix',
type=str,
required=True,
help="location of the train data corpus")
parser.add_argument(
'--eval_data_path_prefix',
type=str,
required=True,
help='location of the evaluation data corpus')
parser.add_argument(
'--tokenizer_path',
type=str,
required=True,
help='location of the tokenizer')
parser.add_argument(
'--max_seq_length',
type=int,
default=512,
help='sequence length')
parser.add_argument(
'--refresh_bucket_size',
type=int,
default=1,
help=
"This param makes sure that a certain task is repeated for this time steps to \
optimise on the back propogation speed with APEX's DistributedDataParallel")
parser.add_argument(
"--max_predictions_per_seq",
"--max_pred",
default=80,
type=int,
help=
"The maximum number of masked tokens in a sequence to be predicted.")
parser.add_argument(
"--gradient_accumulation_steps",
default=1,
type=int,
help="accumulation_steps")
parser.add_argument(
"--train_micro_batch_size_per_gpu",
default=2,
type=int,
required=True,
help="train batch size")
parser.add_argument(
"--eval_micro_batch_size_per_gpu",
default=2,
type=int,
required=True,
help="eval batch size")
parser.add_argument(
"--num_workers",
default=8,
type=int,
help="")
parser.add_argument(
"--async_worker",
action='store_true',
help="")
parser.add_argument(
"--bert_config",
required=True,
type=str,
help="location of config.json")
parser.add_argument(
"--wandb",
action='store_true',
help="use wandb to watch model")
parser.add_argument(
"--wandb_project_name",
default='roberta',
help="wandb project name")
parser.add_argument(
"--log_interval",
default=100,
type=int,
help="report interval")
parser.add_argument(
"--log_path",
type=str,
required=True,
help="log file which records train step")
parser.add_argument(
"--tensorboard_path",
type=str,
required=True,
help="location of tensorboard file")
parser.add_argument(
"--colossal_config",
type=str,
required=True,
help="colossal config, which contains zero config and so on")
parser.add_argument(
"--ckpt_path",
type=str,
required=True,
help="location of saving checkpoint, which contains model and optimizer")
parser.add_argument(
'--seed',
type=int,
default=42,
help="random seed for initialization")
parser.add_argument(
'--vscode_debug',
action='store_true',
help="use vscode to debug")
parser.add_argument(
'--load_pretrain_model',
default='',
type=str,
help="location of model's checkpoin")
parser.add_argument(
'--load_optimizer_lr',
default='',
type=str,
help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step")
parser.add_argument(
'--resume_train',
action='store_true',
help="whether resume training from a early checkpoint")
parser.add_argument(
'--mlm',
default='bert',
type=str,
help="model type, bert or deberta")
parser.add_argument(
'--checkpoint_activations',
action='store_true',
help="whether to use gradient checkpointing")
args = parser.parse_args()
return args