mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 04:32:47 +00:00
[example] reorganize for community examples (#3557)
This commit is contained in:
parent
1a809eddaa
commit
f1b3d60cae
@ -10,9 +10,12 @@
|
|||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
This folder provides several examples accelerated by Colossal-AI. The `tutorial` folder is for everyone to quickly try out the different features in Colossal-AI. Other folders such as `images` and `language` include a wide range of deep learning tasks and applications.
|
This folder provides several examples accelerated by Colossal-AI.
|
||||||
|
Folders such as `images` and `language` include a wide range of deep learning tasks and applications.
|
||||||
|
The `community` folder aim to create a collaborative platform for developers to contribute exotic features built on top of Colossal-AI.
|
||||||
|
The `tutorial` folder is for everyone to quickly try out the different features in Colossal-AI.
|
||||||
|
|
||||||
You can find applications such as Chatbot, Stable Diffusion and Biomedicine in the [Applications](https://github.com/hpcaitech/ColossalAI/tree/main/applications) directory.
|
You can find applications such as Chatbot, AIGC and Biomedicine in the [Applications](https://github.com/hpcaitech/ColossalAI/tree/main/applications) directory.
|
||||||
|
|
||||||
## Folder Structure
|
## Folder Structure
|
||||||
|
|
||||||
@ -52,3 +55,10 @@ Therefore, it is essential for the example contributors to know how to integrate
|
|||||||
2. Configure your testing parameters such as number steps, batch size in `test_ci.sh`, e.t.c. Keep these parameters small such that each example only takes several minutes.
|
2. Configure your testing parameters such as number steps, batch size in `test_ci.sh`, e.t.c. Keep these parameters small such that each example only takes several minutes.
|
||||||
3. Export your dataset path with the prefix `/data` and make sure you have a copy of the dataset in the `/data/scratch/examples-data` directory on the CI machine. Community contributors can contact us via slack to request for downloading the dataset on the CI machine.
|
3. Export your dataset path with the prefix `/data` and make sure you have a copy of the dataset in the `/data/scratch/examples-data` directory on the CI machine. Community contributors can contact us via slack to request for downloading the dataset on the CI machine.
|
||||||
4. Implement the logic such as dependency setup and example execution
|
4. Implement the logic such as dependency setup and example execution
|
||||||
|
|
||||||
|
## Community Dependency
|
||||||
|
We are happy to introduce the following nice community dependency repos that are powered by Colossal-AI:
|
||||||
|
- [lightning-ColossalAI](https://github.com/Lightning-AI/lightning)
|
||||||
|
- [HCP-Diffusion](https://github.com/7eu7d7/HCP-Diffusion)
|
||||||
|
- [KoChatGPT](https://github.com/airobotlab/KoChatGPT)
|
||||||
|
- [minichatgpt](https://github.com/juncongmoo/minichatgpt)
|
||||||
|
28
examples/community/README.md
Normal file
28
examples/community/README.md
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
#Community Examples
|
||||||
|
|
||||||
|
Community-driven Examples is an initiative that allows users to share their own examples to the Colossal-AI community, fostering a sense of community and making it easy for others to access and benefit from shared work. The primary goal with community-driven examples is to have a community-maintained collection of diverse and exotic functionalities built on top of the Colossal-AI package.
|
||||||
|
|
||||||
|
If a community example doesn't work as expected, you can [open an issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) and @ the author to report it.
|
||||||
|
|
||||||
|
|
||||||
|
| Example | Description | Code Example | Colab |Author |
|
||||||
|
|:------------------|:---------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------|:-----------------------------------------|-----------------------------------------------------:|
|
||||||
|
| RoBERTa | Adding RoBERTa for SFT and Prompts model training | [RoBERTa](./roberta) | - | [YY Lin](https://github.com/yynil) (Moore Threads) |
|
||||||
|
| TransformerEngine FP8 | Adding TransformerEngine with FP8 training | [TransformerEngine FP8](./fp8) | - | [Kirthi Shankar Sivamani](https://github.com/ksivaman) (NVIDIA) |
|
||||||
|
|...|...|...|...|...|
|
||||||
|
|
||||||
|
## Looking for Examples
|
||||||
|
* [Swin-Transformer](https://github.com/microsoft/Swin-Transformer)
|
||||||
|
* [T-5](https://github.com/google-research/text-to-text-transfer-transformer)
|
||||||
|
* [Segment Anything (SAM)](https://github.com/facebookresearch/segment-anything)
|
||||||
|
* [ControlNet](https://github.com/lllyasviel/ControlNet)
|
||||||
|
* [Consistency Models](https://github.com/openai/consistency_models)
|
||||||
|
* [MAE](https://github.com/facebookresearch/mae)
|
||||||
|
* [CLIP](https://github.com/openai/CLIP)
|
||||||
|
|
||||||
|
Welcome to [open an issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) to share your insights and needs.
|
||||||
|
|
||||||
|
## How to get involved
|
||||||
|
To join our community-driven initiative, please visit the [Colossal-AI examples](https://github.com/hpcaitech/ColossalAI/tree/main/examples), review the provided information, and explore the codebase.
|
||||||
|
|
||||||
|
To contribute, create a new issue outlining your proposed feature or enhancement, and our team will review and provide feedback. If you are confident enough you can also submit a PR directly. We look forward to collaborating with you on this exciting project!
|
@ -3,12 +3,13 @@
|
|||||||
# See LICENSE for license information.
|
# See LICENSE for license information.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torchvision import datasets, transforms
|
|
||||||
from torch.optim.lr_scheduler import StepLR
|
from torch.optim.lr_scheduler import StepLR
|
||||||
|
from torchvision import datasets, transforms
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformer_engine import pytorch as te
|
from transformer_engine import pytorch as te
|
||||||
@ -18,6 +19,7 @@ except (ImportError, ModuleNotFoundError):
|
|||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
|
|
||||||
def __init__(self, use_te=False):
|
def __init__(self, use_te=False):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
||||||
@ -62,12 +64,10 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
if batch_idx % args.log_interval == 0:
|
if batch_idx % args.log_interval == 0:
|
||||||
print(
|
print(f"Train Epoch: {epoch} "
|
||||||
f"Train Epoch: {epoch} "
|
|
||||||
f"[{batch_idx * len(data)}/{len(train_loader.dataset)} "
|
f"[{batch_idx * len(data)}/{len(train_loader.dataset)} "
|
||||||
f"({100. * batch_idx / len(train_loader):.0f}%)]\t"
|
f"({100. * batch_idx / len(train_loader):.0f}%)]\t"
|
||||||
f"Loss: {loss.item():.6f}"
|
f"Loss: {loss.item():.6f}")
|
||||||
)
|
|
||||||
if args.dry_run:
|
if args.dry_run:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -83,6 +83,7 @@ def calibrate(model, device, test_loader):
|
|||||||
with te.fp8_autocast(enabled=False, calibrating=True):
|
with te.fp8_autocast(enabled=False, calibrating=True):
|
||||||
output = model(data)
|
output = model(data)
|
||||||
|
|
||||||
|
|
||||||
def test(model, device, test_loader, use_fp8):
|
def test(model, device, test_loader, use_fp8):
|
||||||
"""Testing function."""
|
"""Testing function."""
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -93,21 +94,15 @@ def test(model, device, test_loader, use_fp8):
|
|||||||
data, target = data.to(device), target.to(device)
|
data, target = data.to(device), target.to(device)
|
||||||
with te.fp8_autocast(enabled=use_fp8):
|
with te.fp8_autocast(enabled=use_fp8):
|
||||||
output = model(data)
|
output = model(data)
|
||||||
test_loss += F.nll_loss(
|
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
|
||||||
output, target, reduction="sum"
|
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
|
||||||
).item() # sum up batch loss
|
|
||||||
pred = output.argmax(
|
|
||||||
dim=1, keepdim=True
|
|
||||||
) # get the index of the max log-probability
|
|
||||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||||
|
|
||||||
test_loss /= len(test_loader.dataset)
|
test_loss /= len(test_loader.dataset)
|
||||||
|
|
||||||
print(
|
print(f"\nTest set: Average loss: {test_loss:.4f}, "
|
||||||
f"\nTest set: Average loss: {test_loss:.4f}, "
|
|
||||||
f"Accuracy: {correct}/{len(test_loader.dataset)} "
|
f"Accuracy: {correct}/{len(test_loader.dataset)} "
|
||||||
f"({100. * correct / len(test_loader.dataset):.0f}%)\n"
|
f"({100. * correct / len(test_loader.dataset):.0f}%)\n")
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -154,9 +149,7 @@ def main():
|
|||||||
default=False,
|
default=False,
|
||||||
help="quickly check a single pass",
|
help="quickly check a single pass",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
|
||||||
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--log-interval",
|
"--log-interval",
|
||||||
type=int,
|
type=int,
|
||||||
@ -170,15 +163,12 @@ def main():
|
|||||||
default=False,
|
default=False,
|
||||||
help="For Saving the current Model",
|
help="For Saving the current Model",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--use-fp8",
|
||||||
"--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration"
|
action="store_true",
|
||||||
)
|
default=False,
|
||||||
parser.add_argument(
|
help="Use FP8 for inference and training without recalibration")
|
||||||
"--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only"
|
parser.add_argument("--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only")
|
||||||
)
|
parser.add_argument("--use-te", action="store_true", default=False, help="Use Transformer Engine")
|
||||||
parser.add_argument(
|
|
||||||
"--use-te", action="store_true", default=False, help="Use Transformer Engine"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
@ -205,9 +195,7 @@ def main():
|
|||||||
train_kwargs.update(cuda_kwargs)
|
train_kwargs.update(cuda_kwargs)
|
||||||
test_kwargs.update(cuda_kwargs)
|
test_kwargs.update(cuda_kwargs)
|
||||||
|
|
||||||
transform = transforms.Compose(
|
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
||||||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
|
||||||
)
|
|
||||||
dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
|
dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
|
||||||
dataset2 = datasets.MNIST("../data", train=False, transform=transform)
|
dataset2 = datasets.MNIST("../data", train=False, transform=transform)
|
||||||
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
|
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
|
||||||
@ -227,7 +215,7 @@ def main():
|
|||||||
|
|
||||||
if args.save_model or args.use_fp8_infer:
|
if args.save_model or args.use_fp8_infer:
|
||||||
torch.save(model.state_dict(), "mnist_cnn.pt")
|
torch.save(model.state_dict(), "mnist_cnn.pt")
|
||||||
print('Eval with reloaded checkpoint : fp8='+str(args.use_fp8_infer))
|
print('Eval with reloaded checkpoint : fp8=' + str(args.use_fp8_infer))
|
||||||
weights = torch.load("mnist_cnn.pt")
|
weights = torch.load("mnist_cnn.pt")
|
||||||
model.load_state_dict(weights)
|
model.load_state_dict(weights)
|
||||||
test(model, device, test_loader, args.use_fp8_infer)
|
test(model, device, test_loader, args.use_fp8_infer)
|
@ -1,20 +1,22 @@
|
|||||||
import torch
|
import collections
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from random import choice
|
from random import choice
|
||||||
import random
|
|
||||||
import collections
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
import jieba
|
import jieba
|
||||||
|
import torch
|
||||||
|
|
||||||
jieba.setLogLevel(logging.CRITICAL)
|
jieba.setLogLevel(logging.CRITICAL)
|
||||||
import re
|
import re
|
||||||
import numpy as np
|
|
||||||
import mask
|
import mask
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
PAD = 0
|
PAD = 0
|
||||||
MaskedLMInstance = collections.namedtuple("MaskedLMInstance",
|
MaskedLMInstance = collections.namedtuple("MaskedLMInstance", ["index", "label"])
|
||||||
["index", "label"])
|
|
||||||
|
|
||||||
|
|
||||||
def map_to_numpy(data):
|
def map_to_numpy(data):
|
||||||
@ -22,6 +24,7 @@ def map_to_numpy(data):
|
|||||||
|
|
||||||
|
|
||||||
class PreTrainingDataset():
|
class PreTrainingDataset():
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
max_seq_length,
|
max_seq_length,
|
||||||
@ -43,14 +46,12 @@ class PreTrainingDataset():
|
|||||||
self.mlm_tamper_p = 0.05
|
self.mlm_tamper_p = 0.05
|
||||||
self.mlm_maintain_p = 0.1
|
self.mlm_maintain_p = 0.1
|
||||||
|
|
||||||
|
|
||||||
def tokenize(self, doc):
|
def tokenize(self, doc):
|
||||||
temp = []
|
temp = []
|
||||||
for d in doc:
|
for d in doc:
|
||||||
temp.append(self.tokenizer.tokenize(d))
|
temp.append(self.tokenizer.tokenize(d))
|
||||||
return temp
|
return temp
|
||||||
|
|
||||||
|
|
||||||
def create_training_instance(self, instance):
|
def create_training_instance(self, instance):
|
||||||
is_next = 1
|
is_next = 1
|
||||||
raw_text_list = self.get_new_segment(instance)
|
raw_text_list = self.get_new_segment(instance)
|
||||||
@ -83,8 +84,9 @@ class PreTrainingDataset():
|
|||||||
|
|
||||||
# Get Masked LM predictions
|
# Get Masked LM predictions
|
||||||
if self.backend == 'c++':
|
if self.backend == 'c++':
|
||||||
output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions(tokens, original_tokens, self.vocab_words,
|
output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions(
|
||||||
self.tokenizer.vocab, self.max_predictions_per_seq, self.masked_lm_prob)
|
tokens, original_tokens, self.vocab_words, self.tokenizer.vocab, self.max_predictions_per_seq,
|
||||||
|
self.masked_lm_prob)
|
||||||
elif self.backend == 'python':
|
elif self.backend == 'python':
|
||||||
output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens)
|
output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens)
|
||||||
|
|
||||||
@ -105,14 +107,12 @@ class PreTrainingDataset():
|
|||||||
map_to_numpy([is_next])
|
map_to_numpy([is_next])
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
def create_masked_lm_predictions(self, tokens):
|
def create_masked_lm_predictions(self, tokens):
|
||||||
cand_indexes = []
|
cand_indexes = []
|
||||||
for i, token in enumerate(tokens):
|
for i, token in enumerate(tokens):
|
||||||
if token == "[CLS]" or token == "[SEP]":
|
if token == "[CLS]" or token == "[SEP]":
|
||||||
continue
|
continue
|
||||||
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and
|
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")):
|
||||||
token.startswith("##")):
|
|
||||||
cand_indexes[-1].append(i)
|
cand_indexes[-1].append(i)
|
||||||
else:
|
else:
|
||||||
cand_indexes.append([i])
|
cand_indexes.append([i])
|
||||||
@ -122,9 +122,7 @@ class PreTrainingDataset():
|
|||||||
random.shuffle(cand_indexes)
|
random.shuffle(cand_indexes)
|
||||||
output_tokens = list(tokens)
|
output_tokens = list(tokens)
|
||||||
|
|
||||||
num_to_predict = min(
|
num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob))))
|
||||||
self.max_predictions_per_seq,
|
|
||||||
max(1, int(round(len(tokens) * self.masked_lm_prob))))
|
|
||||||
|
|
||||||
masked_lms = []
|
masked_lms = []
|
||||||
covered_indexes = set()
|
covered_indexes = set()
|
||||||
@ -145,13 +143,10 @@ class PreTrainingDataset():
|
|||||||
masked_token = tokens[index]
|
masked_token = tokens[index]
|
||||||
# 10% replace w/ random word
|
# 10% replace w/ random word
|
||||||
else:
|
else:
|
||||||
masked_token = self.vocab_words[random.randint(
|
masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)]
|
||||||
0,
|
|
||||||
len(self.vocab_words) - 1)]
|
|
||||||
|
|
||||||
output_tokens[index] = masked_token
|
output_tokens[index] = masked_token
|
||||||
masked_lms.append(
|
masked_lms.append(MaskedLMInstance(index=index, label=tokens[index]))
|
||||||
MaskedLMInstance(index=index, label=tokens[index]))
|
|
||||||
|
|
||||||
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
||||||
masked_lm_output = [-1] * len(output_tokens)
|
masked_lm_output = [-1] * len(output_tokens)
|
||||||
@ -160,7 +155,6 @@ class PreTrainingDataset():
|
|||||||
|
|
||||||
return (output_tokens, masked_lm_output)
|
return (output_tokens, masked_lm_output)
|
||||||
|
|
||||||
|
|
||||||
def get_new_segment(self, segment):
|
def get_new_segment(self, segment):
|
||||||
"""
|
"""
|
||||||
Input a sentence, return a processed sentence: In order to support the Chinese whole word mask, the words that are separated will be marked with a special mark ("#"), so that the subsequent processing module can know which words belong to the same word.
|
Input a sentence, return a processed sentence: In order to support the Chinese whole word mask, the words that are separated will be marked with a special mark ("#"), so that the subsequent processing module can know which words belong to the same word.
|
||||||
@ -180,10 +174,10 @@ class PreTrainingDataset():
|
|||||||
for length in range(3, 0, -1):
|
for length in range(3, 0, -1):
|
||||||
if i + length > len(segment):
|
if i + length > len(segment):
|
||||||
continue
|
continue
|
||||||
if ''.join(segment[i: i+length]) in seq_cws_dict:
|
if ''.join(segment[i:i + length]) in seq_cws_dict:
|
||||||
new_segment.append(segment[i])
|
new_segment.append(segment[i])
|
||||||
for l in range(1, length):
|
for l in range(1, length):
|
||||||
new_segment.append('##' + segment[i+l])
|
new_segment.append('##' + segment[i + l])
|
||||||
i += length
|
i += length
|
||||||
has_add = True
|
has_add = True
|
||||||
break
|
break
|
||||||
@ -192,7 +186,6 @@ class PreTrainingDataset():
|
|||||||
i += 1
|
i += 1
|
||||||
return new_segment
|
return new_segment
|
||||||
|
|
||||||
|
|
||||||
def create_whole_masked_lm_predictions(self, tokens):
|
def create_whole_masked_lm_predictions(self, tokens):
|
||||||
"""Creates the predictions for the masked LM objective."""
|
"""Creates the predictions for the masked LM objective."""
|
||||||
|
|
||||||
@ -209,18 +202,16 @@ class PreTrainingDataset():
|
|||||||
# Note that Whole Word Masking does *not* change the training code
|
# Note that Whole Word Masking does *not* change the training code
|
||||||
# at all -- we still predict each WordPiece independently, softmaxed
|
# at all -- we still predict each WordPiece independently, softmaxed
|
||||||
# over the entire vocabulary.
|
# over the entire vocabulary.
|
||||||
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and
|
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")):
|
||||||
token.startswith("##")):
|
|
||||||
cand_indexes[-1].append(i)
|
cand_indexes[-1].append(i)
|
||||||
else:
|
else:
|
||||||
cand_indexes.append([i])
|
cand_indexes.append([i])
|
||||||
|
|
||||||
random.shuffle(cand_indexes)
|
random.shuffle(cand_indexes)
|
||||||
|
|
||||||
output_tokens = [t[2:] if len(self.whole_rec.findall(t))>0 else t for t in tokens] # 去掉"##"
|
output_tokens = [t[2:] if len(self.whole_rec.findall(t)) > 0 else t for t in tokens] # 去掉"##"
|
||||||
|
|
||||||
num_to_predict = min(self.max_predictions_per_seq,
|
num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob))))
|
||||||
max(1, int(round(len(tokens) * self.masked_lm_prob))))
|
|
||||||
|
|
||||||
masked_lms = []
|
masked_lms = []
|
||||||
covered_indexes = set()
|
covered_indexes = set()
|
||||||
@ -248,14 +239,18 @@ class PreTrainingDataset():
|
|||||||
else:
|
else:
|
||||||
# 10% of the time, keep original
|
# 10% of the time, keep original
|
||||||
if random.random() < 0.5:
|
if random.random() < 0.5:
|
||||||
masked_token = tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index] # 去掉"##"
|
masked_token = tokens[index][2:] if len(self.whole_rec.findall(
|
||||||
|
tokens[index])) > 0 else tokens[index] # 去掉"##"
|
||||||
# 10% of the time, replace with random word
|
# 10% of the time, replace with random word
|
||||||
else:
|
else:
|
||||||
masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)]
|
masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)]
|
||||||
|
|
||||||
output_tokens[index] = masked_token
|
output_tokens[index] = masked_token
|
||||||
|
|
||||||
masked_lms.append(MaskedLMInstance(index=index, label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index]))
|
masked_lms.append(
|
||||||
|
MaskedLMInstance(
|
||||||
|
index=index,
|
||||||
|
label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index]))
|
||||||
assert len(masked_lms) <= num_to_predict
|
assert len(masked_lms) <= num_to_predict
|
||||||
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
||||||
masked_lm_output = [-1] * len(output_tokens)
|
masked_lm_output = [-1] * len(output_tokens)
|
190
examples/community/roberta/preprocessing/mask.cpp
Normal file
190
examples/community/roberta/preprocessing/mask.cpp
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
#include <math.h>
|
||||||
|
#include <pybind11/numpy.h>
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <chrono>
|
||||||
|
#include <iostream>
|
||||||
|
#include <limits>
|
||||||
|
#include <random>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
const int32_t LONG_SENTENCE_LEN = 512;
|
||||||
|
|
||||||
|
struct MaskedLMInstance {
|
||||||
|
int index;
|
||||||
|
std::string label;
|
||||||
|
MaskedLMInstance(int index, std::string label) {
|
||||||
|
this->index = index;
|
||||||
|
this->label = label;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto get_new_segment(
|
||||||
|
std::vector<std::string> segment, std::vector<std::string> segment_jieba,
|
||||||
|
const std::vector<bool> chinese_vocab) { // const
|
||||||
|
// std::unordered_set<std::string>
|
||||||
|
// &chinese_vocab
|
||||||
|
std::unordered_set<std::string> seq_cws_dict;
|
||||||
|
for (auto word : segment_jieba) {
|
||||||
|
seq_cws_dict.insert(word);
|
||||||
|
}
|
||||||
|
int i = 0;
|
||||||
|
std::vector<std::string> new_segment;
|
||||||
|
int segment_size = segment.size();
|
||||||
|
while (i < segment_size) {
|
||||||
|
if (!chinese_vocab[i]) { // chinese_vocab.find(segment[i]) ==
|
||||||
|
// chinese_vocab.end()
|
||||||
|
new_segment.emplace_back(segment[i]);
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
bool has_add = false;
|
||||||
|
for (int length = 3; length >= 1; length--) {
|
||||||
|
if (i + length > segment_size) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::string chinese_word = "";
|
||||||
|
for (int j = i; j < i + length; j++) {
|
||||||
|
chinese_word += segment[j];
|
||||||
|
}
|
||||||
|
if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) {
|
||||||
|
new_segment.emplace_back(segment[i]);
|
||||||
|
for (int j = i + 1; j < i + length; j++) {
|
||||||
|
new_segment.emplace_back("##" + segment[j]);
|
||||||
|
}
|
||||||
|
i += length;
|
||||||
|
has_add = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!has_add) {
|
||||||
|
new_segment.emplace_back(segment[i]);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return new_segment;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool startsWith(const std::string &s, const std::string &sub) {
|
||||||
|
return s.find(sub) == 0 ? true : false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto create_whole_masked_lm_predictions(
|
||||||
|
std::vector<std::string> &tokens,
|
||||||
|
const std::vector<std::string> &original_tokens,
|
||||||
|
const std::vector<std::string> &vocab_words,
|
||||||
|
std::map<std::string, int> &vocab, const int max_predictions_per_seq,
|
||||||
|
const double masked_lm_prob) {
|
||||||
|
// for (auto item : vocab) {
|
||||||
|
// std::cout << "key=" << std::string(py::str(item.first)) << ", "
|
||||||
|
// << "value=" << std::string(py::str(item.second)) <<
|
||||||
|
// std::endl;
|
||||||
|
// }
|
||||||
|
std::vector<std::vector<int> > cand_indexes;
|
||||||
|
std::vector<int> cand_temp;
|
||||||
|
int tokens_size = tokens.size();
|
||||||
|
std::string prefix = "##";
|
||||||
|
bool do_whole_masked = true;
|
||||||
|
|
||||||
|
for (int i = 0; i < tokens_size; i++) {
|
||||||
|
if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (do_whole_masked && (cand_indexes.size() > 0) &&
|
||||||
|
(tokens[i].rfind(prefix, 0) == 0)) {
|
||||||
|
cand_temp.emplace_back(i);
|
||||||
|
} else {
|
||||||
|
if (cand_temp.size() > 0) {
|
||||||
|
cand_indexes.emplace_back(cand_temp);
|
||||||
|
}
|
||||||
|
cand_temp.clear();
|
||||||
|
cand_temp.emplace_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto seed = std::chrono::system_clock::now().time_since_epoch().count();
|
||||||
|
std::shuffle(cand_indexes.begin(), cand_indexes.end(),
|
||||||
|
std::default_random_engine(seed));
|
||||||
|
// for (auto i : cand_indexes) {
|
||||||
|
// for (auto j : i) {
|
||||||
|
// std::cout << tokens[j] << " ";
|
||||||
|
// }
|
||||||
|
// std::cout << std::endl;
|
||||||
|
// }
|
||||||
|
// for (auto i : output_tokens) {
|
||||||
|
// std::cout << i;
|
||||||
|
// }
|
||||||
|
// std::cout << std::endl;
|
||||||
|
|
||||||
|
int num_to_predict = std::min(max_predictions_per_seq,
|
||||||
|
std::max(1, int(tokens_size * masked_lm_prob)));
|
||||||
|
// std::cout << num_to_predict << std::endl;
|
||||||
|
|
||||||
|
std::set<int> covered_indexes;
|
||||||
|
std::vector<int> masked_lm_output(tokens_size, -1);
|
||||||
|
int vocab_words_len = vocab_words.size();
|
||||||
|
std::default_random_engine e(seed);
|
||||||
|
std::uniform_real_distribution<double> u1(0.0, 1.0);
|
||||||
|
std::uniform_int_distribution<unsigned> u2(0, vocab_words_len - 1);
|
||||||
|
int mask_cnt = 0;
|
||||||
|
std::vector<std::string> output_tokens;
|
||||||
|
output_tokens = original_tokens;
|
||||||
|
|
||||||
|
for (auto index_set : cand_indexes) {
|
||||||
|
if (mask_cnt > num_to_predict) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
int index_set_size = index_set.size();
|
||||||
|
if (mask_cnt + index_set_size > num_to_predict) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
bool is_any_index_covered = false;
|
||||||
|
for (auto index : index_set) {
|
||||||
|
if (covered_indexes.find(index) != covered_indexes.end()) {
|
||||||
|
is_any_index_covered = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (is_any_index_covered) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (auto index : index_set) {
|
||||||
|
covered_indexes.insert(index);
|
||||||
|
std::string masked_token;
|
||||||
|
if (u1(e) < 0.8) {
|
||||||
|
masked_token = "[MASK]";
|
||||||
|
} else {
|
||||||
|
if (u1(e) < 0.5) {
|
||||||
|
masked_token = output_tokens[index];
|
||||||
|
} else {
|
||||||
|
int random_index = u2(e);
|
||||||
|
masked_token = vocab_words[random_index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index]));
|
||||||
|
masked_lm_output[index] = vocab[output_tokens[index]];
|
||||||
|
output_tokens[index] = masked_token;
|
||||||
|
mask_cnt++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// for (auto p : masked_lms) {
|
||||||
|
// masked_lm_output[p.index] = vocab[p.label];
|
||||||
|
// }
|
||||||
|
return std::make_tuple(output_tokens, masked_lm_output);
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(mask, m) {
|
||||||
|
m.def("create_whole_masked_lm_predictions",
|
||||||
|
&create_whole_masked_lm_predictions);
|
||||||
|
m.def("get_new_segment", &get_new_segment);
|
||||||
|
}
|
@ -1,13 +1,14 @@
|
|||||||
|
import argparse
|
||||||
|
import functools
|
||||||
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from tqdm import tqdm
|
|
||||||
from typing import List
|
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
import argparse
|
from typing import List
|
||||||
import functools
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]:
|
def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]:
|
||||||
sent_list = []
|
sent_list = []
|
||||||
@ -17,7 +18,8 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s
|
|||||||
document = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\'])', r'\g<quotation_mark>\n', document)
|
document = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\'])', r'\g<quotation_mark>\n', document)
|
||||||
elif flag == "en":
|
elif flag == "en":
|
||||||
document = re.sub('(?P<quotation_mark>([.?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
|
document = re.sub('(?P<quotation_mark>([.?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
|
||||||
document = re.sub('(?P<quotation_mark>([?!.]["\']))', r'\g<quotation_mark>\n', document) # Special quotation marks
|
document = re.sub('(?P<quotation_mark>([?!.]["\']))', r'\g<quotation_mark>\n',
|
||||||
|
document) # Special quotation marks
|
||||||
else:
|
else:
|
||||||
document = re.sub('(?P<quotation_mark>([。?!….?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
|
document = re.sub('(?P<quotation_mark>([。?!….?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
|
||||||
|
|
||||||
@ -43,9 +45,7 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s
|
|||||||
return sent_list
|
return sent_list
|
||||||
|
|
||||||
|
|
||||||
def get_sent(output_path,
|
def get_sent(output_path, input_path, fin_list=[], host=-1, seq_len=512) -> None:
|
||||||
input_path,
|
|
||||||
fin_list=[], host=-1, seq_len=512) -> None:
|
|
||||||
|
|
||||||
workers = 32
|
workers = 32
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ def get_sent(output_path,
|
|||||||
input_path = input_path[:-1]
|
input_path = input_path[:-1]
|
||||||
|
|
||||||
cur_path = os.path.join(output_path, str(host) + '.txt')
|
cur_path = os.path.join(output_path, str(host) + '.txt')
|
||||||
new_split_sentence = functools.partial(split_sentence, limit=seq_len-2)
|
new_split_sentence = functools.partial(split_sentence, limit=seq_len - 2)
|
||||||
with open(cur_path, 'w', encoding='utf-8') as f:
|
with open(cur_path, 'w', encoding='utf-8') as f:
|
||||||
for fi, fin_path in enumerate(fin_list):
|
for fi, fin_path in enumerate(fin_list):
|
||||||
if not os.path.exists(os.path.join(input_path, fin_path[0])):
|
if not os.path.exists(os.path.join(input_path, fin_path[0])):
|
||||||
@ -136,11 +136,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
for index, shard in enumerate(real_shard):
|
for index, shard in enumerate(real_shard):
|
||||||
get_sent(output_path,
|
get_sent(output_path, input_path, fin_list=shard, host=index, seq_len=seq_len)
|
||||||
input_path,
|
|
||||||
fin_list=shard,
|
|
||||||
host=index,
|
|
||||||
seq_len=seq_len)
|
|
||||||
print(f'cost {str(time.time() - start)}')
|
print(f'cost {str(time.time() - start)}')
|
||||||
|
|
||||||
# if you have multiple server, you can use code below or modify code to openmpi
|
# if you have multiple server, you can use code below or modify code to openmpi
|
@ -1,19 +1,19 @@
|
|||||||
import time
|
|
||||||
import os
|
|
||||||
import psutil
|
|
||||||
import h5py
|
|
||||||
import socket
|
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from tqdm import tqdm
|
import os
|
||||||
|
import socket
|
||||||
|
import time
|
||||||
from random import shuffle
|
from random import shuffle
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
import psutil
|
||||||
from get_mask import PreTrainingDataset
|
from get_mask import PreTrainingDataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
def get_raw_instance(document, max_sequence_length=512):
|
def get_raw_instance(document, max_sequence_length=512):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Get the initial training instances, split the whole segment into multiple parts according to the max_sequence_length, and return as multiple processed instances.
|
Get the initial training instances, split the whole segment into multiple parts according to the max_sequence_length, and return as multiple processed instances.
|
||||||
:param document: document
|
:param document: document
|
||||||
@ -37,7 +37,7 @@ def get_raw_instance(document, max_sequence_length=512):
|
|||||||
if len(curr_seq) > 0:
|
if len(curr_seq) > 0:
|
||||||
result_list.append(curr_seq)
|
result_list.append(curr_seq)
|
||||||
curr_seq = []
|
curr_seq = []
|
||||||
result_list.append(document[sz_idx][ : max_sequence_length_allowed])
|
result_list.append(document[sz_idx][:max_sequence_length_allowed])
|
||||||
sz_idx += 1
|
sz_idx += 1
|
||||||
else:
|
else:
|
||||||
result_list.append(curr_seq)
|
result_list.append(curr_seq)
|
||||||
@ -70,8 +70,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host):
|
|||||||
# document = line
|
# document = line
|
||||||
# if len(document.split("<sep>")) <= 3:
|
# if len(document.split("<sep>")) <= 3:
|
||||||
# continue
|
# continue
|
||||||
if len(line
|
if len(line) > 0 and line[:2] == "]]": # This is end of document
|
||||||
) > 0 and line[:2] == "]]": # This is end of document
|
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
document = []
|
document = []
|
||||||
elif len(line) >= 2:
|
elif len(line) >= 2:
|
||||||
@ -84,8 +83,8 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host):
|
|||||||
# print(len(documents))
|
# print(len(documents))
|
||||||
# print(len(documents[0]))
|
# print(len(documents[0]))
|
||||||
# print(documents[0][0:10])
|
# print(documents[0][0:10])
|
||||||
from typing import List
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
from typing import List
|
||||||
|
|
||||||
ans = []
|
ans = []
|
||||||
for docs in tqdm(documents):
|
for docs in tqdm(documents):
|
||||||
@ -124,13 +123,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host):
|
|||||||
del instances
|
del instances
|
||||||
|
|
||||||
|
|
||||||
def split_numpy_chunk_pool(input_path,
|
def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_factor, seq_len, file_name):
|
||||||
output_path,
|
|
||||||
pretrain_data,
|
|
||||||
worker,
|
|
||||||
dupe_factor,
|
|
||||||
seq_len,
|
|
||||||
file_name):
|
|
||||||
|
|
||||||
if os.path.exists(os.path.join(output_path, f'{file_name}.h5')):
|
if os.path.exists(os.path.join(output_path, f'{file_name}.h5')):
|
||||||
print(f'{file_name}.h5 exists')
|
print(f'{file_name}.h5 exists')
|
||||||
@ -144,8 +137,7 @@ def split_numpy_chunk_pool(input_path,
|
|||||||
document = []
|
document = []
|
||||||
for i, line in enumerate(tqdm(fd)):
|
for i, line in enumerate(tqdm(fd)):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if len(line
|
if len(line) > 0 and line[:2] == "]]": # This is end of document
|
||||||
) > 0 and line[:2] == "]]": # This is end of document
|
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
document = []
|
document = []
|
||||||
elif len(line) >= 2:
|
elif len(line) >= 2:
|
||||||
@ -212,11 +204,21 @@ if __name__ == '__main__':
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer')
|
parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer')
|
||||||
parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
|
parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
|
||||||
parser.add_argument('--max_predictions_per_seq', type=int, default=80, help='number of shards, e.g., 10, 50, or 100')
|
parser.add_argument('--max_predictions_per_seq',
|
||||||
|
type=int,
|
||||||
|
default=80,
|
||||||
|
help='number of shards, e.g., 10, 50, or 100')
|
||||||
parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence')
|
parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence')
|
||||||
parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id')
|
parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id')
|
||||||
parser.add_argument('--backend', type=str, default='python', help='backend of mask token, python, c++, numpy respectively')
|
parser.add_argument('--backend',
|
||||||
parser.add_argument('--dupe_factor', type=int, default=1, help='specifies how many times the preprocessor repeats to create the input from the same article/document')
|
type=str,
|
||||||
|
default='python',
|
||||||
|
help='backend of mask token, python, c++, numpy respectively')
|
||||||
|
parser.add_argument(
|
||||||
|
'--dupe_factor',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help='specifies how many times the preprocessor repeats to create the input from the same article/document')
|
||||||
parser.add_argument('--worker', type=int, default=32, help='number of process')
|
parser.add_argument('--worker', type=int, default=32, help='number of process')
|
||||||
parser.add_argument('--server_num', type=int, default=10, help='number of servers')
|
parser.add_argument('--server_num', type=int, default=10, help='number of servers')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -227,7 +229,6 @@ if __name__ == '__main__':
|
|||||||
args.backend,
|
args.backend,
|
||||||
max_predictions_per_seq=args.max_predictions_per_seq)
|
max_predictions_per_seq=args.max_predictions_per_seq)
|
||||||
|
|
||||||
|
|
||||||
data_len = len(os.listdir(args.input_path))
|
data_len = len(os.listdir(args.input_path))
|
||||||
|
|
||||||
for i in range(data_len):
|
for i in range(data_len):
|
||||||
@ -235,15 +236,10 @@ if __name__ == '__main__':
|
|||||||
if os.path.exists(input_path):
|
if os.path.exists(input_path):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
print(f'process {input_path}')
|
print(f'process {input_path}')
|
||||||
split_numpy_chunk_pool(input_path,
|
split_numpy_chunk_pool(input_path, args.output_path, pretrain_data, args.worker, args.dupe_factor,
|
||||||
args.output_path,
|
args.seq_len, i)
|
||||||
pretrain_data,
|
|
||||||
args.worker,
|
|
||||||
args.dupe_factor,
|
|
||||||
args.seq_len,
|
|
||||||
i)
|
|
||||||
end_ = time.time()
|
end_ = time.time()
|
||||||
print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
|
print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024))
|
||||||
print(f'has cost {(end_ - start) / 60}')
|
print(f'has cost {(end_ - start) / 60}')
|
||||||
print('-' * 100)
|
print('-' * 100)
|
||||||
print('')
|
print('')
|
||||||
@ -269,5 +265,3 @@ if __name__ == '__main__':
|
|||||||
# print(f'has cost {(end_ - start) / 60}')
|
# print(f'has cost {(end_ - start) / 60}')
|
||||||
# print('-' * 100)
|
# print('-' * 100)
|
||||||
# print('')
|
# print('')
|
||||||
|
|
||||||
|
|
@ -21,4 +21,3 @@ bash run_pretrain_resume.sh
|
|||||||
* `--resume_train`: whether to resume training
|
* `--resume_train`: whether to resume training
|
||||||
* `--load_pretrain_model`: absolute path which contains model checkpoint
|
* `--load_pretrain_model`: absolute path which contains model checkpoint
|
||||||
* `--load_optimizer_lr`: absolute path which contains optimizer checkpoint
|
* `--load_optimizer_lr`: absolute path which contains optimizer checkpoint
|
||||||
|
|
87
examples/community/roberta/pretraining/arguments.py
Normal file
87
examples/community/roberta/pretraining/arguments.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from numpy import require
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
|
||||||
|
__all__ = ['parse_args']
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = colossalai.get_default_parser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--distplan",
|
||||||
|
type=str,
|
||||||
|
default='CAI_Gemini',
|
||||||
|
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tp_degree",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--placement",
|
||||||
|
type=str,
|
||||||
|
default='cpu',
|
||||||
|
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--shardinit",
|
||||||
|
action='store_true',
|
||||||
|
help=
|
||||||
|
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument('--lr', type=float, required=True, help='initial learning rate')
|
||||||
|
parser.add_argument('--epoch', type=int, required=True, help='number of epoch')
|
||||||
|
parser.add_argument('--data_path_prefix', type=str, required=True, help="location of the train data corpus")
|
||||||
|
parser.add_argument('--eval_data_path_prefix',
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help='location of the evaluation data corpus')
|
||||||
|
parser.add_argument('--tokenizer_path', type=str, required=True, help='location of the tokenizer')
|
||||||
|
parser.add_argument('--max_seq_length', type=int, default=512, help='sequence length')
|
||||||
|
parser.add_argument('--refresh_bucket_size',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="This param makes sure that a certain task is repeated for this time steps to \
|
||||||
|
optimise on the back propogation speed with APEX's DistributedDataParallel")
|
||||||
|
parser.add_argument("--max_predictions_per_seq",
|
||||||
|
"--max_pred",
|
||||||
|
default=80,
|
||||||
|
type=int,
|
||||||
|
help="The maximum number of masked tokens in a sequence to be predicted.")
|
||||||
|
parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="accumulation_steps")
|
||||||
|
parser.add_argument("--train_micro_batch_size_per_gpu", default=2, type=int, required=True, help="train batch size")
|
||||||
|
parser.add_argument("--eval_micro_batch_size_per_gpu", default=2, type=int, required=True, help="eval batch size")
|
||||||
|
parser.add_argument("--num_workers", default=8, type=int, help="")
|
||||||
|
parser.add_argument("--async_worker", action='store_true', help="")
|
||||||
|
parser.add_argument("--bert_config", required=True, type=str, help="location of config.json")
|
||||||
|
parser.add_argument("--wandb", action='store_true', help="use wandb to watch model")
|
||||||
|
parser.add_argument("--wandb_project_name", default='roberta', help="wandb project name")
|
||||||
|
parser.add_argument("--log_interval", default=100, type=int, help="report interval")
|
||||||
|
parser.add_argument("--log_path", type=str, required=True, help="log file which records train step")
|
||||||
|
parser.add_argument("--tensorboard_path", type=str, required=True, help="location of tensorboard file")
|
||||||
|
parser.add_argument("--colossal_config",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="colossal config, which contains zero config and so on")
|
||||||
|
parser.add_argument("--ckpt_path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="location of saving checkpoint, which contains model and optimizer")
|
||||||
|
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
|
||||||
|
parser.add_argument('--vscode_debug', action='store_true', help="use vscode to debug")
|
||||||
|
parser.add_argument('--load_pretrain_model', default='', type=str, help="location of model's checkpoin")
|
||||||
|
parser.add_argument(
|
||||||
|
'--load_optimizer_lr',
|
||||||
|
default='',
|
||||||
|
type=str,
|
||||||
|
help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step")
|
||||||
|
parser.add_argument('--resume_train', action='store_true', help="whether resume training from a early checkpoint")
|
||||||
|
parser.add_argument('--mlm', default='bert', type=str, help="model type, bert or deberta")
|
||||||
|
parser.add_argument('--checkpoint_activations', action='store_true', help="whether to use gradient checkpointing")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
@ -1,4 +1,5 @@
|
|||||||
class BertDatasetProviderInterface:
|
class BertDatasetProviderInterface:
|
||||||
|
|
||||||
def get_shard(self, index, shuffle=True):
|
def get_shard(self, index, shuffle=True):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
@ -1,9 +1,11 @@
|
|||||||
import os
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
|
||||||
from utils.global_vars import get_timers, get_tensorboard_writer
|
|
||||||
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
|
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
|
||||||
|
from tqdm import tqdm
|
||||||
|
from utils.global_vars import get_tensorboard_writer, get_timers
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, args, logger, global_step, criterion):
|
def evaluate(model, args, logger, global_step, criterion):
|
||||||
evaluate_dataset_provider = NvidiaBertDatasetProvider(args, evaluate=True)
|
evaluate_dataset_provider = NvidiaBertDatasetProvider(args, evaluate=True)
|
||||||
@ -25,7 +27,10 @@ def evaluate(model, args, logger, global_step, criterion):
|
|||||||
dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard)
|
dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard)
|
||||||
# evaluate_dataset_provider.prefetch_shard(shard + 1)
|
# evaluate_dataset_provider.prefetch_shard(shard + 1)
|
||||||
if torch.distributed.get_rank() == 0:
|
if torch.distributed.get_rank() == 0:
|
||||||
iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), colour='MAGENTA', smoothing=1)
|
iterator_data = tqdm(enumerate(dataset_iterator),
|
||||||
|
total=(total_length // args.eval_micro_batch_size_per_gpu // world_size),
|
||||||
|
colour='MAGENTA',
|
||||||
|
smoothing=1)
|
||||||
else:
|
else:
|
||||||
iterator_data = enumerate(dataset_iterator)
|
iterator_data = enumerate(dataset_iterator)
|
||||||
|
|
||||||
@ -41,7 +46,7 @@ def evaluate(model, args, logger, global_step, criterion):
|
|||||||
|
|
||||||
output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
loss = criterion(output.logits, mlm_label)#prediction_scores
|
loss = criterion(output.logits, mlm_label) #prediction_scores
|
||||||
evaluate_dataset_provider.prefetch_batch()
|
evaluate_dataset_provider.prefetch_batch()
|
||||||
|
|
||||||
eval_loss += loss.float().item()
|
eval_loss += loss.float().item()
|
@ -15,7 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch BERT model."""
|
"""PyTorch BERT model."""
|
||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
@ -27,7 +26,6 @@ import torch.utils.checkpoint
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
@ -41,6 +39,7 @@ from transformers.modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from transformers.models.bert.configuration_bert import BertConfig
|
||||||
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
@ -50,8 +49,6 @@ from transformers.utils import (
|
|||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from transformers.models.bert.configuration_bert import BertConfig
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@ -62,8 +59,7 @@ _TOKENIZER_FOR_DOC = "BertTokenizer"
|
|||||||
# TokenClassification docstring
|
# TokenClassification docstring
|
||||||
_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
|
_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
|
||||||
_TOKEN_CLASS_EXPECTED_OUTPUT = (
|
_TOKEN_CLASS_EXPECTED_OUTPUT = (
|
||||||
"['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] "
|
"['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] ")
|
||||||
)
|
|
||||||
_TOKEN_CLASS_EXPECTED_LOSS = 0.01
|
_TOKEN_CLASS_EXPECTED_LOSS = 0.01
|
||||||
|
|
||||||
# QuestionAnswering docstring
|
# QuestionAnswering docstring
|
||||||
@ -78,7 +74,6 @@ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-pol
|
|||||||
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
|
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
|
||||||
_SEQ_CLASS_EXPECTED_LOSS = 0.01
|
_SEQ_CLASS_EXPECTED_LOSS = 0.01
|
||||||
|
|
||||||
|
|
||||||
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"bert-base-uncased",
|
"bert-base-uncased",
|
||||||
"bert-large-uncased",
|
"bert-large-uncased",
|
||||||
@ -114,10 +109,8 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error(
|
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||||
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||||
"https://www.tensorflow.org/install/ for installation instructions."
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||||
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
||||||
@ -135,10 +128,8 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|||||||
name = name.split("/")
|
name = name.split("/")
|
||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
# which are not required for using pretrained model
|
# which are not required for using pretrained model
|
||||||
if any(
|
if any(n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
||||||
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
for n in name):
|
||||||
for n in name
|
|
||||||
):
|
|
||||||
logger.info(f"Skipping {'/'.join(name)}")
|
logger.info(f"Skipping {'/'.join(name)}")
|
||||||
continue
|
continue
|
||||||
pointer = model
|
pointer = model
|
||||||
@ -218,7 +209,7 @@ class BertEmbeddings(nn.Module):
|
|||||||
seq_length = input_shape[1]
|
seq_length = input_shape[1]
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
position_ids = self.position_ids[:, past_key_values_length:seq_length + past_key_values_length]
|
||||||
|
|
||||||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
||||||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
||||||
@ -245,13 +236,12 @@ class BertEmbeddings(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertSelfAttention(nn.Module):
|
class BertSelfAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config, position_embedding_type=None):
|
def __init__(self, config, position_embedding_type=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||||
raise ValueError(
|
raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
||||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
f"heads ({config.num_attention_heads})")
|
||||||
f"heads ({config.num_attention_heads})"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||||
@ -262,9 +252,7 @@ class BertSelfAttention(nn.Module):
|
|||||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
|
||||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||||
self.position_embedding_type = position_embedding_type or getattr(
|
self.position_embedding_type = position_embedding_type or getattr(config, "position_embedding_type", "absolute")
|
||||||
config, "position_embedding_type", "absolute"
|
|
||||||
)
|
|
||||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||||
self.max_position_embeddings = config.max_position_embeddings
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||||
@ -372,6 +360,7 @@ class BertSelfAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertSelfOutput(nn.Module):
|
class BertSelfOutput(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
@ -386,6 +375,7 @@ class BertSelfOutput(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertAttention(nn.Module):
|
class BertAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config, position_embedding_type=None):
|
def __init__(self, config, position_embedding_type=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
|
self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||||
@ -395,9 +385,8 @@ class BertAttention(nn.Module):
|
|||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
heads, index = find_pruneable_heads_and_indices(
|
heads, index = find_pruneable_heads_and_indices(heads, self.self.num_attention_heads,
|
||||||
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
self.self.attention_head_size, self.pruned_heads)
|
||||||
)
|
|
||||||
|
|
||||||
# Prune linear layers
|
# Prune linear layers
|
||||||
self.self.query = prune_linear_layer(self.self.query, index)
|
self.self.query = prune_linear_layer(self.self.query, index)
|
||||||
@ -435,6 +424,7 @@ class BertAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertIntermediate(nn.Module):
|
class BertIntermediate(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
@ -450,6 +440,7 @@ class BertIntermediate(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertOutput(nn.Module):
|
class BertOutput(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
@ -464,6 +455,7 @@ class BertOutput(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertLayer(nn.Module):
|
class BertLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
@ -511,8 +503,7 @@ class BertLayer(nn.Module):
|
|||||||
if not hasattr(self, "crossattention"):
|
if not hasattr(self, "crossattention"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||||
" by setting `config.add_cross_attention=True`"
|
" by setting `config.add_cross_attention=True`")
|
||||||
)
|
|
||||||
|
|
||||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
@ -532,9 +523,8 @@ class BertLayer(nn.Module):
|
|||||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||||
present_key_value = present_key_value + cross_attn_present_key_value
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||||||
|
|
||||||
layer_output = apply_chunking_to_forward(
|
layer_output = apply_chunking_to_forward(self.feed_forward_chunk, self.chunk_size_feed_forward,
|
||||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
self.seq_len_dim, attention_output)
|
||||||
)
|
|
||||||
outputs = (layer_output,) + outputs
|
outputs = (layer_output,) + outputs
|
||||||
|
|
||||||
# if decoder, return the attn key/values as the last output
|
# if decoder, return the attn key/values as the last output
|
||||||
@ -550,6 +540,7 @@ class BertLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertEncoder(nn.Module):
|
class BertEncoder(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -585,11 +576,11 @@ class BertEncoder(nn.Module):
|
|||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||||
)
|
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
|
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
return module(*inputs, past_key_value, output_attentions)
|
return module(*inputs, past_key_value, output_attentions)
|
||||||
|
|
||||||
@ -626,17 +617,13 @@ class BertEncoder(nn.Module):
|
|||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(v for v in [
|
||||||
v
|
|
||||||
for v in [
|
|
||||||
hidden_states,
|
hidden_states,
|
||||||
next_decoder_cache,
|
next_decoder_cache,
|
||||||
all_hidden_states,
|
all_hidden_states,
|
||||||
all_self_attentions,
|
all_self_attentions,
|
||||||
all_cross_attentions,
|
all_cross_attentions,
|
||||||
]
|
] if v is not None)
|
||||||
if v is not None
|
|
||||||
)
|
|
||||||
return BaseModelOutputWithPastAndCrossAttentions(
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=next_decoder_cache,
|
past_key_values=next_decoder_cache,
|
||||||
@ -647,6 +634,7 @@ class BertEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertPooler(nn.Module):
|
class BertPooler(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
@ -662,6 +650,7 @@ class BertPooler(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertPredictionHeadTransform(nn.Module):
|
class BertPredictionHeadTransform(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
@ -679,6 +668,7 @@ class BertPredictionHeadTransform(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertLMPredictionHead(nn.Module):
|
class BertLMPredictionHead(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.transform = BertPredictionHeadTransform(config)
|
self.transform = BertPredictionHeadTransform(config)
|
||||||
@ -699,6 +689,7 @@ class BertLMPredictionHead(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertOnlyMLMHead(nn.Module):
|
class BertOnlyMLMHead(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.predictions = BertLMPredictionHead(config)
|
self.predictions = BertLMPredictionHead(config)
|
||||||
@ -709,6 +700,7 @@ class BertOnlyMLMHead(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertOnlyNSPHead(nn.Module):
|
class BertOnlyNSPHead(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
||||||
@ -719,6 +711,7 @@ class BertOnlyNSPHead(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertPreTrainingHeads(nn.Module):
|
class BertPreTrainingHeads(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.predictions = BertLMPredictionHead(config)
|
self.predictions = BertLMPredictionHead(config)
|
||||||
@ -950,9 +943,8 @@ class BertModel(BertPreTrainedModel):
|
|||||||
`past_key_values`).
|
`past_key_values`).
|
||||||
"""
|
"""
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (output_hidden_states
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if self.config.is_decoder:
|
if self.config.is_decoder:
|
||||||
@ -1051,6 +1043,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class BertForPreTraining(BertPreTrainedModel):
|
class BertForPreTraining(BertPreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@ -1151,9 +1144,8 @@ class BertForPreTraining(BertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings("""Bert Model with a `language modeling` head on top for CLM fine-tuning.""",
|
||||||
"""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
|
BERT_START_DOCSTRING)
|
||||||
)
|
|
||||||
class BertLMHeadModel(BertPreTrainedModel):
|
class BertLMHeadModel(BertPreTrainedModel):
|
||||||
|
|
||||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
@ -1298,10 +1290,8 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if config.is_decoder:
|
if config.is_decoder:
|
||||||
logger.warning(
|
logger.warning("If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
|
||||||
"If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
|
"bi-directional self-attention.")
|
||||||
"bi-directional self-attention."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.bert = BertModel(config, add_pooling_layer=False)
|
self.bert = BertModel(config, add_pooling_layer=False)
|
||||||
self.cls = BertOnlyMLMHead(config)
|
self.cls = BertOnlyMLMHead(config)
|
||||||
@ -1390,9 +1380,10 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
raise ValueError("The PAD token should be defined for generation")
|
raise ValueError("The PAD token should be defined for generation")
|
||||||
|
|
||||||
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
|
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
|
||||||
dummy_token = torch.full(
|
dummy_token = torch.full((effective_batch_size, 1),
|
||||||
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
|
self.config.pad_token_id,
|
||||||
)
|
dtype=torch.long,
|
||||||
|
device=input_ids.device)
|
||||||
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
||||||
|
|
||||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||||
@ -1403,6 +1394,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class BertForNextSentencePrediction(BertPreTrainedModel):
|
class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@ -1508,15 +1500,15 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
|||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class BertForSequenceClassification(BertPreTrainedModel):
|
class BertForSequenceClassification(BertPreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.bert = BertModel(config)
|
self.bert = BertModel(config)
|
||||||
classifier_dropout = (
|
classifier_dropout = (config.classifier_dropout
|
||||||
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
if config.classifier_dropout is not None else config.hidden_dropout_prob)
|
||||||
)
|
|
||||||
self.dropout = nn.Dropout(classifier_dropout)
|
self.dropout = nn.Dropout(classifier_dropout)
|
||||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
@ -1612,13 +1604,13 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class BertForMultipleChoice(BertPreTrainedModel):
|
class BertForMultipleChoice(BertPreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.bert = BertModel(config)
|
self.bert = BertModel(config)
|
||||||
classifier_dropout = (
|
classifier_dropout = (config.classifier_dropout
|
||||||
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
if config.classifier_dropout is not None else config.hidden_dropout_prob)
|
||||||
)
|
|
||||||
self.dropout = nn.Dropout(classifier_dropout)
|
self.dropout = nn.Dropout(classifier_dropout)
|
||||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||||
|
|
||||||
@ -1658,11 +1650,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
|||||||
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||||
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||||
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||||
inputs_embeds = (
|
inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
||||||
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
if inputs_embeds is not None else None)
|
||||||
if inputs_embeds is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = self.bert(
|
outputs = self.bert(
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -1715,9 +1704,8 @@ class BertForTokenClassification(BertPreTrainedModel):
|
|||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
self.bert = BertModel(config, add_pooling_layer=False)
|
self.bert = BertModel(config, add_pooling_layer=False)
|
||||||
classifier_dropout = (
|
classifier_dropout = (config.classifier_dropout
|
||||||
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
if config.classifier_dropout is not None else config.hidden_dropout_prob)
|
||||||
)
|
|
||||||
self.dropout = nn.Dropout(classifier_dropout)
|
self.dropout = nn.Dropout(classifier_dropout)
|
||||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
@ -23,7 +23,7 @@ import torch
|
|||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
||||||
|
from transformers import FillMaskPipeline, T5ForConditionalGeneration, T5Tokenizer
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
@ -34,10 +34,14 @@ from transformers.modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.pytorch_utils import softmax_backward_data
|
|
||||||
from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
|
||||||
from transformers.models.deberta_v2.configuration_deberta_v2 import DebertaV2Config
|
from transformers.models.deberta_v2.configuration_deberta_v2 import DebertaV2Config
|
||||||
from transformers import T5Tokenizer, T5ForConditionalGeneration, FillMaskPipeline
|
from transformers.pytorch_utils import softmax_backward_data
|
||||||
|
from transformers.utils import (
|
||||||
|
add_code_sample_docstrings,
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@ -55,6 +59,7 @@ DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
|
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.ContextPooler
|
# Copied from transformers.models.deberta.modeling_deberta.ContextPooler
|
||||||
class ContextPooler(nn.Module):
|
class ContextPooler(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
|
self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
|
||||||
@ -133,15 +138,15 @@ class XSoftmax(torch.autograd.Function):
|
|||||||
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
|
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
|
||||||
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
|
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
|
||||||
)
|
)
|
||||||
output = masked_fill(
|
output = masked_fill(g, self, r_mask,
|
||||||
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
|
g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)))
|
||||||
)
|
|
||||||
output = softmax(g, output, dim)
|
output = softmax(g, output, dim)
|
||||||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
|
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
|
||||||
class DropoutContext(object):
|
class DropoutContext(object):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.dropout = 0
|
self.dropout = 0
|
||||||
self.mask = None
|
self.mask = None
|
||||||
@ -244,6 +249,7 @@ class StableDropout(nn.Module):
|
|||||||
|
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm
|
# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm
|
||||||
class DebertaV2SelfOutput(nn.Module):
|
class DebertaV2SelfOutput(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
@ -259,6 +265,7 @@ class DebertaV2SelfOutput(nn.Module):
|
|||||||
|
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2
|
# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2
|
||||||
class DebertaV2Attention(nn.Module):
|
class DebertaV2Attention(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self = DisentangledSelfAttention(config)
|
self.self = DisentangledSelfAttention(config)
|
||||||
@ -296,6 +303,7 @@ class DebertaV2Attention(nn.Module):
|
|||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2
|
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2
|
||||||
class DebertaV2Intermediate(nn.Module):
|
class DebertaV2Intermediate(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
@ -312,6 +320,7 @@ class DebertaV2Intermediate(nn.Module):
|
|||||||
|
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm
|
# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm
|
||||||
class DebertaV2Output(nn.Module):
|
class DebertaV2Output(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
@ -328,6 +337,7 @@ class DebertaV2Output(nn.Module):
|
|||||||
|
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
|
# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
|
||||||
class DebertaV2Layer(nn.Module):
|
class DebertaV2Layer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attention = DebertaV2Attention(config)
|
self.attention = DebertaV2Attention(config)
|
||||||
@ -362,14 +372,17 @@ class DebertaV2Layer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ConvLayer(nn.Module):
|
class ConvLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
kernel_size = getattr(config, "conv_kernel_size", 3)
|
kernel_size = getattr(config, "conv_kernel_size", 3)
|
||||||
groups = getattr(config, "conv_groups", 1)
|
groups = getattr(config, "conv_groups", 1)
|
||||||
self.conv_act = getattr(config, "conv_act", "tanh")
|
self.conv_act = getattr(config, "conv_act", "tanh")
|
||||||
self.conv = nn.Conv1d(
|
self.conv = nn.Conv1d(config.hidden_size,
|
||||||
config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups
|
config.hidden_size,
|
||||||
)
|
kernel_size,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
groups=groups)
|
||||||
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
|
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
|
||||||
self.dropout = StableDropout(config.hidden_dropout_prob)
|
self.dropout = StableDropout(config.hidden_dropout_prob)
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -452,9 +465,10 @@ class DebertaV2Encoder(nn.Module):
|
|||||||
def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
|
def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
|
||||||
if self.relative_attention and relative_pos is None:
|
if self.relative_attention and relative_pos is None:
|
||||||
q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
|
q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
|
||||||
relative_pos = build_relative_position(
|
relative_pos = build_relative_position(q,
|
||||||
q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions
|
hidden_states.size(-2),
|
||||||
)
|
bucket_size=self.position_buckets,
|
||||||
|
max_position=self.max_relative_positions)
|
||||||
return relative_pos
|
return relative_pos
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -491,6 +505,7 @@ class DebertaV2Encoder(nn.Module):
|
|||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
|
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
return module(*inputs, output_attentions)
|
return module(*inputs, output_attentions)
|
||||||
|
|
||||||
@ -535,9 +550,9 @@ class DebertaV2Encoder(nn.Module):
|
|||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
|
return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
|
||||||
return BaseModelOutput(
|
return BaseModelOutput(last_hidden_state=output_states,
|
||||||
last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions
|
hidden_states=all_hidden_states,
|
||||||
)
|
attentions=all_attentions)
|
||||||
|
|
||||||
|
|
||||||
def make_log_bucket_position(relative_pos, bucket_size, max_position):
|
def make_log_bucket_position(relative_pos, bucket_size, max_position):
|
||||||
@ -610,10 +625,8 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if config.hidden_size % config.num_attention_heads != 0:
|
if config.hidden_size % config.num_attention_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
||||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
f"heads ({config.num_attention_heads})")
|
||||||
f"heads ({config.num_attention_heads})"
|
|
||||||
)
|
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
_attention_head_size = config.hidden_size // config.num_attention_heads
|
_attention_head_size = config.hidden_size // config.num_attention_heads
|
||||||
self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size)
|
self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size)
|
||||||
@ -706,28 +719,22 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
|
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
|
||||||
if self.relative_attention:
|
if self.relative_attention:
|
||||||
rel_embeddings = self.pos_dropout(rel_embeddings)
|
rel_embeddings = self.pos_dropout(rel_embeddings)
|
||||||
rel_att = self.disentangled_attention_bias(
|
rel_att = self.disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings,
|
||||||
query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
|
scale_factor)
|
||||||
)
|
|
||||||
|
|
||||||
if rel_att is not None:
|
if rel_att is not None:
|
||||||
attention_scores = attention_scores + rel_att
|
attention_scores = attention_scores + rel_att
|
||||||
attention_scores = attention_scores
|
attention_scores = attention_scores
|
||||||
attention_scores = attention_scores.view(
|
attention_scores = attention_scores.view(-1, self.num_attention_heads, attention_scores.size(-2),
|
||||||
-1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
|
attention_scores.size(-1))
|
||||||
)
|
|
||||||
|
|
||||||
# bsz x height x length x dimension
|
# bsz x height x length x dimension
|
||||||
attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
|
attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
|
||||||
attention_probs = self.dropout(attention_probs)
|
attention_probs = self.dropout(attention_probs)
|
||||||
context_layer = torch.bmm(
|
context_layer = torch.bmm(attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)),
|
||||||
attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer
|
value_layer)
|
||||||
)
|
context_layer = (context_layer.view(-1, self.num_attention_heads, context_layer.size(-2),
|
||||||
context_layer = (
|
context_layer.size(-1)).permute(0, 2, 1, 3).contiguous())
|
||||||
context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1))
|
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
||||||
context_layer = context_layer.view(new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
@ -738,9 +745,10 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
|
def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
|
||||||
if relative_pos is None:
|
if relative_pos is None:
|
||||||
q = query_layer.size(-2)
|
q = query_layer.size(-2)
|
||||||
relative_pos = build_relative_position(
|
relative_pos = build_relative_position(q,
|
||||||
q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions
|
key_layer.size(-2),
|
||||||
)
|
bucket_size=self.position_buckets,
|
||||||
|
max_position=self.max_relative_positions)
|
||||||
if relative_pos.dim() == 2:
|
if relative_pos.dim() == 2:
|
||||||
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
|
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
|
||||||
elif relative_pos.dim() == 3:
|
elif relative_pos.dim() == 3:
|
||||||
@ -758,25 +766,22 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
# rel_embeddings = rel_embeddings.unsqueeze(0)
|
# rel_embeddings = rel_embeddings.unsqueeze(0)
|
||||||
# rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
|
# rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
|
||||||
if self.share_att_key:
|
if self.share_att_key:
|
||||||
pos_query_layer = self.transpose_for_scores(
|
pos_query_layer = self.transpose_for_scores(self.query_proj(rel_embeddings),
|
||||||
self.query_proj(rel_embeddings), self.num_attention_heads
|
self.num_attention_heads).repeat(
|
||||||
).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
|
query_layer.size(0) // self.num_attention_heads, 1, 1)
|
||||||
pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(
|
pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(
|
||||||
query_layer.size(0) // self.num_attention_heads, 1, 1
|
query_layer.size(0) // self.num_attention_heads, 1, 1)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if "c2p" in self.pos_att_type:
|
if "c2p" in self.pos_att_type:
|
||||||
pos_key_layer = self.transpose_for_scores(
|
pos_key_layer = self.transpose_for_scores(self.pos_key_proj(rel_embeddings),
|
||||||
self.pos_key_proj(rel_embeddings), self.num_attention_heads
|
self.num_attention_heads).repeat(
|
||||||
).repeat(
|
query_layer.size(0) // self.num_attention_heads, 1,
|
||||||
query_layer.size(0) // self.num_attention_heads, 1, 1
|
1) # .split(self.all_head_size, dim=-1)
|
||||||
) # .split(self.all_head_size, dim=-1)
|
|
||||||
if "p2c" in self.pos_att_type:
|
if "p2c" in self.pos_att_type:
|
||||||
pos_query_layer = self.transpose_for_scores(
|
pos_query_layer = self.transpose_for_scores(self.pos_query_proj(rel_embeddings),
|
||||||
self.pos_query_proj(rel_embeddings), self.num_attention_heads
|
self.num_attention_heads).repeat(
|
||||||
).repeat(
|
query_layer.size(0) // self.num_attention_heads, 1,
|
||||||
query_layer.size(0) // self.num_attention_heads, 1, 1
|
1) # .split(self.all_head_size, dim=-1)
|
||||||
) # .split(self.all_head_size, dim=-1)
|
|
||||||
|
|
||||||
score = 0
|
score = 0
|
||||||
# content->position
|
# content->position
|
||||||
@ -787,7 +792,9 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
c2p_att = torch.gather(
|
c2p_att = torch.gather(
|
||||||
c2p_att,
|
c2p_att,
|
||||||
dim=-1,
|
dim=-1,
|
||||||
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
|
index=c2p_pos.squeeze(0).expand([query_layer.size(0),
|
||||||
|
query_layer.size(1),
|
||||||
|
relative_pos.size(-1)]),
|
||||||
)
|
)
|
||||||
score += c2p_att / scale
|
score += c2p_att / scale
|
||||||
|
|
||||||
@ -810,7 +817,9 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
p2c_att = torch.gather(
|
p2c_att = torch.gather(
|
||||||
p2c_att,
|
p2c_att,
|
||||||
dim=-1,
|
dim=-1,
|
||||||
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
|
index=p2c_pos.squeeze(0).expand([query_layer.size(0),
|
||||||
|
key_layer.size(-2),
|
||||||
|
key_layer.size(-2)]),
|
||||||
).transpose(-1, -2)
|
).transpose(-1, -2)
|
||||||
score += p2c_att / scale
|
score += p2c_att / scale
|
||||||
|
|
||||||
@ -990,6 +999,7 @@ DEBERTA_INPUTS_DOCSTRING = r"""
|
|||||||
)
|
)
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2
|
# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2
|
||||||
class DebertaV2Model(DebertaV2PreTrainedModel):
|
class DebertaV2Model(DebertaV2PreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@ -1032,9 +1042,8 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
|
|||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, BaseModelOutput]:
|
) -> Union[Tuple, BaseModelOutput]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (output_hidden_states
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
@ -1091,7 +1100,7 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
|
|||||||
sequence_output = encoded_layers[-1]
|
sequence_output = encoded_layers[-1]
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :]
|
return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2):]
|
||||||
|
|
||||||
return BaseModelOutput(
|
return BaseModelOutput(
|
||||||
last_hidden_state=sequence_output,
|
last_hidden_state=sequence_output,
|
||||||
@ -1182,6 +1191,7 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
|
|||||||
|
|
||||||
# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta
|
# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta
|
||||||
class DebertaV2PredictionHeadTransform(nn.Module):
|
class DebertaV2PredictionHeadTransform(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
@ -1200,6 +1210,7 @@ class DebertaV2PredictionHeadTransform(nn.Module):
|
|||||||
|
|
||||||
# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta
|
# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta
|
||||||
class DebertaV2LMPredictionHead(nn.Module):
|
class DebertaV2LMPredictionHead(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.transform = DebertaV2PredictionHeadTransform(config)
|
self.transform = DebertaV2PredictionHeadTransform(config)
|
||||||
@ -1221,6 +1232,7 @@ class DebertaV2LMPredictionHead(nn.Module):
|
|||||||
|
|
||||||
# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta
|
# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta
|
||||||
class DebertaV2OnlyMLMHead(nn.Module):
|
class DebertaV2OnlyMLMHead(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.predictions = DebertaV2LMPredictionHead(config)
|
self.predictions = DebertaV2LMPredictionHead(config)
|
||||||
@ -1239,6 +1251,7 @@ class DebertaV2OnlyMLMHead(nn.Module):
|
|||||||
)
|
)
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2
|
# Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2
|
||||||
class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
|
class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@ -1318,9 +1331,8 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
|
|||||||
label_index = (labels >= 0).nonzero()
|
label_index = (labels >= 0).nonzero()
|
||||||
labels = labels.long()
|
labels = labels.long()
|
||||||
if label_index.size(0) > 0:
|
if label_index.size(0) > 0:
|
||||||
labeled_logits = torch.gather(
|
labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0),
|
||||||
logits, 0, label_index.expand(label_index.size(0), logits.size(1))
|
logits.size(1)))
|
||||||
)
|
|
||||||
labels = torch.gather(labels, 0, label_index.view(-1))
|
labels = torch.gather(labels, 0, label_index.view(-1))
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
|
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
|
||||||
@ -1345,9 +1357,10 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
|
|||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return SequenceClassifierOutput(
|
return SequenceClassifierOutput(loss=loss,
|
||||||
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
logits=logits,
|
||||||
)
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@ -1422,9 +1435,10 @@ class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
|
|||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return TokenClassifierOutput(
|
return TokenClassifierOutput(loss=loss,
|
||||||
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
logits=logits,
|
||||||
)
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@ -1536,6 +1550,7 @@ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
|
|||||||
DEBERTA_START_DOCSTRING,
|
DEBERTA_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel):
|
class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@ -1591,11 +1606,8 @@ class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel):
|
|||||||
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||||
flat_inputs_embeds = (
|
flat_inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
||||||
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
if inputs_embeds is not None else None)
|
||||||
if inputs_embeds is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = self.deberta(
|
outputs = self.deberta(
|
||||||
flat_input_ids,
|
flat_input_ids,
|
@ -1,24 +1,25 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import h5py
|
|
||||||
import logging
|
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
|
||||||
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.utils.data import DataLoader, Dataset
|
|
||||||
from torch.utils.data.sampler import RandomSampler
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
|
|
||||||
from bert_dataset_provider import BertDatasetProviderInterface
|
from bert_dataset_provider import BertDatasetProviderInterface
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from torch.utils.data.sampler import RandomSampler
|
||||||
|
|
||||||
import colossalai.utils as utils
|
import colossalai.utils as utils
|
||||||
|
|
||||||
|
|
||||||
# Workaround because python functions are not picklable
|
# Workaround because python functions are not picklable
|
||||||
class WorkerInitObj(object):
|
class WorkerInitObj(object):
|
||||||
|
|
||||||
def __init__(self, seed):
|
def __init__(self, seed):
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
@ -27,29 +28,25 @@ class WorkerInitObj(object):
|
|||||||
random.seed(self.seed + id)
|
random.seed(self.seed + id)
|
||||||
|
|
||||||
|
|
||||||
def create_pretraining_dataset(input_file, max_predictions_per_seq,
|
def create_pretraining_dataset(input_file, max_predictions_per_seq, num_workers, train_batch_size, worker_init,
|
||||||
num_workers, train_batch_size, worker_init,
|
|
||||||
data_sampler):
|
data_sampler):
|
||||||
train_data = pretraining_dataset(
|
train_data = pretraining_dataset(input_file=input_file, max_predictions_per_seq=max_predictions_per_seq)
|
||||||
input_file=input_file, max_predictions_per_seq=max_predictions_per_seq)
|
|
||||||
train_dataloader = DataLoader(train_data,
|
train_dataloader = DataLoader(train_data,
|
||||||
sampler=data_sampler(train_data),
|
sampler=data_sampler(train_data),
|
||||||
batch_size=train_batch_size,
|
batch_size=train_batch_size,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
worker_init_fn=worker_init,
|
worker_init_fn=worker_init,
|
||||||
pin_memory=True
|
pin_memory=True)
|
||||||
)
|
|
||||||
return train_dataloader, len(train_data)
|
return train_dataloader, len(train_data)
|
||||||
|
|
||||||
|
|
||||||
class pretraining_dataset(Dataset):
|
class pretraining_dataset(Dataset):
|
||||||
|
|
||||||
def __init__(self, input_file, max_predictions_per_seq):
|
def __init__(self, input_file, max_predictions_per_seq):
|
||||||
self.input_file = input_file
|
self.input_file = input_file
|
||||||
self.max_predictions_per_seq = max_predictions_per_seq
|
self.max_predictions_per_seq = max_predictions_per_seq
|
||||||
f = h5py.File(input_file, "r")
|
f = h5py.File(input_file, "r")
|
||||||
keys = [
|
keys = ['input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions']
|
||||||
'input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions'
|
|
||||||
]
|
|
||||||
self.inputs = [np.asarray(f[key][:]) for key in keys]
|
self.inputs = [np.asarray(f[key][:]) for key in keys]
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
@ -59,21 +56,16 @@ class pretraining_dataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
|
||||||
[
|
[input_ids, input_mask, segment_ids, masked_lm_labels] = [
|
||||||
input_ids, input_mask, segment_ids, masked_lm_labels
|
torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else torch.from_numpy(
|
||||||
] = [
|
np.asarray(input[index].astype(np.int64))) for indice, input in enumerate(self.inputs)
|
||||||
torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else
|
|
||||||
torch.from_numpy(np.asarray(input[index].astype(np.int64)))
|
|
||||||
for indice, input in enumerate(self.inputs)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return [
|
return [input_ids, input_mask, segment_ids, masked_lm_labels]
|
||||||
input_ids, input_mask,
|
|
||||||
segment_ids, masked_lm_labels
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
|
class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
|
||||||
|
|
||||||
def __init__(self, args, evaluate=False):
|
def __init__(self, args, evaluate=False):
|
||||||
self.num_workers = args.num_workers
|
self.num_workers = args.num_workers
|
||||||
self.max_seq_length = args.max_seq_length
|
self.max_seq_length = args.max_seq_length
|
||||||
@ -92,13 +84,15 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
|
|||||||
# Initialize dataset files
|
# Initialize dataset files
|
||||||
if not evaluate:
|
if not evaluate:
|
||||||
self.dataset_files = [
|
self.dataset_files = [
|
||||||
os.path.join(args.data_path_prefix, f) for f in os.listdir(args.data_path_prefix) if
|
os.path.join(args.data_path_prefix, f)
|
||||||
os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f
|
for f in os.listdir(args.data_path_prefix)
|
||||||
|
if os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
self.dataset_files = [
|
self.dataset_files = [
|
||||||
os.path.join(args.eval_data_path_prefix, f) for f in os.listdir(args.eval_data_path_prefix) if
|
os.path.join(args.eval_data_path_prefix, f)
|
||||||
os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f
|
for f in os.listdir(args.eval_data_path_prefix)
|
||||||
|
if os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f
|
||||||
]
|
]
|
||||||
|
|
||||||
self.dataset_files.sort()
|
self.dataset_files.sort()
|
||||||
@ -114,9 +108,7 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
|
|||||||
self.shuffle = True
|
self.shuffle = True
|
||||||
|
|
||||||
if self.global_rank == 0:
|
if self.global_rank == 0:
|
||||||
self.logger.info(
|
self.logger.info(f"NvidiaBertDatasetProvider - Initialization: num_files = {self.num_files}")
|
||||||
f"NvidiaBertDatasetProvider - Initialization: num_files = {self.num_files}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_shard(self, index):
|
def get_shard(self, index):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@ -130,8 +122,7 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
|
|||||||
worker_init=self.worker_init,
|
worker_init=self.worker_init,
|
||||||
data_sampler=self.data_sampler)
|
data_sampler=self.data_sampler)
|
||||||
else:
|
else:
|
||||||
self.train_dataloader, sample_count = self.dataset_future.result(
|
self.train_dataloader, sample_count = self.dataset_future.result(timeout=None)
|
||||||
timeout=None)
|
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"Data Loading Completed for Pretraining Data from {self.data_file} with {sample_count} samples took {time.time()-start:.2f}s."
|
f"Data Loading Completed for Pretraining Data from {self.data_file} with {sample_count} samples took {time.time()-start:.2f}s."
|
||||||
@ -145,10 +136,8 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
|
|||||||
|
|
||||||
def prefetch_shard(self, index):
|
def prefetch_shard(self, index):
|
||||||
self.data_file = self._get_shard_file(index)
|
self.data_file = self._get_shard_file(index)
|
||||||
self.dataset_future = self.pool.submit(
|
self.dataset_future = self.pool.submit(create_pretraining_dataset, self.data_file, self.max_predictions_per_seq,
|
||||||
create_pretraining_dataset, self.data_file,
|
self.num_workers, self.train_micro_batch_size_per_gpu, self.worker_init,
|
||||||
self.max_predictions_per_seq, self.num_workers,
|
|
||||||
self.train_micro_batch_size_per_gpu, self.worker_init,
|
|
||||||
self.data_sampler)
|
self.data_sampler)
|
||||||
|
|
||||||
def get_batch(self, batch_iter):
|
def get_batch(self, batch_iter):
|
||||||
@ -179,4 +168,3 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
|
|||||||
indices = torch.randperm(self.num_files, generator=g).tolist()
|
indices = torch.randperm(self.num_files, generator=g).tolist()
|
||||||
new_dataset = [self.dataset_files[i] for i in indices]
|
new_dataset = [self.dataset_files[i] for i in indices]
|
||||||
self.dataset_files = new_dataset
|
self.dataset_files = new_dataset
|
||||||
|
|
@ -1,23 +1,32 @@
|
|||||||
import transformers
|
|
||||||
import logging
|
import logging
|
||||||
from colossalai.nn.lr_scheduler import LinearWarmupLR
|
|
||||||
from transformers import get_linear_schedule_with_warmup
|
|
||||||
from transformers import BertForPreTraining, RobertaForMaskedLM, RobertaConfig
|
|
||||||
from transformers import GPT2Config, GPT2LMHeadModel
|
|
||||||
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
|
||||||
from colossalai.nn.optimizer import FusedAdam, HybridAdam
|
|
||||||
from torch.optim import AdamW
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
import torch
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
sys.path.append(os.getcwd())
|
|
||||||
from model.deberta_v2 import DebertaV2ForMaskedLM
|
|
||||||
from model.bert import BertForMaskedLM
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForMaskedLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
BertForPreTraining,
|
||||||
|
GPT2Config,
|
||||||
|
GPT2LMHeadModel,
|
||||||
|
RobertaConfig,
|
||||||
|
RobertaForMaskedLM,
|
||||||
|
get_linear_schedule_with_warmup,
|
||||||
|
)
|
||||||
|
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.nn.lr_scheduler import LinearWarmupLR
|
||||||
|
from colossalai.nn.optimizer import FusedAdam, HybridAdam
|
||||||
|
|
||||||
|
sys.path.append(os.getcwd())
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from model.bert import BertForMaskedLM
|
||||||
|
from model.deberta_v2 import DebertaV2ForMaskedLM
|
||||||
|
|
||||||
__all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining']
|
__all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining']
|
||||||
|
|
||||||
|
|
||||||
@ -30,6 +39,7 @@ def get_new_state_dict(state_dict, start_index=13):
|
|||||||
|
|
||||||
|
|
||||||
class LMModel(nn.Module):
|
class LMModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self, model, config, args):
|
def __init__(self, model, config, args):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -58,9 +68,11 @@ def get_model(args, logger):
|
|||||||
if len(args.load_pretrain_model) > 0:
|
if len(args.load_pretrain_model) > 0:
|
||||||
assert os.path.exists(args.load_pretrain_model)
|
assert os.path.exists(args.load_pretrain_model)
|
||||||
# load_checkpoint(args.load_pretrain_model, model, strict=False)
|
# load_checkpoint(args.load_pretrain_model, model, strict=False)
|
||||||
m_state_dict = torch.load(args.load_pretrain_model, map_location=torch.device(f"cuda:{torch.cuda.current_device()}"))
|
m_state_dict = torch.load(args.load_pretrain_model,
|
||||||
|
map_location=torch.device(f"cuda:{torch.cuda.current_device()}"))
|
||||||
# new_state_dict = get_new_state_dict(m_state_dict)
|
# new_state_dict = get_new_state_dict(m_state_dict)
|
||||||
model.load_state_dict(m_state_dict, strict=True) # must insure that every process have identical parameters !!!!!!!
|
model.load_state_dict(m_state_dict,
|
||||||
|
strict=True) # must insure that every process have identical parameters !!!!!!!
|
||||||
logger.info("load model success")
|
logger.info("load model success")
|
||||||
|
|
||||||
numel = sum([p.numel() for p in model.parameters()])
|
numel = sum([p.numel() for p in model.parameters()])
|
||||||
@ -89,7 +101,10 @@ def get_optimizer(model, lr):
|
|||||||
|
|
||||||
def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1):
|
def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1):
|
||||||
# warmup_steps = int(total_steps * warmup_ratio)
|
# warmup_steps = int(total_steps * warmup_ratio)
|
||||||
lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, last_epoch=last_epoch)
|
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
|
||||||
|
num_warmup_steps=warmup_steps,
|
||||||
|
num_training_steps=total_steps,
|
||||||
|
last_epoch=last_epoch)
|
||||||
# lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps)
|
# lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps)
|
||||||
return lr_scheduler
|
return lr_scheduler
|
||||||
|
|
||||||
@ -107,6 +122,3 @@ def save_ckpt(model, optimizer, lr_scheduler, path, epoch, shard, global_step):
|
|||||||
if gpc.get_global_rank() == 0:
|
if gpc.get_global_rank() == 0:
|
||||||
torch.save(checkpoint, optimizer_lr_path)
|
torch.save(checkpoint, optimizer_lr_path)
|
||||||
torch.save(model_state, model_path)
|
torch.save(model_state, model_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -35,4 +35,3 @@ env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \
|
|||||||
--mlm bert \
|
--mlm bert \
|
||||||
--wandb \
|
--wandb \
|
||||||
--checkpoint_activations \
|
--checkpoint_activations \
|
||||||
|
|
@ -38,4 +38,3 @@ env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \
|
|||||||
--resume_train \
|
--resume_train \
|
||||||
--load_pretrain_model /ckpt/1.pt \
|
--load_pretrain_model /ckpt/1.pt \
|
||||||
--load_optimizer_lr /ckpt/1.op_lrs \
|
--load_optimizer_lr /ckpt/1.op_lrs \
|
||||||
|
|
@ -4,21 +4,6 @@ import time
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from functools import partial
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.context import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
||||||
from colossalai.zero import ZeroOptimizer
|
|
||||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
|
||||||
|
|
||||||
from arguments import parse_args
|
from arguments import parse_args
|
||||||
from evaluation import evaluate
|
from evaluation import evaluate
|
||||||
from loss import LossForPretraining
|
from loss import LossForPretraining
|
||||||
@ -30,6 +15,15 @@ from utils.exp_util import get_mem_info, get_tflops, log_args, throughput_calcul
|
|||||||
from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables
|
from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables
|
||||||
from utils.logger import Logger
|
from utils.logger import Logger
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
|
||||||
|
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
from colossalai.zero import ZeroOptimizer
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
@ -156,14 +150,15 @@ def main():
|
|||||||
start_epoch = o_l_state_dict['epoch']
|
start_epoch = o_l_state_dict['epoch']
|
||||||
start_shard = o_l_state_dict['shard'] + 1
|
start_shard = o_l_state_dict['shard'] + 1
|
||||||
# global_step = o_l_state_dict['global_step'] + 1
|
# global_step = o_l_state_dict['global_step'] + 1
|
||||||
logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}')
|
logger.info(
|
||||||
|
f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}'
|
||||||
|
)
|
||||||
|
|
||||||
criterion = LossForPretraining(config.vocab_size)
|
criterion = LossForPretraining(config.vocab_size)
|
||||||
|
|
||||||
# build dataloader
|
# build dataloader
|
||||||
pretrain_dataset_provider = NvidiaBertDatasetProvider(args)
|
pretrain_dataset_provider = NvidiaBertDatasetProvider(args)
|
||||||
|
|
||||||
|
|
||||||
logger.info(get_mem_info(prefix='After init model, '))
|
logger.info(get_mem_info(prefix='After init model, '))
|
||||||
|
|
||||||
best_loss = None
|
best_loss = None
|
||||||
@ -242,8 +237,9 @@ def main():
|
|||||||
logger.info('*' * 100)
|
logger.info('*' * 100)
|
||||||
|
|
||||||
eval_loss += evaluate(model, args, logger, global_step, criterion)
|
eval_loss += evaluate(model, args, logger, global_step, criterion)
|
||||||
save_ckpt(model, optimizer, lr_scheduler, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step)
|
save_ckpt(model, optimizer, lr_scheduler,
|
||||||
|
os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch,
|
||||||
|
shard, global_step)
|
||||||
|
|
||||||
eval_loss /= len(os.listdir(args.data_path_prefix))
|
eval_loss /= len(os.listdir(args.data_path_prefix))
|
||||||
logger.info(
|
logger.info(
|
@ -1,8 +1,10 @@
|
|||||||
import time
|
|
||||||
import wandb
|
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import wandb
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
|
||||||
class WandbLog:
|
class WandbLog:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -38,9 +40,3 @@ class TensorboardLog:
|
|||||||
def log_zeroshot(self, result, step):
|
def log_zeroshot(self, result, step):
|
||||||
for k, v in result.items():
|
for k, v in result.items():
|
||||||
self.writer.add_scalar(f'{k}_acc/eval', v, step)
|
self.writer.add_scalar(f'{k}_acc/eval', v, step)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,9 +1,13 @@
|
|||||||
import functools
|
import functools
|
||||||
import os, shutil
|
import os
|
||||||
import torch
|
import shutil
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
|
import torch
|
||||||
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
|
|
||||||
def logging(s, log_path, print_=True, log_=True):
|
def logging(s, log_path, print_=True, log_=True):
|
||||||
if print_:
|
if print_:
|
||||||
print(s)
|
print(s)
|
||||||
@ -11,9 +15,11 @@ def logging(s, log_path, print_=True, log_=True):
|
|||||||
with open(log_path, 'a+') as f_log:
|
with open(log_path, 'a+') as f_log:
|
||||||
f_log.write(s + '\n')
|
f_log.write(s + '\n')
|
||||||
|
|
||||||
|
|
||||||
def get_logger(log_path, **kwargs):
|
def get_logger(log_path, **kwargs):
|
||||||
return functools.partial(logging, log_path=log_path, **kwargs)
|
return functools.partial(logging, log_path=log_path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
|
def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
|
||||||
if debug:
|
if debug:
|
||||||
print('Debug Mode : no experiment dir created')
|
print('Debug Mode : no experiment dir created')
|
||||||
@ -33,6 +39,7 @@ def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
|
|||||||
|
|
||||||
return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
|
return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
|
||||||
|
|
||||||
|
|
||||||
def get_cpu_mem():
|
def get_cpu_mem():
|
||||||
return psutil.Process().memory_info().rss / 1024**2
|
return psutil.Process().memory_info().rss / 1024**2
|
||||||
|
|
||||||
@ -52,11 +59,15 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
|
|||||||
def get_parameters_in_billions(model, world_size=1):
|
def get_parameters_in_billions(model, world_size=1):
|
||||||
gpus_per_model = world_size
|
gpus_per_model = world_size
|
||||||
|
|
||||||
approx_parameters_in_billions = sum([sum([p.ds_numel if hasattr(p,'ds_id') else p.nelement() for p in model_module.parameters()])
|
approx_parameters_in_billions = sum([
|
||||||
for model_module in model])
|
sum([p.ds_numel if hasattr(p, 'ds_id') else p.nelement()
|
||||||
|
for p in model_module.parameters()])
|
||||||
|
for model_module in model
|
||||||
|
])
|
||||||
|
|
||||||
return approx_parameters_in_billions * gpus_per_model / (1e9)
|
return approx_parameters_in_billions * gpus_per_model / (1e9)
|
||||||
|
|
||||||
|
|
||||||
def throughput_calculator(numel, args, config, iteration_time, total_iterations, world_size=1):
|
def throughput_calculator(numel, args, config, iteration_time, total_iterations, world_size=1):
|
||||||
gpus_per_model = 1
|
gpus_per_model = 1
|
||||||
batch_size = args.train_micro_batch_size_per_gpu
|
batch_size = args.train_micro_batch_size_per_gpu
|
||||||
@ -76,10 +87,13 @@ def throughput_calculator(numel, args, config, iteration_time, total_iterations,
|
|||||||
# The factor of 4 is when used with activation check-pointing,
|
# The factor of 4 is when used with activation check-pointing,
|
||||||
# otherwise it will be 3.
|
# otherwise it will be 3.
|
||||||
checkpoint_activations_factor = 4 if args.checkpoint_activations else 3
|
checkpoint_activations_factor = 4 if args.checkpoint_activations else 3
|
||||||
flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers * (hidden_size**2)) * (1. + (args.max_seq_length / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size)))
|
flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers *
|
||||||
|
(hidden_size**2)) * (1. + (args.max_seq_length / (6. * hidden_size)) +
|
||||||
|
(vocab_size / (16. * num_layers * hidden_size)))
|
||||||
tflops = flops_per_iteration / (elapsed_time_per_iter * (10**12))
|
tflops = flops_per_iteration / (elapsed_time_per_iter * (10**12))
|
||||||
return samples_per_second, tflops, approx_parameters_in_billions
|
return samples_per_second, tflops, approx_parameters_in_billions
|
||||||
|
|
||||||
|
|
||||||
def synchronize():
|
def synchronize():
|
||||||
if not torch.distributed.is_available():
|
if not torch.distributed.is_available():
|
||||||
return
|
return
|
||||||
@ -90,6 +104,7 @@ def synchronize():
|
|||||||
return
|
return
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
|
||||||
def log_args(logger, args):
|
def log_args(logger, args):
|
||||||
logger.info('--------args----------')
|
logger.info('--------args----------')
|
||||||
message = '\n'.join([f'{k:<30}: {v}' for k, v in vars(args).items()])
|
message = '\n'.join([f'{k:<30}: {v}' for k, v in vars(args).items()])
|
@ -1,5 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .WandbLog import TensorboardLog
|
from .WandbLog import TensorboardLog
|
||||||
|
|
||||||
_GLOBAL_TIMERS = None
|
_GLOBAL_TIMERS = None
|
||||||
@ -10,30 +12,34 @@ def set_global_variables(launch_time, tensorboard_path):
|
|||||||
_set_timers()
|
_set_timers()
|
||||||
_set_tensorboard_writer(launch_time, tensorboard_path)
|
_set_tensorboard_writer(launch_time, tensorboard_path)
|
||||||
|
|
||||||
|
|
||||||
def _set_timers():
|
def _set_timers():
|
||||||
"""Initialize timers."""
|
"""Initialize timers."""
|
||||||
global _GLOBAL_TIMERS
|
global _GLOBAL_TIMERS
|
||||||
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
|
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
|
||||||
_GLOBAL_TIMERS = Timers()
|
_GLOBAL_TIMERS = Timers()
|
||||||
|
|
||||||
|
|
||||||
def _set_tensorboard_writer(launch_time, tensorboard_path):
|
def _set_tensorboard_writer(launch_time, tensorboard_path):
|
||||||
"""Set tensorboard writer."""
|
"""Set tensorboard writer."""
|
||||||
global _GLOBAL_TENSORBOARD_WRITER
|
global _GLOBAL_TENSORBOARD_WRITER
|
||||||
_ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER,
|
_ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, 'tensorboard writer')
|
||||||
'tensorboard writer')
|
|
||||||
if torch.distributed.get_rank() == 0:
|
if torch.distributed.get_rank() == 0:
|
||||||
_GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f'/{launch_time}', launch_time)
|
_GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f'/{launch_time}', launch_time)
|
||||||
|
|
||||||
|
|
||||||
def get_timers():
|
def get_timers():
|
||||||
"""Return timers."""
|
"""Return timers."""
|
||||||
_ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
|
_ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
|
||||||
return _GLOBAL_TIMERS
|
return _GLOBAL_TIMERS
|
||||||
|
|
||||||
|
|
||||||
def get_tensorboard_writer():
|
def get_tensorboard_writer():
|
||||||
"""Return tensorboard writer. It can be None so no need
|
"""Return tensorboard writer. It can be None so no need
|
||||||
to check if it is initialized."""
|
to check if it is initialized."""
|
||||||
return _GLOBAL_TENSORBOARD_WRITER
|
return _GLOBAL_TENSORBOARD_WRITER
|
||||||
|
|
||||||
|
|
||||||
def _ensure_var_is_initialized(var, name):
|
def _ensure_var_is_initialized(var, name):
|
||||||
"""Make sure the input variable is not None."""
|
"""Make sure the input variable is not None."""
|
||||||
assert var is not None, '{} is not initialized.'.format(name)
|
assert var is not None, '{} is not initialized.'.format(name)
|
||||||
@ -115,12 +121,10 @@ class Timers:
|
|||||||
assert normalizer > 0.0
|
assert normalizer > 0.0
|
||||||
string = 'time (ms)'
|
string = 'time (ms)'
|
||||||
for name in names:
|
for name in names:
|
||||||
elapsed_time = self.timers[name].elapsed(
|
elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
|
||||||
reset=reset) * 1000.0 / normalizer
|
|
||||||
string += ' | {}: {:.2f}'.format(name, elapsed_time)
|
string += ' | {}: {:.2f}'.format(name, elapsed_time)
|
||||||
if torch.distributed.is_initialized():
|
if torch.distributed.is_initialized():
|
||||||
if torch.distributed.get_rank() == (
|
if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1):
|
||||||
torch.distributed.get_world_size() - 1):
|
|
||||||
print(string, flush=True)
|
print(string, flush=True)
|
||||||
else:
|
else:
|
||||||
print(string, flush=True)
|
print(string, flush=True)
|
@ -1,22 +1,22 @@
|
|||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
|
||||||
datefmt='%m/%d/%Y %H:%M:%S',
|
datefmt='%m/%d/%Y %H:%M:%S',
|
||||||
level=logging.INFO)
|
level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Logger():
|
class Logger():
|
||||||
|
|
||||||
def __init__(self, log_path, cuda=False, debug=False):
|
def __init__(self, log_path, cuda=False, debug=False):
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
self.cuda = cuda
|
self.cuda = cuda
|
||||||
self.log_path = log_path
|
self.log_path = log_path
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
|
|
||||||
|
|
||||||
def info(self, message, log_=True, print_=True, *args, **kwargs):
|
def info(self, message, log_=True, print_=True, *args, **kwargs):
|
||||||
if (self.cuda and dist.get_rank() == 0) or not self.cuda:
|
if (self.cuda and dist.get_rank() == 0) or not self.cuda:
|
||||||
if print_:
|
if print_:
|
||||||
@ -26,6 +26,5 @@ class Logger():
|
|||||||
with open(self.log_path, 'a+') as f_log:
|
with open(self.log_path, 'a+') as f_log:
|
||||||
f_log.write(message + '\n')
|
f_log.write(message + '\n')
|
||||||
|
|
||||||
|
|
||||||
def error(self, message, *args, **kwargs):
|
def error(self, message, *args, **kwargs):
|
||||||
self.logger.error(message, *args, **kwargs)
|
self.logger.error(message, *args, **kwargs)
|
@ -1,184 +0,0 @@
|
|||||||
#include <algorithm>
|
|
||||||
#include <iostream>
|
|
||||||
#include <limits>
|
|
||||||
#include <math.h>
|
|
||||||
#include <stdexcept>
|
|
||||||
#include <pybind11/pybind11.h>
|
|
||||||
#include <pybind11/numpy.h>
|
|
||||||
#include <random>
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <pybind11/stl.h>
|
|
||||||
#include <chrono>
|
|
||||||
#include <tuple>
|
|
||||||
#include <unordered_set>
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
const int32_t LONG_SENTENCE_LEN = 512;
|
|
||||||
|
|
||||||
struct MaskedLMInstance {
|
|
||||||
int index;
|
|
||||||
std::string label;
|
|
||||||
MaskedLMInstance(int index, std::string label) {
|
|
||||||
this->index = index;
|
|
||||||
this->label = label;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
auto get_new_segment(std::vector<std::string> segment, std::vector<std::string> segment_jieba, const std::vector<bool> chinese_vocab) { // const std::unordered_set<std::string> &chinese_vocab
|
|
||||||
std::unordered_set<std::string> seq_cws_dict;
|
|
||||||
for (auto word : segment_jieba) {
|
|
||||||
seq_cws_dict.insert(word);
|
|
||||||
}
|
|
||||||
int i = 0;
|
|
||||||
std::vector<std::string> new_segment;
|
|
||||||
int segment_size = segment.size();
|
|
||||||
while (i < segment_size) {
|
|
||||||
if (!chinese_vocab[i]) { //chinese_vocab.find(segment[i]) == chinese_vocab.end()
|
|
||||||
new_segment.emplace_back(segment[i]);
|
|
||||||
i += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
bool has_add = false;
|
|
||||||
for (int length = 3; length >= 1; length--) {
|
|
||||||
if (i + length > segment_size) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
std::string chinese_word = "";
|
|
||||||
for (int j = i; j < i + length; j++) {
|
|
||||||
chinese_word += segment[j];
|
|
||||||
}
|
|
||||||
if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) {
|
|
||||||
new_segment.emplace_back(segment[i]);
|
|
||||||
for (int j = i + 1; j < i + length; j++) {
|
|
||||||
new_segment.emplace_back("##" + segment[j]);
|
|
||||||
}
|
|
||||||
i += length;
|
|
||||||
has_add = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!has_add) {
|
|
||||||
new_segment.emplace_back(segment[i]);
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return new_segment;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool startsWith(const std::string& s, const std::string& sub) {
|
|
||||||
return s.find(sub) == 0 ? true : false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto create_whole_masked_lm_predictions(std::vector<std::string> &tokens,
|
|
||||||
const std::vector<std::string> &original_tokens,
|
|
||||||
const std::vector<std::string> &vocab_words,
|
|
||||||
std::map<std::string, int> &vocab,
|
|
||||||
const int max_predictions_per_seq,
|
|
||||||
const double masked_lm_prob) {
|
|
||||||
// for (auto item : vocab) {
|
|
||||||
// std::cout << "key=" << std::string(py::str(item.first)) << ", "
|
|
||||||
// << "value=" << std::string(py::str(item.second)) << std::endl;
|
|
||||||
// }
|
|
||||||
std::vector<std::vector<int> > cand_indexes;
|
|
||||||
std::vector<int> cand_temp;
|
|
||||||
int tokens_size = tokens.size();
|
|
||||||
std::string prefix = "##";
|
|
||||||
bool do_whole_masked = true;
|
|
||||||
|
|
||||||
for (int i = 0; i < tokens_size; i++) {
|
|
||||||
if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (do_whole_masked && (cand_indexes.size() > 0) && (tokens[i].rfind(prefix, 0) == 0)) {
|
|
||||||
cand_temp.emplace_back(i);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
if (cand_temp.size() > 0) {
|
|
||||||
cand_indexes.emplace_back(cand_temp);
|
|
||||||
}
|
|
||||||
cand_temp.clear();
|
|
||||||
cand_temp.emplace_back(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto seed = std::chrono::system_clock::now().time_since_epoch().count();
|
|
||||||
std::shuffle(cand_indexes.begin(), cand_indexes.end(), std::default_random_engine(seed));
|
|
||||||
// for (auto i : cand_indexes) {
|
|
||||||
// for (auto j : i) {
|
|
||||||
// std::cout << tokens[j] << " ";
|
|
||||||
// }
|
|
||||||
// std::cout << std::endl;
|
|
||||||
// }
|
|
||||||
// for (auto i : output_tokens) {
|
|
||||||
// std::cout << i;
|
|
||||||
// }
|
|
||||||
// std::cout << std::endl;
|
|
||||||
|
|
||||||
int num_to_predict = std::min(max_predictions_per_seq,
|
|
||||||
std::max(1, int(tokens_size * masked_lm_prob)));
|
|
||||||
// std::cout << num_to_predict << std::endl;
|
|
||||||
|
|
||||||
std::set<int> covered_indexes;
|
|
||||||
std::vector<int> masked_lm_output(tokens_size, -1);
|
|
||||||
int vocab_words_len = vocab_words.size();
|
|
||||||
std::default_random_engine e(seed);
|
|
||||||
std::uniform_real_distribution<double> u1(0.0, 1.0);
|
|
||||||
std::uniform_int_distribution<unsigned> u2(0, vocab_words_len - 1);
|
|
||||||
int mask_cnt = 0;
|
|
||||||
std::vector<std::string> output_tokens;
|
|
||||||
output_tokens = original_tokens;
|
|
||||||
|
|
||||||
for (auto index_set : cand_indexes) {
|
|
||||||
if (mask_cnt > num_to_predict) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
int index_set_size = index_set.size();
|
|
||||||
if (mask_cnt + index_set_size > num_to_predict) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
bool is_any_index_covered = false;
|
|
||||||
for (auto index : index_set) {
|
|
||||||
if (covered_indexes.find(index) != covered_indexes.end()) {
|
|
||||||
is_any_index_covered = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (is_any_index_covered) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (auto index : index_set) {
|
|
||||||
|
|
||||||
covered_indexes.insert(index);
|
|
||||||
std::string masked_token;
|
|
||||||
if (u1(e) < 0.8) {
|
|
||||||
masked_token = "[MASK]";
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
if (u1(e) < 0.5) {
|
|
||||||
masked_token = output_tokens[index];
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
int random_index = u2(e);
|
|
||||||
masked_token = vocab_words[random_index];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index]));
|
|
||||||
masked_lm_output[index] = vocab[output_tokens[index]];
|
|
||||||
output_tokens[index] = masked_token;
|
|
||||||
mask_cnt++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// for (auto p : masked_lms) {
|
|
||||||
// masked_lm_output[p.index] = vocab[p.label];
|
|
||||||
// }
|
|
||||||
return std::make_tuple(output_tokens, masked_lm_output);
|
|
||||||
}
|
|
||||||
|
|
||||||
PYBIND11_MODULE(mask, m) {
|
|
||||||
m.def("create_whole_masked_lm_predictions", &create_whole_masked_lm_predictions);
|
|
||||||
m.def("get_new_segment", &get_new_segment);
|
|
||||||
}
|
|
@ -1,176 +0,0 @@
|
|||||||
import colossalai
|
|
||||||
from numpy import require
|
|
||||||
|
|
||||||
__all__ = ['parse_args']
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = colossalai.get_default_parser()
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--distplan",
|
|
||||||
type=str,
|
|
||||||
default='CAI_Gemini',
|
|
||||||
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--tp_degree",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--placement",
|
|
||||||
type=str,
|
|
||||||
default='cpu',
|
|
||||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--shardinit",
|
|
||||||
action='store_true',
|
|
||||||
help="Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--lr',
|
|
||||||
type=float,
|
|
||||||
required=True,
|
|
||||||
help='initial learning rate')
|
|
||||||
parser.add_argument(
|
|
||||||
'--epoch',
|
|
||||||
type=int,
|
|
||||||
required=True,
|
|
||||||
help='number of epoch')
|
|
||||||
parser.add_argument(
|
|
||||||
'--data_path_prefix',
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="location of the train data corpus")
|
|
||||||
parser.add_argument(
|
|
||||||
'--eval_data_path_prefix',
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help='location of the evaluation data corpus')
|
|
||||||
parser.add_argument(
|
|
||||||
'--tokenizer_path',
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help='location of the tokenizer')
|
|
||||||
parser.add_argument(
|
|
||||||
'--max_seq_length',
|
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help='sequence length')
|
|
||||||
parser.add_argument(
|
|
||||||
'--refresh_bucket_size',
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help=
|
|
||||||
"This param makes sure that a certain task is repeated for this time steps to \
|
|
||||||
optimise on the back propogation speed with APEX's DistributedDataParallel")
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_predictions_per_seq",
|
|
||||||
"--max_pred",
|
|
||||||
default=80,
|
|
||||||
type=int,
|
|
||||||
help=
|
|
||||||
"The maximum number of masked tokens in a sequence to be predicted.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--gradient_accumulation_steps",
|
|
||||||
default=1,
|
|
||||||
type=int,
|
|
||||||
help="accumulation_steps")
|
|
||||||
parser.add_argument(
|
|
||||||
"--train_micro_batch_size_per_gpu",
|
|
||||||
default=2,
|
|
||||||
type=int,
|
|
||||||
required=True,
|
|
||||||
help="train batch size")
|
|
||||||
parser.add_argument(
|
|
||||||
"--eval_micro_batch_size_per_gpu",
|
|
||||||
default=2,
|
|
||||||
type=int,
|
|
||||||
required=True,
|
|
||||||
help="eval batch size")
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_workers",
|
|
||||||
default=8,
|
|
||||||
type=int,
|
|
||||||
help="")
|
|
||||||
parser.add_argument(
|
|
||||||
"--async_worker",
|
|
||||||
action='store_true',
|
|
||||||
help="")
|
|
||||||
parser.add_argument(
|
|
||||||
"--bert_config",
|
|
||||||
required=True,
|
|
||||||
type=str,
|
|
||||||
help="location of config.json")
|
|
||||||
parser.add_argument(
|
|
||||||
"--wandb",
|
|
||||||
action='store_true',
|
|
||||||
help="use wandb to watch model")
|
|
||||||
parser.add_argument(
|
|
||||||
"--wandb_project_name",
|
|
||||||
default='roberta',
|
|
||||||
help="wandb project name")
|
|
||||||
parser.add_argument(
|
|
||||||
"--log_interval",
|
|
||||||
default=100,
|
|
||||||
type=int,
|
|
||||||
help="report interval")
|
|
||||||
parser.add_argument(
|
|
||||||
"--log_path",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="log file which records train step")
|
|
||||||
parser.add_argument(
|
|
||||||
"--tensorboard_path",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="location of tensorboard file")
|
|
||||||
parser.add_argument(
|
|
||||||
"--colossal_config",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="colossal config, which contains zero config and so on")
|
|
||||||
parser.add_argument(
|
|
||||||
"--ckpt_path",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="location of saving checkpoint, which contains model and optimizer")
|
|
||||||
parser.add_argument(
|
|
||||||
'--seed',
|
|
||||||
type=int,
|
|
||||||
default=42,
|
|
||||||
help="random seed for initialization")
|
|
||||||
parser.add_argument(
|
|
||||||
'--vscode_debug',
|
|
||||||
action='store_true',
|
|
||||||
help="use vscode to debug")
|
|
||||||
parser.add_argument(
|
|
||||||
'--load_pretrain_model',
|
|
||||||
default='',
|
|
||||||
type=str,
|
|
||||||
help="location of model's checkpoin")
|
|
||||||
parser.add_argument(
|
|
||||||
'--load_optimizer_lr',
|
|
||||||
default='',
|
|
||||||
type=str,
|
|
||||||
help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step")
|
|
||||||
parser.add_argument(
|
|
||||||
'--resume_train',
|
|
||||||
action='store_true',
|
|
||||||
help="whether resume training from a early checkpoint")
|
|
||||||
parser.add_argument(
|
|
||||||
'--mlm',
|
|
||||||
default='bert',
|
|
||||||
type=str,
|
|
||||||
help="model type, bert or deberta")
|
|
||||||
parser.add_argument(
|
|
||||||
'--checkpoint_activations',
|
|
||||||
action='store_true',
|
|
||||||
help="whether to use gradient checkpointing")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
return args
|
|
Loading…
Reference in New Issue
Block a user