mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 04:02:17 +00:00
[example] reorganize for community examples (#3557)
This commit is contained in:
parent
1a809eddaa
commit
f1b3d60cae
@ -10,9 +10,12 @@
|
|||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
This folder provides several examples accelerated by Colossal-AI. The `tutorial` folder is for everyone to quickly try out the different features in Colossal-AI. Other folders such as `images` and `language` include a wide range of deep learning tasks and applications.
|
This folder provides several examples accelerated by Colossal-AI.
|
||||||
|
Folders such as `images` and `language` include a wide range of deep learning tasks and applications.
|
||||||
|
The `community` folder aim to create a collaborative platform for developers to contribute exotic features built on top of Colossal-AI.
|
||||||
|
The `tutorial` folder is for everyone to quickly try out the different features in Colossal-AI.
|
||||||
|
|
||||||
You can find applications such as Chatbot, Stable Diffusion and Biomedicine in the [Applications](https://github.com/hpcaitech/ColossalAI/tree/main/applications) directory.
|
You can find applications such as Chatbot, AIGC and Biomedicine in the [Applications](https://github.com/hpcaitech/ColossalAI/tree/main/applications) directory.
|
||||||
|
|
||||||
## Folder Structure
|
## Folder Structure
|
||||||
|
|
||||||
@ -52,3 +55,10 @@ Therefore, it is essential for the example contributors to know how to integrate
|
|||||||
2. Configure your testing parameters such as number steps, batch size in `test_ci.sh`, e.t.c. Keep these parameters small such that each example only takes several minutes.
|
2. Configure your testing parameters such as number steps, batch size in `test_ci.sh`, e.t.c. Keep these parameters small such that each example only takes several minutes.
|
||||||
3. Export your dataset path with the prefix `/data` and make sure you have a copy of the dataset in the `/data/scratch/examples-data` directory on the CI machine. Community contributors can contact us via slack to request for downloading the dataset on the CI machine.
|
3. Export your dataset path with the prefix `/data` and make sure you have a copy of the dataset in the `/data/scratch/examples-data` directory on the CI machine. Community contributors can contact us via slack to request for downloading the dataset on the CI machine.
|
||||||
4. Implement the logic such as dependency setup and example execution
|
4. Implement the logic such as dependency setup and example execution
|
||||||
|
|
||||||
|
## Community Dependency
|
||||||
|
We are happy to introduce the following nice community dependency repos that are powered by Colossal-AI:
|
||||||
|
- [lightning-ColossalAI](https://github.com/Lightning-AI/lightning)
|
||||||
|
- [HCP-Diffusion](https://github.com/7eu7d7/HCP-Diffusion)
|
||||||
|
- [KoChatGPT](https://github.com/airobotlab/KoChatGPT)
|
||||||
|
- [minichatgpt](https://github.com/juncongmoo/minichatgpt)
|
||||||
|
28
examples/community/README.md
Normal file
28
examples/community/README.md
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
#Community Examples
|
||||||
|
|
||||||
|
Community-driven Examples is an initiative that allows users to share their own examples to the Colossal-AI community, fostering a sense of community and making it easy for others to access and benefit from shared work. The primary goal with community-driven examples is to have a community-maintained collection of diverse and exotic functionalities built on top of the Colossal-AI package.
|
||||||
|
|
||||||
|
If a community example doesn't work as expected, you can [open an issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) and @ the author to report it.
|
||||||
|
|
||||||
|
|
||||||
|
| Example | Description | Code Example | Colab |Author |
|
||||||
|
|:------------------|:---------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------|:-----------------------------------------|-----------------------------------------------------:|
|
||||||
|
| RoBERTa | Adding RoBERTa for SFT and Prompts model training | [RoBERTa](./roberta) | - | [YY Lin](https://github.com/yynil) (Moore Threads) |
|
||||||
|
| TransformerEngine FP8 | Adding TransformerEngine with FP8 training | [TransformerEngine FP8](./fp8) | - | [Kirthi Shankar Sivamani](https://github.com/ksivaman) (NVIDIA) |
|
||||||
|
|...|...|...|...|...|
|
||||||
|
|
||||||
|
## Looking for Examples
|
||||||
|
* [Swin-Transformer](https://github.com/microsoft/Swin-Transformer)
|
||||||
|
* [T-5](https://github.com/google-research/text-to-text-transfer-transformer)
|
||||||
|
* [Segment Anything (SAM)](https://github.com/facebookresearch/segment-anything)
|
||||||
|
* [ControlNet](https://github.com/lllyasviel/ControlNet)
|
||||||
|
* [Consistency Models](https://github.com/openai/consistency_models)
|
||||||
|
* [MAE](https://github.com/facebookresearch/mae)
|
||||||
|
* [CLIP](https://github.com/openai/CLIP)
|
||||||
|
|
||||||
|
Welcome to [open an issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) to share your insights and needs.
|
||||||
|
|
||||||
|
## How to get involved
|
||||||
|
To join our community-driven initiative, please visit the [Colossal-AI examples](https://github.com/hpcaitech/ColossalAI/tree/main/examples), review the provided information, and explore the codebase.
|
||||||
|
|
||||||
|
To contribute, create a new issue outlining your proposed feature or enhancement, and our team will review and provide feedback. If you are confident enough you can also submit a PR directly. We look forward to collaborating with you on this exciting project!
|
@ -1,13 +1,13 @@
|
|||||||
# Basic MNIST Example with optional FP8 of TransformerEngine
|
# Basic MNIST Example with optional FP8 of TransformerEngine
|
||||||
|
|
||||||
[TransformerEngine](https://github.com/NVIDIA/TransformerEngine) is a library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower memory utilization in both training and inference.
|
[TransformerEngine](https://github.com/NVIDIA/TransformerEngine) is a library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower memory utilization in both training and inference.
|
||||||
|
|
||||||
Thanks for the contribution to this tutorial from NVIDIA.
|
Thanks for the contribution to this tutorial from NVIDIA.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python main.py
|
python main.py
|
||||||
python main.py --use-te # Linear layers from TransformerEngine
|
python main.py --use-te # Linear layers from TransformerEngine
|
||||||
python main.py --use-fp8 # FP8 + TransformerEngine for Linear layers
|
python main.py --use-fp8 # FP8 + TransformerEngine for Linear layers
|
||||||
```
|
```
|
||||||
|
|
||||||
> We are working to integrate it with Colossal-AI and will finish it soon.
|
> We are working to integrate it with Colossal-AI and will finish it soon.
|
@ -3,12 +3,13 @@
|
|||||||
# See LICENSE for license information.
|
# See LICENSE for license information.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torchvision import datasets, transforms
|
|
||||||
from torch.optim.lr_scheduler import StepLR
|
from torch.optim.lr_scheduler import StepLR
|
||||||
|
from torchvision import datasets, transforms
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformer_engine import pytorch as te
|
from transformer_engine import pytorch as te
|
||||||
@ -18,6 +19,7 @@ except (ImportError, ModuleNotFoundError):
|
|||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
|
|
||||||
def __init__(self, use_te=False):
|
def __init__(self, use_te=False):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
||||||
@ -62,12 +64,10 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
if batch_idx % args.log_interval == 0:
|
if batch_idx % args.log_interval == 0:
|
||||||
print(
|
print(f"Train Epoch: {epoch} "
|
||||||
f"Train Epoch: {epoch} "
|
f"[{batch_idx * len(data)}/{len(train_loader.dataset)} "
|
||||||
f"[{batch_idx * len(data)}/{len(train_loader.dataset)} "
|
f"({100. * batch_idx / len(train_loader):.0f}%)]\t"
|
||||||
f"({100. * batch_idx / len(train_loader):.0f}%)]\t"
|
f"Loss: {loss.item():.6f}")
|
||||||
f"Loss: {loss.item():.6f}"
|
|
||||||
)
|
|
||||||
if args.dry_run:
|
if args.dry_run:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -83,6 +83,7 @@ def calibrate(model, device, test_loader):
|
|||||||
with te.fp8_autocast(enabled=False, calibrating=True):
|
with te.fp8_autocast(enabled=False, calibrating=True):
|
||||||
output = model(data)
|
output = model(data)
|
||||||
|
|
||||||
|
|
||||||
def test(model, device, test_loader, use_fp8):
|
def test(model, device, test_loader, use_fp8):
|
||||||
"""Testing function."""
|
"""Testing function."""
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -93,21 +94,15 @@ def test(model, device, test_loader, use_fp8):
|
|||||||
data, target = data.to(device), target.to(device)
|
data, target = data.to(device), target.to(device)
|
||||||
with te.fp8_autocast(enabled=use_fp8):
|
with te.fp8_autocast(enabled=use_fp8):
|
||||||
output = model(data)
|
output = model(data)
|
||||||
test_loss += F.nll_loss(
|
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
|
||||||
output, target, reduction="sum"
|
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
|
||||||
).item() # sum up batch loss
|
|
||||||
pred = output.argmax(
|
|
||||||
dim=1, keepdim=True
|
|
||||||
) # get the index of the max log-probability
|
|
||||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||||
|
|
||||||
test_loss /= len(test_loader.dataset)
|
test_loss /= len(test_loader.dataset)
|
||||||
|
|
||||||
print(
|
print(f"\nTest set: Average loss: {test_loss:.4f}, "
|
||||||
f"\nTest set: Average loss: {test_loss:.4f}, "
|
f"Accuracy: {correct}/{len(test_loader.dataset)} "
|
||||||
f"Accuracy: {correct}/{len(test_loader.dataset)} "
|
f"({100. * correct / len(test_loader.dataset):.0f}%)\n")
|
||||||
f"({100. * correct / len(test_loader.dataset):.0f}%)\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -154,9 +149,7 @@ def main():
|
|||||||
default=False,
|
default=False,
|
||||||
help="quickly check a single pass",
|
help="quickly check a single pass",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
|
||||||
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--log-interval",
|
"--log-interval",
|
||||||
type=int,
|
type=int,
|
||||||
@ -170,15 +163,12 @@ def main():
|
|||||||
default=False,
|
default=False,
|
||||||
help="For Saving the current Model",
|
help="For Saving the current Model",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--use-fp8",
|
||||||
"--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration"
|
action="store_true",
|
||||||
)
|
default=False,
|
||||||
parser.add_argument(
|
help="Use FP8 for inference and training without recalibration")
|
||||||
"--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only"
|
parser.add_argument("--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only")
|
||||||
)
|
parser.add_argument("--use-te", action="store_true", default=False, help="Use Transformer Engine")
|
||||||
parser.add_argument(
|
|
||||||
"--use-te", action="store_true", default=False, help="Use Transformer Engine"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
@ -205,9 +195,7 @@ def main():
|
|||||||
train_kwargs.update(cuda_kwargs)
|
train_kwargs.update(cuda_kwargs)
|
||||||
test_kwargs.update(cuda_kwargs)
|
test_kwargs.update(cuda_kwargs)
|
||||||
|
|
||||||
transform = transforms.Compose(
|
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
||||||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
|
||||||
)
|
|
||||||
dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
|
dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
|
||||||
dataset2 = datasets.MNIST("../data", train=False, transform=transform)
|
dataset2 = datasets.MNIST("../data", train=False, transform=transform)
|
||||||
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
|
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
|
||||||
@ -227,7 +215,7 @@ def main():
|
|||||||
|
|
||||||
if args.save_model or args.use_fp8_infer:
|
if args.save_model or args.use_fp8_infer:
|
||||||
torch.save(model.state_dict(), "mnist_cnn.pt")
|
torch.save(model.state_dict(), "mnist_cnn.pt")
|
||||||
print('Eval with reloaded checkpoint : fp8='+str(args.use_fp8_infer))
|
print('Eval with reloaded checkpoint : fp8=' + str(args.use_fp8_infer))
|
||||||
weights = torch.load("mnist_cnn.pt")
|
weights = torch.load("mnist_cnn.pt")
|
||||||
model.load_state_dict(weights)
|
model.load_state_dict(weights)
|
||||||
test(model, device, test_loader, args.use_fp8_infer)
|
test(model, device, test_loader, args.use_fp8_infer)
|
@ -11,7 +11,7 @@ ssh-keygen
|
|||||||
ssh-copy-id -i ~/.ssh/id_rsa.pub ip_destination
|
ssh-copy-id -i ~/.ssh/id_rsa.pub ip_destination
|
||||||
```
|
```
|
||||||
|
|
||||||
- In all hosts, edit /etc/hosts to record all hosts' name and ip.The example is shown below.
|
- In all hosts, edit /etc/hosts to record all hosts' name and ip.The example is shown below.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
192.168.2.1 GPU001
|
192.168.2.1 GPU001
|
||||||
@ -29,7 +29,7 @@ ssh-copy-id -i ~/.ssh/id_rsa.pub ip_destination
|
|||||||
service ssh restart
|
service ssh restart
|
||||||
```
|
```
|
||||||
|
|
||||||
## 1. Corpus Preprocessing
|
## 1. Corpus Preprocessing
|
||||||
```bash
|
```bash
|
||||||
cd preprocessing
|
cd preprocessing
|
||||||
```
|
```
|
@ -21,7 +21,7 @@ This folder is used to preprocess chinese corpus with Whole Word Masked. You can
|
|||||||
<span id='Split Sentence'/>
|
<span id='Split Sentence'/>
|
||||||
|
|
||||||
### 2.1. Split Sentence & Split data into multiple shard:
|
### 2.1. Split Sentence & Split data into multiple shard:
|
||||||
Firstly, each file has multiple documents, and each document contains multiple sentences. Split sentence through punctuation, such as `。!`. **Secondly, split data into multiple shard based on server hardware (cpu, cpu memory, hard disk) and corpus size.** Each shard contains a part of corpus, and the model needs to train all the shards as one epoch.
|
Firstly, each file has multiple documents, and each document contains multiple sentences. Split sentence through punctuation, such as `。!`. **Secondly, split data into multiple shard based on server hardware (cpu, cpu memory, hard disk) and corpus size.** Each shard contains a part of corpus, and the model needs to train all the shards as one epoch.
|
||||||
In this example, split 200G Corpus into 100 shard, and each shard is about 2G. The size of the shard is memory-dependent, taking into account the number of servers, the memory used by the tokenizer, and the memory used by the multi-process training to read the shard (n data parallel requires n\*shard_size memory). **To sum up, data preprocessing and model pretraining requires fighting with hardware, not just GPU.**
|
In this example, split 200G Corpus into 100 shard, and each shard is about 2G. The size of the shard is memory-dependent, taking into account the number of servers, the memory used by the tokenizer, and the memory used by the multi-process training to read the shard (n data parallel requires n\*shard_size memory). **To sum up, data preprocessing and model pretraining requires fighting with hardware, not just GPU.**
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@ -49,7 +49,7 @@ python sentence_split.py --input_path /orginal_corpus --output_path /shard --sha
|
|||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
<summary><b>Output txt:</b></summary>
|
<summary><b>Output txt:</b></summary>
|
||||||
|
|
||||||
```
|
```
|
||||||
我今天去打篮球。
|
我今天去打篮球。
|
||||||
@ -76,7 +76,7 @@ make
|
|||||||
|
|
||||||
* `--input_path`: location of all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ...
|
* `--input_path`: location of all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ...
|
||||||
* `--output_path`: location of all h5 with token_id, input_mask, segment_ids and masked_lm_positions, e.g., /h5/0.h5, /h5/1.h5 ...
|
* `--output_path`: location of all h5 with token_id, input_mask, segment_ids and masked_lm_positions, e.g., /h5/0.h5, /h5/1.h5 ...
|
||||||
* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json. Download config.json, special_tokens_map.json, vocab.txt and tokenzier.json from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main)
|
* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json. Download config.json, special_tokens_map.json, vocab.txt and tokenzier.json from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main)
|
||||||
* `--backend`: python or c++, **specifies c++ can obtain faster preprocess speed**
|
* `--backend`: python or c++, **specifies c++ can obtain faster preprocess speed**
|
||||||
* `--dupe_factor`: specifies how many times the preprocessor repeats to create the input from the same article/document
|
* `--dupe_factor`: specifies how many times the preprocessor repeats to create the input from the same article/document
|
||||||
* `--worker`: number of process
|
* `--worker`: number of process
|
||||||
@ -91,7 +91,7 @@ make
|
|||||||
下周请假。
|
下周请假。
|
||||||
```
|
```
|
||||||
|
|
||||||
<summary><b>Output h5+numpy:</b></summary>
|
<summary><b>Output h5+numpy:</b></summary>
|
||||||
|
|
||||||
```
|
```
|
||||||
'input_ids': [[id0,id1,id2,id3,id4,id5,id6,0,0..],
|
'input_ids': [[id0,id1,id2,id3,id4,id5,id6,0,0..],
|
||||||
@ -102,4 +102,4 @@ make
|
|||||||
...]
|
...]
|
||||||
'masked_lm_positions': [[label1,-1,-1,label2,-1...],
|
'masked_lm_positions': [[label1,-1,-1,label2,-1...],
|
||||||
...]
|
...]
|
||||||
```
|
```
|
@ -1,20 +1,22 @@
|
|||||||
import torch
|
import collections
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from random import choice
|
from random import choice
|
||||||
import random
|
|
||||||
import collections
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
import jieba
|
import jieba
|
||||||
|
import torch
|
||||||
|
|
||||||
jieba.setLogLevel(logging.CRITICAL)
|
jieba.setLogLevel(logging.CRITICAL)
|
||||||
import re
|
import re
|
||||||
import numpy as np
|
|
||||||
import mask
|
import mask
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
PAD = 0
|
PAD = 0
|
||||||
MaskedLMInstance = collections.namedtuple("MaskedLMInstance",
|
MaskedLMInstance = collections.namedtuple("MaskedLMInstance", ["index", "label"])
|
||||||
["index", "label"])
|
|
||||||
|
|
||||||
|
|
||||||
def map_to_numpy(data):
|
def map_to_numpy(data):
|
||||||
@ -22,6 +24,7 @@ def map_to_numpy(data):
|
|||||||
|
|
||||||
|
|
||||||
class PreTrainingDataset():
|
class PreTrainingDataset():
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
max_seq_length,
|
max_seq_length,
|
||||||
@ -43,17 +46,15 @@ class PreTrainingDataset():
|
|||||||
self.mlm_tamper_p = 0.05
|
self.mlm_tamper_p = 0.05
|
||||||
self.mlm_maintain_p = 0.1
|
self.mlm_maintain_p = 0.1
|
||||||
|
|
||||||
|
|
||||||
def tokenize(self, doc):
|
def tokenize(self, doc):
|
||||||
temp = []
|
temp = []
|
||||||
for d in doc:
|
for d in doc:
|
||||||
temp.append(self.tokenizer.tokenize(d))
|
temp.append(self.tokenizer.tokenize(d))
|
||||||
return temp
|
return temp
|
||||||
|
|
||||||
|
|
||||||
def create_training_instance(self, instance):
|
def create_training_instance(self, instance):
|
||||||
is_next = 1
|
is_next = 1
|
||||||
raw_text_list = self.get_new_segment(instance)
|
raw_text_list = self.get_new_segment(instance)
|
||||||
tokens_a = raw_text_list
|
tokens_a = raw_text_list
|
||||||
assert len(tokens_a) == len(instance)
|
assert len(tokens_a) == len(instance)
|
||||||
# tokens_a, tokens_b, is_next = instance.get_values()
|
# tokens_a, tokens_b, is_next = instance.get_values()
|
||||||
@ -83,8 +84,9 @@ class PreTrainingDataset():
|
|||||||
|
|
||||||
# Get Masked LM predictions
|
# Get Masked LM predictions
|
||||||
if self.backend == 'c++':
|
if self.backend == 'c++':
|
||||||
output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions(tokens, original_tokens, self.vocab_words,
|
output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions(
|
||||||
self.tokenizer.vocab, self.max_predictions_per_seq, self.masked_lm_prob)
|
tokens, original_tokens, self.vocab_words, self.tokenizer.vocab, self.max_predictions_per_seq,
|
||||||
|
self.masked_lm_prob)
|
||||||
elif self.backend == 'python':
|
elif self.backend == 'python':
|
||||||
output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens)
|
output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens)
|
||||||
|
|
||||||
@ -102,29 +104,25 @@ class PreTrainingDataset():
|
|||||||
map_to_numpy(input_mask),
|
map_to_numpy(input_mask),
|
||||||
map_to_numpy(segment_ids),
|
map_to_numpy(segment_ids),
|
||||||
map_to_numpy(masked_lm_output),
|
map_to_numpy(masked_lm_output),
|
||||||
map_to_numpy([is_next])
|
map_to_numpy([is_next])
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
def create_masked_lm_predictions(self, tokens):
|
def create_masked_lm_predictions(self, tokens):
|
||||||
cand_indexes = []
|
cand_indexes = []
|
||||||
for i, token in enumerate(tokens):
|
for i, token in enumerate(tokens):
|
||||||
if token == "[CLS]" or token == "[SEP]":
|
if token == "[CLS]" or token == "[SEP]":
|
||||||
continue
|
continue
|
||||||
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and
|
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")):
|
||||||
token.startswith("##")):
|
|
||||||
cand_indexes[-1].append(i)
|
cand_indexes[-1].append(i)
|
||||||
else:
|
else:
|
||||||
cand_indexes.append([i])
|
cand_indexes.append([i])
|
||||||
|
|
||||||
# cand_indexes.append(i)
|
# cand_indexes.append(i)
|
||||||
|
|
||||||
random.shuffle(cand_indexes)
|
random.shuffle(cand_indexes)
|
||||||
output_tokens = list(tokens)
|
output_tokens = list(tokens)
|
||||||
|
|
||||||
num_to_predict = min(
|
num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob))))
|
||||||
self.max_predictions_per_seq,
|
|
||||||
max(1, int(round(len(tokens) * self.masked_lm_prob))))
|
|
||||||
|
|
||||||
masked_lms = []
|
masked_lms = []
|
||||||
covered_indexes = set()
|
covered_indexes = set()
|
||||||
@ -145,13 +143,10 @@ class PreTrainingDataset():
|
|||||||
masked_token = tokens[index]
|
masked_token = tokens[index]
|
||||||
# 10% replace w/ random word
|
# 10% replace w/ random word
|
||||||
else:
|
else:
|
||||||
masked_token = self.vocab_words[random.randint(
|
masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)]
|
||||||
0,
|
|
||||||
len(self.vocab_words) - 1)]
|
|
||||||
|
|
||||||
output_tokens[index] = masked_token
|
output_tokens[index] = masked_token
|
||||||
masked_lms.append(
|
masked_lms.append(MaskedLMInstance(index=index, label=tokens[index]))
|
||||||
MaskedLMInstance(index=index, label=tokens[index]))
|
|
||||||
|
|
||||||
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
||||||
masked_lm_output = [-1] * len(output_tokens)
|
masked_lm_output = [-1] * len(output_tokens)
|
||||||
@ -160,7 +155,6 @@ class PreTrainingDataset():
|
|||||||
|
|
||||||
return (output_tokens, masked_lm_output)
|
return (output_tokens, masked_lm_output)
|
||||||
|
|
||||||
|
|
||||||
def get_new_segment(self, segment):
|
def get_new_segment(self, segment):
|
||||||
"""
|
"""
|
||||||
Input a sentence, return a processed sentence: In order to support the Chinese whole word mask, the words that are separated will be marked with a special mark ("#"), so that the subsequent processing module can know which words belong to the same word.
|
Input a sentence, return a processed sentence: In order to support the Chinese whole word mask, the words that are separated will be marked with a special mark ("#"), so that the subsequent processing module can know which words belong to the same word.
|
||||||
@ -171,7 +165,7 @@ class PreTrainingDataset():
|
|||||||
new_segment = []
|
new_segment = []
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(segment):
|
while i < len(segment):
|
||||||
if len(self.rec.findall(segment[i])) == 0:
|
if len(self.rec.findall(segment[i])) == 0:
|
||||||
new_segment.append(segment[i])
|
new_segment.append(segment[i])
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
@ -180,10 +174,10 @@ class PreTrainingDataset():
|
|||||||
for length in range(3, 0, -1):
|
for length in range(3, 0, -1):
|
||||||
if i + length > len(segment):
|
if i + length > len(segment):
|
||||||
continue
|
continue
|
||||||
if ''.join(segment[i: i+length]) in seq_cws_dict:
|
if ''.join(segment[i:i + length]) in seq_cws_dict:
|
||||||
new_segment.append(segment[i])
|
new_segment.append(segment[i])
|
||||||
for l in range(1, length):
|
for l in range(1, length):
|
||||||
new_segment.append('##' + segment[i+l])
|
new_segment.append('##' + segment[i + l])
|
||||||
i += length
|
i += length
|
||||||
has_add = True
|
has_add = True
|
||||||
break
|
break
|
||||||
@ -192,7 +186,6 @@ class PreTrainingDataset():
|
|||||||
i += 1
|
i += 1
|
||||||
return new_segment
|
return new_segment
|
||||||
|
|
||||||
|
|
||||||
def create_whole_masked_lm_predictions(self, tokens):
|
def create_whole_masked_lm_predictions(self, tokens):
|
||||||
"""Creates the predictions for the masked LM objective."""
|
"""Creates the predictions for the masked LM objective."""
|
||||||
|
|
||||||
@ -209,18 +202,16 @@ class PreTrainingDataset():
|
|||||||
# Note that Whole Word Masking does *not* change the training code
|
# Note that Whole Word Masking does *not* change the training code
|
||||||
# at all -- we still predict each WordPiece independently, softmaxed
|
# at all -- we still predict each WordPiece independently, softmaxed
|
||||||
# over the entire vocabulary.
|
# over the entire vocabulary.
|
||||||
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and
|
if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")):
|
||||||
token.startswith("##")):
|
|
||||||
cand_indexes[-1].append(i)
|
cand_indexes[-1].append(i)
|
||||||
else:
|
else:
|
||||||
cand_indexes.append([i])
|
cand_indexes.append([i])
|
||||||
|
|
||||||
random.shuffle(cand_indexes)
|
random.shuffle(cand_indexes)
|
||||||
|
|
||||||
output_tokens = [t[2:] if len(self.whole_rec.findall(t))>0 else t for t in tokens] # 去掉"##"
|
output_tokens = [t[2:] if len(self.whole_rec.findall(t)) > 0 else t for t in tokens] # 去掉"##"
|
||||||
|
|
||||||
num_to_predict = min(self.max_predictions_per_seq,
|
num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob))))
|
||||||
max(1, int(round(len(tokens) * self.masked_lm_prob))))
|
|
||||||
|
|
||||||
masked_lms = []
|
masked_lms = []
|
||||||
covered_indexes = set()
|
covered_indexes = set()
|
||||||
@ -248,14 +239,18 @@ class PreTrainingDataset():
|
|||||||
else:
|
else:
|
||||||
# 10% of the time, keep original
|
# 10% of the time, keep original
|
||||||
if random.random() < 0.5:
|
if random.random() < 0.5:
|
||||||
masked_token = tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index] # 去掉"##"
|
masked_token = tokens[index][2:] if len(self.whole_rec.findall(
|
||||||
|
tokens[index])) > 0 else tokens[index] # 去掉"##"
|
||||||
# 10% of the time, replace with random word
|
# 10% of the time, replace with random word
|
||||||
else:
|
else:
|
||||||
masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)]
|
masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)]
|
||||||
|
|
||||||
output_tokens[index] = masked_token
|
output_tokens[index] = masked_token
|
||||||
|
|
||||||
masked_lms.append(MaskedLMInstance(index=index, label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index]))
|
masked_lms.append(
|
||||||
|
MaskedLMInstance(
|
||||||
|
index=index,
|
||||||
|
label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index]))
|
||||||
assert len(masked_lms) <= num_to_predict
|
assert len(masked_lms) <= num_to_predict
|
||||||
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
||||||
masked_lm_output = [-1] * len(output_tokens)
|
masked_lm_output = [-1] * len(output_tokens)
|
190
examples/community/roberta/preprocessing/mask.cpp
Normal file
190
examples/community/roberta/preprocessing/mask.cpp
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
#include <math.h>
|
||||||
|
#include <pybind11/numpy.h>
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <chrono>
|
||||||
|
#include <iostream>
|
||||||
|
#include <limits>
|
||||||
|
#include <random>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
const int32_t LONG_SENTENCE_LEN = 512;
|
||||||
|
|
||||||
|
struct MaskedLMInstance {
|
||||||
|
int index;
|
||||||
|
std::string label;
|
||||||
|
MaskedLMInstance(int index, std::string label) {
|
||||||
|
this->index = index;
|
||||||
|
this->label = label;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto get_new_segment(
|
||||||
|
std::vector<std::string> segment, std::vector<std::string> segment_jieba,
|
||||||
|
const std::vector<bool> chinese_vocab) { // const
|
||||||
|
// std::unordered_set<std::string>
|
||||||
|
// &chinese_vocab
|
||||||
|
std::unordered_set<std::string> seq_cws_dict;
|
||||||
|
for (auto word : segment_jieba) {
|
||||||
|
seq_cws_dict.insert(word);
|
||||||
|
}
|
||||||
|
int i = 0;
|
||||||
|
std::vector<std::string> new_segment;
|
||||||
|
int segment_size = segment.size();
|
||||||
|
while (i < segment_size) {
|
||||||
|
if (!chinese_vocab[i]) { // chinese_vocab.find(segment[i]) ==
|
||||||
|
// chinese_vocab.end()
|
||||||
|
new_segment.emplace_back(segment[i]);
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
bool has_add = false;
|
||||||
|
for (int length = 3; length >= 1; length--) {
|
||||||
|
if (i + length > segment_size) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::string chinese_word = "";
|
||||||
|
for (int j = i; j < i + length; j++) {
|
||||||
|
chinese_word += segment[j];
|
||||||
|
}
|
||||||
|
if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) {
|
||||||
|
new_segment.emplace_back(segment[i]);
|
||||||
|
for (int j = i + 1; j < i + length; j++) {
|
||||||
|
new_segment.emplace_back("##" + segment[j]);
|
||||||
|
}
|
||||||
|
i += length;
|
||||||
|
has_add = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!has_add) {
|
||||||
|
new_segment.emplace_back(segment[i]);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return new_segment;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool startsWith(const std::string &s, const std::string &sub) {
|
||||||
|
return s.find(sub) == 0 ? true : false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto create_whole_masked_lm_predictions(
|
||||||
|
std::vector<std::string> &tokens,
|
||||||
|
const std::vector<std::string> &original_tokens,
|
||||||
|
const std::vector<std::string> &vocab_words,
|
||||||
|
std::map<std::string, int> &vocab, const int max_predictions_per_seq,
|
||||||
|
const double masked_lm_prob) {
|
||||||
|
// for (auto item : vocab) {
|
||||||
|
// std::cout << "key=" << std::string(py::str(item.first)) << ", "
|
||||||
|
// << "value=" << std::string(py::str(item.second)) <<
|
||||||
|
// std::endl;
|
||||||
|
// }
|
||||||
|
std::vector<std::vector<int> > cand_indexes;
|
||||||
|
std::vector<int> cand_temp;
|
||||||
|
int tokens_size = tokens.size();
|
||||||
|
std::string prefix = "##";
|
||||||
|
bool do_whole_masked = true;
|
||||||
|
|
||||||
|
for (int i = 0; i < tokens_size; i++) {
|
||||||
|
if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (do_whole_masked && (cand_indexes.size() > 0) &&
|
||||||
|
(tokens[i].rfind(prefix, 0) == 0)) {
|
||||||
|
cand_temp.emplace_back(i);
|
||||||
|
} else {
|
||||||
|
if (cand_temp.size() > 0) {
|
||||||
|
cand_indexes.emplace_back(cand_temp);
|
||||||
|
}
|
||||||
|
cand_temp.clear();
|
||||||
|
cand_temp.emplace_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto seed = std::chrono::system_clock::now().time_since_epoch().count();
|
||||||
|
std::shuffle(cand_indexes.begin(), cand_indexes.end(),
|
||||||
|
std::default_random_engine(seed));
|
||||||
|
// for (auto i : cand_indexes) {
|
||||||
|
// for (auto j : i) {
|
||||||
|
// std::cout << tokens[j] << " ";
|
||||||
|
// }
|
||||||
|
// std::cout << std::endl;
|
||||||
|
// }
|
||||||
|
// for (auto i : output_tokens) {
|
||||||
|
// std::cout << i;
|
||||||
|
// }
|
||||||
|
// std::cout << std::endl;
|
||||||
|
|
||||||
|
int num_to_predict = std::min(max_predictions_per_seq,
|
||||||
|
std::max(1, int(tokens_size * masked_lm_prob)));
|
||||||
|
// std::cout << num_to_predict << std::endl;
|
||||||
|
|
||||||
|
std::set<int> covered_indexes;
|
||||||
|
std::vector<int> masked_lm_output(tokens_size, -1);
|
||||||
|
int vocab_words_len = vocab_words.size();
|
||||||
|
std::default_random_engine e(seed);
|
||||||
|
std::uniform_real_distribution<double> u1(0.0, 1.0);
|
||||||
|
std::uniform_int_distribution<unsigned> u2(0, vocab_words_len - 1);
|
||||||
|
int mask_cnt = 0;
|
||||||
|
std::vector<std::string> output_tokens;
|
||||||
|
output_tokens = original_tokens;
|
||||||
|
|
||||||
|
for (auto index_set : cand_indexes) {
|
||||||
|
if (mask_cnt > num_to_predict) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
int index_set_size = index_set.size();
|
||||||
|
if (mask_cnt + index_set_size > num_to_predict) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
bool is_any_index_covered = false;
|
||||||
|
for (auto index : index_set) {
|
||||||
|
if (covered_indexes.find(index) != covered_indexes.end()) {
|
||||||
|
is_any_index_covered = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (is_any_index_covered) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (auto index : index_set) {
|
||||||
|
covered_indexes.insert(index);
|
||||||
|
std::string masked_token;
|
||||||
|
if (u1(e) < 0.8) {
|
||||||
|
masked_token = "[MASK]";
|
||||||
|
} else {
|
||||||
|
if (u1(e) < 0.5) {
|
||||||
|
masked_token = output_tokens[index];
|
||||||
|
} else {
|
||||||
|
int random_index = u2(e);
|
||||||
|
masked_token = vocab_words[random_index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index]));
|
||||||
|
masked_lm_output[index] = vocab[output_tokens[index]];
|
||||||
|
output_tokens[index] = masked_token;
|
||||||
|
mask_cnt++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// for (auto p : masked_lms) {
|
||||||
|
// masked_lm_output[p.index] = vocab[p.label];
|
||||||
|
// }
|
||||||
|
return std::make_tuple(output_tokens, masked_lm_output);
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(mask, m) {
|
||||||
|
m.def("create_whole_masked_lm_predictions",
|
||||||
|
&create_whole_masked_lm_predictions);
|
||||||
|
m.def("get_new_segment", &get_new_segment);
|
||||||
|
}
|
@ -1,28 +1,30 @@
|
|||||||
|
import argparse
|
||||||
|
import functools
|
||||||
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from tqdm import tqdm
|
|
||||||
from typing import List
|
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
import argparse
|
from typing import List
|
||||||
import functools
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]:
|
def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]:
|
||||||
sent_list = []
|
sent_list = []
|
||||||
try:
|
try:
|
||||||
if flag == "zh":
|
if flag == "zh":
|
||||||
document = re.sub('(?P<quotation_mark>([。?!…](?![”’"\'])))', r'\g<quotation_mark>\n', document)
|
document = re.sub('(?P<quotation_mark>([。?!…](?![”’"\'])))', r'\g<quotation_mark>\n', document)
|
||||||
document = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\'])', r'\g<quotation_mark>\n', document)
|
document = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\'])', r'\g<quotation_mark>\n', document)
|
||||||
elif flag == "en":
|
elif flag == "en":
|
||||||
document = re.sub('(?P<quotation_mark>([.?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
|
document = re.sub('(?P<quotation_mark>([.?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
|
||||||
document = re.sub('(?P<quotation_mark>([?!.]["\']))', r'\g<quotation_mark>\n', document) # Special quotation marks
|
document = re.sub('(?P<quotation_mark>([?!.]["\']))', r'\g<quotation_mark>\n',
|
||||||
|
document) # Special quotation marks
|
||||||
else:
|
else:
|
||||||
document = re.sub('(?P<quotation_mark>([。?!….?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
|
document = re.sub('(?P<quotation_mark>([。?!….?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
|
||||||
|
|
||||||
document = re.sub('(?P<quotation_mark>(([。?!.!?]|…{1,2})[”’"\']))', r'\g<quotation_mark>\n',
|
document = re.sub('(?P<quotation_mark>(([。?!.!?]|…{1,2})[”’"\']))', r'\g<quotation_mark>\n',
|
||||||
document) # Special quotation marks
|
document) # Special quotation marks
|
||||||
|
|
||||||
sent_list_ori = document.splitlines()
|
sent_list_ori = document.splitlines()
|
||||||
for sent in sent_list_ori:
|
for sent in sent_list_ori:
|
||||||
@ -43,17 +45,15 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s
|
|||||||
return sent_list
|
return sent_list
|
||||||
|
|
||||||
|
|
||||||
def get_sent(output_path,
|
def get_sent(output_path, input_path, fin_list=[], host=-1, seq_len=512) -> None:
|
||||||
input_path,
|
|
||||||
fin_list=[], host=-1, seq_len=512) -> None:
|
|
||||||
|
|
||||||
workers = 32
|
workers = 32
|
||||||
|
|
||||||
if input_path[-1] == '/':
|
if input_path[-1] == '/':
|
||||||
input_path = input_path[:-1]
|
input_path = input_path[:-1]
|
||||||
|
|
||||||
cur_path = os.path.join(output_path, str(host) + '.txt')
|
cur_path = os.path.join(output_path, str(host) + '.txt')
|
||||||
new_split_sentence = functools.partial(split_sentence, limit=seq_len-2)
|
new_split_sentence = functools.partial(split_sentence, limit=seq_len - 2)
|
||||||
with open(cur_path, 'w', encoding='utf-8') as f:
|
with open(cur_path, 'w', encoding='utf-8') as f:
|
||||||
for fi, fin_path in enumerate(fin_list):
|
for fi, fin_path in enumerate(fin_list):
|
||||||
if not os.path.exists(os.path.join(input_path, fin_path[0])):
|
if not os.path.exists(os.path.join(input_path, fin_path[0])):
|
||||||
@ -62,7 +62,7 @@ def get_sent(output_path,
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
print("Processing ", fin_path[0], " ", fi)
|
print("Processing ", fin_path[0], " ", fi)
|
||||||
|
|
||||||
with open(os.path.join(input_path, fin_path[0]), 'r') as fin:
|
with open(os.path.join(input_path, fin_path[0]), 'r') as fin:
|
||||||
f_data = [l['content'] for l in json.load(fin)]
|
f_data = [l['content'] for l in json.load(fin)]
|
||||||
|
|
||||||
@ -99,17 +99,17 @@ def getFileSize(filepath, shard):
|
|||||||
real_shard.append(temp)
|
real_shard.append(temp)
|
||||||
accu_size = 0
|
accu_size = 0
|
||||||
temp = []
|
temp = []
|
||||||
|
|
||||||
if len(temp) > 0:
|
if len(temp) > 0:
|
||||||
real_shard.append(temp)
|
real_shard.append(temp)
|
||||||
|
|
||||||
return real_shard
|
return real_shard
|
||||||
|
|
||||||
|
|
||||||
def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'):
|
def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'):
|
||||||
import socket
|
import socket
|
||||||
host = int(socket.gethostname().split(server_name)[-1])
|
host = int(socket.gethostname().split(server_name)[-1])
|
||||||
|
|
||||||
fin_list = real_shard[server_num * base + host - 1]
|
fin_list = real_shard[server_num * base + host - 1]
|
||||||
print(fin_list)
|
print(fin_list)
|
||||||
print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}')
|
print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}')
|
||||||
@ -126,28 +126,24 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence')
|
parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
server_num = args.server_num
|
server_num = args.server_num
|
||||||
seq_len = args.seq_len
|
seq_len = args.seq_len
|
||||||
shard = args.shard
|
shard = args.shard
|
||||||
input_path = args.input_path
|
input_path = args.input_path
|
||||||
output_path = args.output_path
|
output_path = args.output_path
|
||||||
|
|
||||||
real_shard = getFileSize(input_path, shard)
|
real_shard = getFileSize(input_path, shard)
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
for index, shard in enumerate(real_shard):
|
for index, shard in enumerate(real_shard):
|
||||||
get_sent(output_path,
|
get_sent(output_path, input_path, fin_list=shard, host=index, seq_len=seq_len)
|
||||||
input_path,
|
|
||||||
fin_list=shard,
|
|
||||||
host=index,
|
|
||||||
seq_len=seq_len)
|
|
||||||
print(f'cost {str(time.time() - start)}')
|
print(f'cost {str(time.time() - start)}')
|
||||||
|
|
||||||
# if you have multiple server, you can use code below or modify code to openmpi
|
# if you have multiple server, you can use code below or modify code to openmpi
|
||||||
|
|
||||||
# for i in range(len(real_shard) // server_num + 1):
|
# for i in range(len(real_shard) // server_num + 1):
|
||||||
# fin_list, host = get_start_end(real_shard, i)
|
# fin_list, host = get_start_end(real_shard, i)
|
||||||
|
|
||||||
# start = time.time()
|
# start = time.time()
|
||||||
# get_sent(output_path,
|
# get_sent(output_path,
|
||||||
# input_path,
|
# input_path,
|
@ -1,19 +1,19 @@
|
|||||||
import time
|
|
||||||
import os
|
|
||||||
import psutil
|
|
||||||
import h5py
|
|
||||||
import socket
|
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from tqdm import tqdm
|
import os
|
||||||
|
import socket
|
||||||
|
import time
|
||||||
from random import shuffle
|
from random import shuffle
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
import psutil
|
||||||
from get_mask import PreTrainingDataset
|
from get_mask import PreTrainingDataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
def get_raw_instance(document, max_sequence_length=512):
|
def get_raw_instance(document, max_sequence_length=512):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Get the initial training instances, split the whole segment into multiple parts according to the max_sequence_length, and return as multiple processed instances.
|
Get the initial training instances, split the whole segment into multiple parts according to the max_sequence_length, and return as multiple processed instances.
|
||||||
:param document: document
|
:param document: document
|
||||||
@ -26,24 +26,24 @@ def get_raw_instance(document, max_sequence_length=512):
|
|||||||
sizes = [len(seq) for seq in document]
|
sizes = [len(seq) for seq in document]
|
||||||
|
|
||||||
result_list = []
|
result_list = []
|
||||||
curr_seq = []
|
curr_seq = []
|
||||||
sz_idx = 0
|
sz_idx = 0
|
||||||
while sz_idx < len(sizes):
|
while sz_idx < len(sizes):
|
||||||
|
|
||||||
if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0:
|
if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0:
|
||||||
curr_seq += document[sz_idx]
|
curr_seq += document[sz_idx]
|
||||||
sz_idx += 1
|
sz_idx += 1
|
||||||
elif sizes[sz_idx] >= max_sequence_length_allowed:
|
elif sizes[sz_idx] >= max_sequence_length_allowed:
|
||||||
if len(curr_seq) > 0:
|
if len(curr_seq) > 0:
|
||||||
result_list.append(curr_seq)
|
result_list.append(curr_seq)
|
||||||
curr_seq = []
|
curr_seq = []
|
||||||
result_list.append(document[sz_idx][ : max_sequence_length_allowed])
|
result_list.append(document[sz_idx][:max_sequence_length_allowed])
|
||||||
sz_idx += 1
|
sz_idx += 1
|
||||||
else:
|
else:
|
||||||
result_list.append(curr_seq)
|
result_list.append(curr_seq)
|
||||||
curr_seq = []
|
curr_seq = []
|
||||||
|
|
||||||
if len(curr_seq) > max_sequence_length_allowed / 2: # /2
|
if len(curr_seq) > max_sequence_length_allowed / 2: # /2
|
||||||
result_list.append(curr_seq)
|
result_list.append(curr_seq)
|
||||||
|
|
||||||
# num_instance=int(len(big_list)/max_sequence_length_allowed)+1
|
# num_instance=int(len(big_list)/max_sequence_length_allowed)+1
|
||||||
@ -70,8 +70,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host):
|
|||||||
# document = line
|
# document = line
|
||||||
# if len(document.split("<sep>")) <= 3:
|
# if len(document.split("<sep>")) <= 3:
|
||||||
# continue
|
# continue
|
||||||
if len(line
|
if len(line) > 0 and line[:2] == "]]": # This is end of document
|
||||||
) > 0 and line[:2] == "]]": # This is end of document
|
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
document = []
|
document = []
|
||||||
elif len(line) >= 2:
|
elif len(line) >= 2:
|
||||||
@ -84,8 +83,8 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host):
|
|||||||
# print(len(documents))
|
# print(len(documents))
|
||||||
# print(len(documents[0]))
|
# print(len(documents[0]))
|
||||||
# print(documents[0][0:10])
|
# print(documents[0][0:10])
|
||||||
from typing import List
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
from typing import List
|
||||||
|
|
||||||
ans = []
|
ans = []
|
||||||
for docs in tqdm(documents):
|
for docs in tqdm(documents):
|
||||||
@ -98,7 +97,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host):
|
|||||||
raw_ins = get_raw_instance(a)
|
raw_ins = get_raw_instance(a)
|
||||||
instances.extend(raw_ins)
|
instances.extend(raw_ins)
|
||||||
del ans
|
del ans
|
||||||
|
|
||||||
print('len instance', len(instances))
|
print('len instance', len(instances))
|
||||||
|
|
||||||
sen_num = len(instances)
|
sen_num = len(instances)
|
||||||
@ -116,21 +115,15 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host):
|
|||||||
masked_lm_output[index] = mask_dict[3]
|
masked_lm_output[index] = mask_dict[3]
|
||||||
|
|
||||||
with h5py.File(f'/output/{host}.h5', 'w') as hf:
|
with h5py.File(f'/output/{host}.h5', 'w') as hf:
|
||||||
hf.create_dataset("input_ids", data=input_ids)
|
hf.create_dataset("input_ids", data=input_ids)
|
||||||
hf.create_dataset("input_mask", data=input_ids)
|
hf.create_dataset("input_mask", data=input_ids)
|
||||||
hf.create_dataset("segment_ids", data=segment_ids)
|
hf.create_dataset("segment_ids", data=segment_ids)
|
||||||
hf.create_dataset("masked_lm_positions", data=masked_lm_output)
|
hf.create_dataset("masked_lm_positions", data=masked_lm_output)
|
||||||
|
|
||||||
del instances
|
del instances
|
||||||
|
|
||||||
|
|
||||||
def split_numpy_chunk_pool(input_path,
|
def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_factor, seq_len, file_name):
|
||||||
output_path,
|
|
||||||
pretrain_data,
|
|
||||||
worker,
|
|
||||||
dupe_factor,
|
|
||||||
seq_len,
|
|
||||||
file_name):
|
|
||||||
|
|
||||||
if os.path.exists(os.path.join(output_path, f'{file_name}.h5')):
|
if os.path.exists(os.path.join(output_path, f'{file_name}.h5')):
|
||||||
print(f'{file_name}.h5 exists')
|
print(f'{file_name}.h5 exists')
|
||||||
@ -144,8 +137,7 @@ def split_numpy_chunk_pool(input_path,
|
|||||||
document = []
|
document = []
|
||||||
for i, line in enumerate(tqdm(fd)):
|
for i, line in enumerate(tqdm(fd)):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if len(line
|
if len(line) > 0 and line[:2] == "]]": # This is end of document
|
||||||
) > 0 and line[:2] == "]]": # This is end of document
|
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
document = []
|
document = []
|
||||||
elif len(line) >= 2:
|
elif len(line) >= 2:
|
||||||
@ -153,7 +145,7 @@ def split_numpy_chunk_pool(input_path,
|
|||||||
if len(document) > 0:
|
if len(document) > 0:
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
print(f'read_file cost {time.time() - s}, length is {len(documents)}')
|
print(f'read_file cost {time.time() - s}, length is {len(documents)}')
|
||||||
|
|
||||||
ans = []
|
ans = []
|
||||||
s = time.time()
|
s = time.time()
|
||||||
pool = multiprocessing.Pool(worker)
|
pool = multiprocessing.Pool(worker)
|
||||||
@ -169,7 +161,7 @@ def split_numpy_chunk_pool(input_path,
|
|||||||
raw_ins = get_raw_instance(a, max_sequence_length=seq_len)
|
raw_ins = get_raw_instance(a, max_sequence_length=seq_len)
|
||||||
instances.extend(raw_ins)
|
instances.extend(raw_ins)
|
||||||
del ans
|
del ans
|
||||||
|
|
||||||
print('len instance', len(instances))
|
print('len instance', len(instances))
|
||||||
|
|
||||||
new_instances = []
|
new_instances = []
|
||||||
@ -199,10 +191,10 @@ def split_numpy_chunk_pool(input_path,
|
|||||||
print((time.time() - s) / 60)
|
print((time.time() - s) / 60)
|
||||||
|
|
||||||
with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf:
|
with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf:
|
||||||
hf.create_dataset("input_ids", data=input_ids)
|
hf.create_dataset("input_ids", data=input_ids)
|
||||||
hf.create_dataset("input_mask", data=input_mask)
|
hf.create_dataset("input_mask", data=input_mask)
|
||||||
hf.create_dataset("segment_ids", data=segment_ids)
|
hf.create_dataset("segment_ids", data=segment_ids)
|
||||||
hf.create_dataset("masked_lm_positions", data=masked_lm_output)
|
hf.create_dataset("masked_lm_positions", data=masked_lm_output)
|
||||||
|
|
||||||
del instances
|
del instances
|
||||||
|
|
||||||
@ -212,22 +204,31 @@ if __name__ == '__main__':
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer')
|
parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer')
|
||||||
parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
|
parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
|
||||||
parser.add_argument('--max_predictions_per_seq', type=int, default=80, help='number of shards, e.g., 10, 50, or 100')
|
parser.add_argument('--max_predictions_per_seq',
|
||||||
|
type=int,
|
||||||
|
default=80,
|
||||||
|
help='number of shards, e.g., 10, 50, or 100')
|
||||||
parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence')
|
parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence')
|
||||||
parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id')
|
parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id')
|
||||||
parser.add_argument('--backend', type=str, default='python', help='backend of mask token, python, c++, numpy respectively')
|
parser.add_argument('--backend',
|
||||||
parser.add_argument('--dupe_factor', type=int, default=1, help='specifies how many times the preprocessor repeats to create the input from the same article/document')
|
type=str,
|
||||||
|
default='python',
|
||||||
|
help='backend of mask token, python, c++, numpy respectively')
|
||||||
|
parser.add_argument(
|
||||||
|
'--dupe_factor',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help='specifies how many times the preprocessor repeats to create the input from the same article/document')
|
||||||
parser.add_argument('--worker', type=int, default=32, help='number of process')
|
parser.add_argument('--worker', type=int, default=32, help='number of process')
|
||||||
parser.add_argument('--server_num', type=int, default=10, help='number of servers')
|
parser.add_argument('--server_num', type=int, default=10, help='number of servers')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
|
||||||
pretrain_data = PreTrainingDataset(tokenizer,
|
pretrain_data = PreTrainingDataset(tokenizer,
|
||||||
args.seq_len,
|
args.seq_len,
|
||||||
args.backend,
|
args.backend,
|
||||||
max_predictions_per_seq=args.max_predictions_per_seq)
|
max_predictions_per_seq=args.max_predictions_per_seq)
|
||||||
|
|
||||||
|
|
||||||
data_len = len(os.listdir(args.input_path))
|
data_len = len(os.listdir(args.input_path))
|
||||||
|
|
||||||
for i in range(data_len):
|
for i in range(data_len):
|
||||||
@ -235,15 +236,10 @@ if __name__ == '__main__':
|
|||||||
if os.path.exists(input_path):
|
if os.path.exists(input_path):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
print(f'process {input_path}')
|
print(f'process {input_path}')
|
||||||
split_numpy_chunk_pool(input_path,
|
split_numpy_chunk_pool(input_path, args.output_path, pretrain_data, args.worker, args.dupe_factor,
|
||||||
args.output_path,
|
args.seq_len, i)
|
||||||
pretrain_data,
|
|
||||||
args.worker,
|
|
||||||
args.dupe_factor,
|
|
||||||
args.seq_len,
|
|
||||||
i)
|
|
||||||
end_ = time.time()
|
end_ = time.time()
|
||||||
print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
|
print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024))
|
||||||
print(f'has cost {(end_ - start) / 60}')
|
print(f'has cost {(end_ - start) / 60}')
|
||||||
print('-' * 100)
|
print('-' * 100)
|
||||||
print('')
|
print('')
|
||||||
@ -257,9 +253,9 @@ if __name__ == '__main__':
|
|||||||
# if os.path.exists(input_path):
|
# if os.path.exists(input_path):
|
||||||
# start = time.time()
|
# start = time.time()
|
||||||
# print(f'I am server {host}, process {input_path}')
|
# print(f'I am server {host}, process {input_path}')
|
||||||
# split_numpy_chunk_pool(input_path,
|
# split_numpy_chunk_pool(input_path,
|
||||||
# args.output_path,
|
# args.output_path,
|
||||||
# pretrain_data,
|
# pretrain_data,
|
||||||
# args.worker,
|
# args.worker,
|
||||||
# args.dupe_factor,
|
# args.dupe_factor,
|
||||||
# args.seq_len,
|
# args.seq_len,
|
||||||
@ -269,5 +265,3 @@ if __name__ == '__main__':
|
|||||||
# print(f'has cost {(end_ - start) / 60}')
|
# print(f'has cost {(end_ - start) / 60}')
|
||||||
# print('-' * 100)
|
# print('-' * 100)
|
||||||
# print('')
|
# print('')
|
||||||
|
|
||||||
|
|
@ -19,6 +19,5 @@ bash run_pretrain.sh
|
|||||||
bash run_pretrain_resume.sh
|
bash run_pretrain_resume.sh
|
||||||
```
|
```
|
||||||
* `--resume_train`: whether to resume training
|
* `--resume_train`: whether to resume training
|
||||||
* `--load_pretrain_model`: absolute path which contains model checkpoint
|
* `--load_pretrain_model`: absolute path which contains model checkpoint
|
||||||
* `--load_optimizer_lr`: absolute path which contains optimizer checkpoint
|
* `--load_optimizer_lr`: absolute path which contains optimizer checkpoint
|
||||||
|
|
87
examples/community/roberta/pretraining/arguments.py
Normal file
87
examples/community/roberta/pretraining/arguments.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from numpy import require
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
|
||||||
|
__all__ = ['parse_args']
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = colossalai.get_default_parser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--distplan",
|
||||||
|
type=str,
|
||||||
|
default='CAI_Gemini',
|
||||||
|
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tp_degree",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--placement",
|
||||||
|
type=str,
|
||||||
|
default='cpu',
|
||||||
|
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--shardinit",
|
||||||
|
action='store_true',
|
||||||
|
help=
|
||||||
|
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument('--lr', type=float, required=True, help='initial learning rate')
|
||||||
|
parser.add_argument('--epoch', type=int, required=True, help='number of epoch')
|
||||||
|
parser.add_argument('--data_path_prefix', type=str, required=True, help="location of the train data corpus")
|
||||||
|
parser.add_argument('--eval_data_path_prefix',
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help='location of the evaluation data corpus')
|
||||||
|
parser.add_argument('--tokenizer_path', type=str, required=True, help='location of the tokenizer')
|
||||||
|
parser.add_argument('--max_seq_length', type=int, default=512, help='sequence length')
|
||||||
|
parser.add_argument('--refresh_bucket_size',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="This param makes sure that a certain task is repeated for this time steps to \
|
||||||
|
optimise on the back propogation speed with APEX's DistributedDataParallel")
|
||||||
|
parser.add_argument("--max_predictions_per_seq",
|
||||||
|
"--max_pred",
|
||||||
|
default=80,
|
||||||
|
type=int,
|
||||||
|
help="The maximum number of masked tokens in a sequence to be predicted.")
|
||||||
|
parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="accumulation_steps")
|
||||||
|
parser.add_argument("--train_micro_batch_size_per_gpu", default=2, type=int, required=True, help="train batch size")
|
||||||
|
parser.add_argument("--eval_micro_batch_size_per_gpu", default=2, type=int, required=True, help="eval batch size")
|
||||||
|
parser.add_argument("--num_workers", default=8, type=int, help="")
|
||||||
|
parser.add_argument("--async_worker", action='store_true', help="")
|
||||||
|
parser.add_argument("--bert_config", required=True, type=str, help="location of config.json")
|
||||||
|
parser.add_argument("--wandb", action='store_true', help="use wandb to watch model")
|
||||||
|
parser.add_argument("--wandb_project_name", default='roberta', help="wandb project name")
|
||||||
|
parser.add_argument("--log_interval", default=100, type=int, help="report interval")
|
||||||
|
parser.add_argument("--log_path", type=str, required=True, help="log file which records train step")
|
||||||
|
parser.add_argument("--tensorboard_path", type=str, required=True, help="location of tensorboard file")
|
||||||
|
parser.add_argument("--colossal_config",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="colossal config, which contains zero config and so on")
|
||||||
|
parser.add_argument("--ckpt_path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="location of saving checkpoint, which contains model and optimizer")
|
||||||
|
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
|
||||||
|
parser.add_argument('--vscode_debug', action='store_true', help="use vscode to debug")
|
||||||
|
parser.add_argument('--load_pretrain_model', default='', type=str, help="location of model's checkpoin")
|
||||||
|
parser.add_argument(
|
||||||
|
'--load_optimizer_lr',
|
||||||
|
default='',
|
||||||
|
type=str,
|
||||||
|
help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step")
|
||||||
|
parser.add_argument('--resume_train', action='store_true', help="whether resume training from a early checkpoint")
|
||||||
|
parser.add_argument('--mlm', default='bert', type=str, help="model type, bert or deberta")
|
||||||
|
parser.add_argument('--checkpoint_activations', action='store_true', help="whether to use gradient checkpointing")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
@ -1,4 +1,5 @@
|
|||||||
class BertDatasetProviderInterface:
|
class BertDatasetProviderInterface:
|
||||||
|
|
||||||
def get_shard(self, index, shuffle=True):
|
def get_shard(self, index, shuffle=True):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
@ -1,9 +1,11 @@
|
|||||||
import os
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from utils.global_vars import get_timers, get_tensorboard_writer
|
from utils.global_vars import get_tensorboard_writer, get_timers
|
||||||
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
|
|
||||||
|
|
||||||
def evaluate(model, args, logger, global_step, criterion):
|
def evaluate(model, args, logger, global_step, criterion):
|
||||||
evaluate_dataset_provider = NvidiaBertDatasetProvider(args, evaluate=True)
|
evaluate_dataset_provider = NvidiaBertDatasetProvider(args, evaluate=True)
|
||||||
@ -20,16 +22,19 @@ def evaluate(model, args, logger, global_step, criterion):
|
|||||||
|
|
||||||
for shard in range(start_shard, len(os.listdir(args.eval_data_path_prefix))):
|
for shard in range(start_shard, len(os.listdir(args.eval_data_path_prefix))):
|
||||||
|
|
||||||
timers('eval_shard_time').start()
|
timers('eval_shard_time').start()
|
||||||
|
|
||||||
dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard)
|
dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard)
|
||||||
# evaluate_dataset_provider.prefetch_shard(shard + 1)
|
# evaluate_dataset_provider.prefetch_shard(shard + 1)
|
||||||
if torch.distributed.get_rank() == 0:
|
if torch.distributed.get_rank() == 0:
|
||||||
iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), colour='MAGENTA', smoothing=1)
|
iterator_data = tqdm(enumerate(dataset_iterator),
|
||||||
|
total=(total_length // args.eval_micro_batch_size_per_gpu // world_size),
|
||||||
|
colour='MAGENTA',
|
||||||
|
smoothing=1)
|
||||||
else:
|
else:
|
||||||
iterator_data = enumerate(dataset_iterator)
|
iterator_data = enumerate(dataset_iterator)
|
||||||
|
|
||||||
for step, batch_data in iterator_data: #tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1):
|
for step, batch_data in iterator_data: #tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1):
|
||||||
|
|
||||||
# batch_data = pretrain_dataset_provider.get_batch(batch_index)
|
# batch_data = pretrain_dataset_provider.get_batch(batch_index)
|
||||||
eval_step += 1
|
eval_step += 1
|
||||||
@ -40,8 +45,8 @@ def evaluate(model, args, logger, global_step, criterion):
|
|||||||
# nsp_label = batch_data[5].cuda()
|
# nsp_label = batch_data[5].cuda()
|
||||||
|
|
||||||
output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
loss = criterion(output.logits, mlm_label)#prediction_scores
|
loss = criterion(output.logits, mlm_label) #prediction_scores
|
||||||
evaluate_dataset_provider.prefetch_batch()
|
evaluate_dataset_provider.prefetch_batch()
|
||||||
|
|
||||||
eval_loss += loss.float().item()
|
eval_loss += loss.float().item()
|
||||||
@ -54,10 +59,10 @@ def evaluate(model, args, logger, global_step, criterion):
|
|||||||
if args.wandb and torch.distributed.get_rank() == 0:
|
if args.wandb and torch.distributed.get_rank() == 0:
|
||||||
tensorboard_log = get_tensorboard_writer()
|
tensorboard_log = get_tensorboard_writer()
|
||||||
tensorboard_log.log_eval({
|
tensorboard_log.log_eval({
|
||||||
'loss': cur_loss,
|
'loss': cur_loss,
|
||||||
'ppl': ppl,
|
'ppl': ppl,
|
||||||
'mins_batch': elapsed_time_per_iteration
|
'mins_batch': elapsed_time_per_iteration
|
||||||
}, global_step)
|
}, global_step)
|
||||||
|
|
||||||
eval_log_str = f'evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \
|
eval_log_str = f'evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \
|
||||||
f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}'
|
f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}'
|
||||||
@ -68,4 +73,4 @@ def evaluate(model, args, logger, global_step, criterion):
|
|||||||
|
|
||||||
evaluate_dataset_provider.release_shard()
|
evaluate_dataset_provider.release_shard()
|
||||||
model.train()
|
model.train()
|
||||||
return cur_loss
|
return cur_loss
|
@ -13,5 +13,5 @@ class LossForPretraining(torch.nn.Module):
|
|||||||
def forward(self, prediction_scores, masked_lm_labels, next_sentence_labels=None):
|
def forward(self, prediction_scores, masked_lm_labels, next_sentence_labels=None):
|
||||||
masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1))
|
masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1))
|
||||||
# next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1))
|
# next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1))
|
||||||
total_loss = masked_lm_loss #+ next_sentence_loss
|
total_loss = masked_lm_loss #+ next_sentence_loss
|
||||||
return total_loss
|
return total_loss
|
@ -15,7 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch BERT model."""
|
"""PyTorch BERT model."""
|
||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
@ -27,7 +26,6 @@ import torch.utils.checkpoint
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
@ -41,8 +39,9 @@ from transformers.modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from transformers.models.bert.configuration_bert import BertConfig
|
||||||
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@ -50,8 +49,6 @@ from transformers.utils import (
|
|||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from transformers.models.bert.configuration_bert import BertConfig
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@ -62,8 +59,7 @@ _TOKENIZER_FOR_DOC = "BertTokenizer"
|
|||||||
# TokenClassification docstring
|
# TokenClassification docstring
|
||||||
_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
|
_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
|
||||||
_TOKEN_CLASS_EXPECTED_OUTPUT = (
|
_TOKEN_CLASS_EXPECTED_OUTPUT = (
|
||||||
"['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] "
|
"['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] ")
|
||||||
)
|
|
||||||
_TOKEN_CLASS_EXPECTED_LOSS = 0.01
|
_TOKEN_CLASS_EXPECTED_LOSS = 0.01
|
||||||
|
|
||||||
# QuestionAnswering docstring
|
# QuestionAnswering docstring
|
||||||
@ -78,7 +74,6 @@ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-pol
|
|||||||
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
|
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
|
||||||
_SEQ_CLASS_EXPECTED_LOSS = 0.01
|
_SEQ_CLASS_EXPECTED_LOSS = 0.01
|
||||||
|
|
||||||
|
|
||||||
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"bert-base-uncased",
|
"bert-base-uncased",
|
||||||
"bert-large-uncased",
|
"bert-large-uncased",
|
||||||
@ -114,10 +109,8 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error(
|
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||||
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||||
"https://www.tensorflow.org/install/ for installation instructions."
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||||
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
||||||
@ -135,10 +128,8 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|||||||
name = name.split("/")
|
name = name.split("/")
|
||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
# which are not required for using pretrained model
|
# which are not required for using pretrained model
|
||||||
if any(
|
if any(n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
||||||
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
for n in name):
|
||||||
for n in name
|
|
||||||
):
|
|
||||||
logger.info(f"Skipping {'/'.join(name)}")
|
logger.info(f"Skipping {'/'.join(name)}")
|
||||||
continue
|
continue
|
||||||
pointer = model
|
pointer = model
|
||||||
@ -218,7 +209,7 @@ class BertEmbeddings(nn.Module):
|
|||||||
seq_length = input_shape[1]
|
seq_length = input_shape[1]
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
position_ids = self.position_ids[:, past_key_values_length:seq_length + past_key_values_length]
|
||||||
|
|
||||||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
||||||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
||||||
@ -245,13 +236,12 @@ class BertEmbeddings(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertSelfAttention(nn.Module):
|
class BertSelfAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config, position_embedding_type=None):
|
def __init__(self, config, position_embedding_type=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||||
raise ValueError(
|
raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
||||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
f"heads ({config.num_attention_heads})")
|
||||||
f"heads ({config.num_attention_heads})"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||||
@ -262,9 +252,7 @@ class BertSelfAttention(nn.Module):
|
|||||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
|
||||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||||
self.position_embedding_type = position_embedding_type or getattr(
|
self.position_embedding_type = position_embedding_type or getattr(config, "position_embedding_type", "absolute")
|
||||||
config, "position_embedding_type", "absolute"
|
|
||||||
)
|
|
||||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||||
self.max_position_embeddings = config.max_position_embeddings
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||||
@ -332,7 +320,7 @@ class BertSelfAttention(nn.Module):
|
|||||||
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||||
distance = position_ids_l - position_ids_r
|
distance = position_ids_l - position_ids_r
|
||||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||||
|
|
||||||
if self.position_embedding_type == "relative_key":
|
if self.position_embedding_type == "relative_key":
|
||||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||||
@ -372,6 +360,7 @@ class BertSelfAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertSelfOutput(nn.Module):
|
class BertSelfOutput(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
@ -386,6 +375,7 @@ class BertSelfOutput(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertAttention(nn.Module):
|
class BertAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config, position_embedding_type=None):
|
def __init__(self, config, position_embedding_type=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
|
self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||||
@ -395,9 +385,8 @@ class BertAttention(nn.Module):
|
|||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
heads, index = find_pruneable_heads_and_indices(
|
heads, index = find_pruneable_heads_and_indices(heads, self.self.num_attention_heads,
|
||||||
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
self.self.attention_head_size, self.pruned_heads)
|
||||||
)
|
|
||||||
|
|
||||||
# Prune linear layers
|
# Prune linear layers
|
||||||
self.self.query = prune_linear_layer(self.self.query, index)
|
self.self.query = prune_linear_layer(self.self.query, index)
|
||||||
@ -430,11 +419,12 @@ class BertAttention(nn.Module):
|
|||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = self.output(self_outputs[0], hidden_states)
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class BertIntermediate(nn.Module):
|
class BertIntermediate(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
@ -450,6 +440,7 @@ class BertIntermediate(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertOutput(nn.Module):
|
class BertOutput(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
@ -464,6 +455,7 @@ class BertOutput(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertLayer(nn.Module):
|
class BertLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
@ -504,15 +496,14 @@ class BertLayer(nn.Module):
|
|||||||
outputs = self_attention_outputs[1:-1]
|
outputs = self_attention_outputs[1:-1]
|
||||||
present_key_value = self_attention_outputs[-1]
|
present_key_value = self_attention_outputs[-1]
|
||||||
else:
|
else:
|
||||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
cross_attn_present_key_value = None
|
cross_attn_present_key_value = None
|
||||||
if self.is_decoder and encoder_hidden_states is not None:
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
if not hasattr(self, "crossattention"):
|
if not hasattr(self, "crossattention"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||||
" by setting `config.add_cross_attention=True`"
|
" by setting `config.add_cross_attention=True`")
|
||||||
)
|
|
||||||
|
|
||||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
@ -526,15 +517,14 @@ class BertLayer(nn.Module):
|
|||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = cross_attention_outputs[0]
|
attention_output = cross_attention_outputs[0]
|
||||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||||
present_key_value = present_key_value + cross_attn_present_key_value
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||||||
|
|
||||||
layer_output = apply_chunking_to_forward(
|
layer_output = apply_chunking_to_forward(self.feed_forward_chunk, self.chunk_size_feed_forward,
|
||||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
self.seq_len_dim, attention_output)
|
||||||
)
|
|
||||||
outputs = (layer_output,) + outputs
|
outputs = (layer_output,) + outputs
|
||||||
|
|
||||||
# if decoder, return the attn key/values as the last output
|
# if decoder, return the attn key/values as the last output
|
||||||
@ -550,6 +540,7 @@ class BertLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertEncoder(nn.Module):
|
class BertEncoder(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -585,11 +576,11 @@ class BertEncoder(nn.Module):
|
|||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||||
)
|
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
|
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
return module(*inputs, past_key_value, output_attentions)
|
return module(*inputs, past_key_value, output_attentions)
|
||||||
|
|
||||||
@ -626,17 +617,13 @@ class BertEncoder(nn.Module):
|
|||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(v for v in [
|
||||||
v
|
hidden_states,
|
||||||
for v in [
|
next_decoder_cache,
|
||||||
hidden_states,
|
all_hidden_states,
|
||||||
next_decoder_cache,
|
all_self_attentions,
|
||||||
all_hidden_states,
|
all_cross_attentions,
|
||||||
all_self_attentions,
|
] if v is not None)
|
||||||
all_cross_attentions,
|
|
||||||
]
|
|
||||||
if v is not None
|
|
||||||
)
|
|
||||||
return BaseModelOutputWithPastAndCrossAttentions(
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=next_decoder_cache,
|
past_key_values=next_decoder_cache,
|
||||||
@ -647,6 +634,7 @@ class BertEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertPooler(nn.Module):
|
class BertPooler(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
@ -662,6 +650,7 @@ class BertPooler(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertPredictionHeadTransform(nn.Module):
|
class BertPredictionHeadTransform(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
@ -679,6 +668,7 @@ class BertPredictionHeadTransform(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertLMPredictionHead(nn.Module):
|
class BertLMPredictionHead(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.transform = BertPredictionHeadTransform(config)
|
self.transform = BertPredictionHeadTransform(config)
|
||||||
@ -699,6 +689,7 @@ class BertLMPredictionHead(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertOnlyMLMHead(nn.Module):
|
class BertOnlyMLMHead(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.predictions = BertLMPredictionHead(config)
|
self.predictions = BertLMPredictionHead(config)
|
||||||
@ -709,6 +700,7 @@ class BertOnlyMLMHead(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertOnlyNSPHead(nn.Module):
|
class BertOnlyNSPHead(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
||||||
@ -719,6 +711,7 @@ class BertOnlyNSPHead(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertPreTrainingHeads(nn.Module):
|
class BertPreTrainingHeads(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.predictions = BertLMPredictionHead(config)
|
self.predictions = BertLMPredictionHead(config)
|
||||||
@ -950,9 +943,8 @@ class BertModel(BertPreTrainedModel):
|
|||||||
`past_key_values`).
|
`past_key_values`).
|
||||||
"""
|
"""
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (output_hidden_states
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if self.config.is_decoder:
|
if self.config.is_decoder:
|
||||||
@ -1051,6 +1043,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class BertForPreTraining(BertPreTrainedModel):
|
class BertForPreTraining(BertPreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@ -1151,9 +1144,8 @@ class BertForPreTraining(BertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings("""Bert Model with a `language modeling` head on top for CLM fine-tuning.""",
|
||||||
"""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
|
BERT_START_DOCSTRING)
|
||||||
)
|
|
||||||
class BertLMHeadModel(BertPreTrainedModel):
|
class BertLMHeadModel(BertPreTrainedModel):
|
||||||
|
|
||||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
@ -1298,10 +1290,8 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if config.is_decoder:
|
if config.is_decoder:
|
||||||
logger.warning(
|
logger.warning("If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
|
||||||
"If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
|
"bi-directional self-attention.")
|
||||||
"bi-directional self-attention."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.bert = BertModel(config, add_pooling_layer=False)
|
self.bert = BertModel(config, add_pooling_layer=False)
|
||||||
self.cls = BertOnlyMLMHead(config)
|
self.cls = BertOnlyMLMHead(config)
|
||||||
@ -1367,7 +1357,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
|
|
||||||
masked_lm_loss = None
|
masked_lm_loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
||||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
@ -1390,9 +1380,10 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
raise ValueError("The PAD token should be defined for generation")
|
raise ValueError("The PAD token should be defined for generation")
|
||||||
|
|
||||||
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
|
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
|
||||||
dummy_token = torch.full(
|
dummy_token = torch.full((effective_batch_size, 1),
|
||||||
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
|
self.config.pad_token_id,
|
||||||
)
|
dtype=torch.long,
|
||||||
|
device=input_ids.device)
|
||||||
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
||||||
|
|
||||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||||
@ -1403,6 +1394,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class BertForNextSentencePrediction(BertPreTrainedModel):
|
class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@ -1508,15 +1500,15 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
|||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class BertForSequenceClassification(BertPreTrainedModel):
|
class BertForSequenceClassification(BertPreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.bert = BertModel(config)
|
self.bert = BertModel(config)
|
||||||
classifier_dropout = (
|
classifier_dropout = (config.classifier_dropout
|
||||||
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
if config.classifier_dropout is not None else config.hidden_dropout_prob)
|
||||||
)
|
|
||||||
self.dropout = nn.Dropout(classifier_dropout)
|
self.dropout = nn.Dropout(classifier_dropout)
|
||||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
@ -1612,13 +1604,13 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class BertForMultipleChoice(BertPreTrainedModel):
|
class BertForMultipleChoice(BertPreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.bert = BertModel(config)
|
self.bert = BertModel(config)
|
||||||
classifier_dropout = (
|
classifier_dropout = (config.classifier_dropout
|
||||||
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
if config.classifier_dropout is not None else config.hidden_dropout_prob)
|
||||||
)
|
|
||||||
self.dropout = nn.Dropout(classifier_dropout)
|
self.dropout = nn.Dropout(classifier_dropout)
|
||||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||||
|
|
||||||
@ -1658,11 +1650,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
|||||||
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||||
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||||
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||||
inputs_embeds = (
|
inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
||||||
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
if inputs_embeds is not None else None)
|
||||||
if inputs_embeds is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = self.bert(
|
outputs = self.bert(
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -1715,9 +1704,8 @@ class BertForTokenClassification(BertPreTrainedModel):
|
|||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
self.bert = BertModel(config, add_pooling_layer=False)
|
self.bert = BertModel(config, add_pooling_layer=False)
|
||||||
classifier_dropout = (
|
classifier_dropout = (config.classifier_dropout
|
||||||
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
if config.classifier_dropout is not None else config.hidden_dropout_prob)
|
||||||
)
|
|
||||||
self.dropout = nn.Dropout(classifier_dropout)
|
self.dropout = nn.Dropout(classifier_dropout)
|
||||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
@ -23,7 +23,7 @@ import torch
|
|||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
||||||
|
from transformers import FillMaskPipeline, T5ForConditionalGeneration, T5Tokenizer
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
@ -34,10 +34,14 @@ from transformers.modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.pytorch_utils import softmax_backward_data
|
|
||||||
from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
|
||||||
from transformers.models.deberta_v2.configuration_deberta_v2 import DebertaV2Config
|
from transformers.models.deberta_v2.configuration_deberta_v2 import DebertaV2Config
|
||||||
from transformers import T5Tokenizer, T5ForConditionalGeneration, FillMaskPipeline
|
from transformers.pytorch_utils import softmax_backward_data
|
||||||
|
from transformers.utils import (
|
||||||
|
add_code_sample_docstrings,
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@ -55,6 +59,7 @@ DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
|
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.ContextPooler
|
# Copied from transformers.models.deberta.modeling_deberta.ContextPooler
|
||||||
class ContextPooler(nn.Module):
|
class ContextPooler(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
|
self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
|
||||||
@ -133,15 +138,15 @@ class XSoftmax(torch.autograd.Function):
|
|||||||
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
|
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
|
||||||
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
|
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
|
||||||
)
|
)
|
||||||
output = masked_fill(
|
output = masked_fill(g, self, r_mask,
|
||||||
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
|
g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)))
|
||||||
)
|
|
||||||
output = softmax(g, output, dim)
|
output = softmax(g, output, dim)
|
||||||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
|
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
|
||||||
class DropoutContext(object):
|
class DropoutContext(object):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.dropout = 0
|
self.dropout = 0
|
||||||
self.mask = None
|
self.mask = None
|
||||||
@ -244,6 +249,7 @@ class StableDropout(nn.Module):
|
|||||||
|
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm
|
# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm
|
||||||
class DebertaV2SelfOutput(nn.Module):
|
class DebertaV2SelfOutput(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
@ -259,6 +265,7 @@ class DebertaV2SelfOutput(nn.Module):
|
|||||||
|
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2
|
# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2
|
||||||
class DebertaV2Attention(nn.Module):
|
class DebertaV2Attention(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self = DisentangledSelfAttention(config)
|
self.self = DisentangledSelfAttention(config)
|
||||||
@ -296,6 +303,7 @@ class DebertaV2Attention(nn.Module):
|
|||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2
|
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2
|
||||||
class DebertaV2Intermediate(nn.Module):
|
class DebertaV2Intermediate(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
@ -312,6 +320,7 @@ class DebertaV2Intermediate(nn.Module):
|
|||||||
|
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm
|
# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm
|
||||||
class DebertaV2Output(nn.Module):
|
class DebertaV2Output(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
@ -328,6 +337,7 @@ class DebertaV2Output(nn.Module):
|
|||||||
|
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
|
# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
|
||||||
class DebertaV2Layer(nn.Module):
|
class DebertaV2Layer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attention = DebertaV2Attention(config)
|
self.attention = DebertaV2Attention(config)
|
||||||
@ -362,14 +372,17 @@ class DebertaV2Layer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ConvLayer(nn.Module):
|
class ConvLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
kernel_size = getattr(config, "conv_kernel_size", 3)
|
kernel_size = getattr(config, "conv_kernel_size", 3)
|
||||||
groups = getattr(config, "conv_groups", 1)
|
groups = getattr(config, "conv_groups", 1)
|
||||||
self.conv_act = getattr(config, "conv_act", "tanh")
|
self.conv_act = getattr(config, "conv_act", "tanh")
|
||||||
self.conv = nn.Conv1d(
|
self.conv = nn.Conv1d(config.hidden_size,
|
||||||
config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups
|
config.hidden_size,
|
||||||
)
|
kernel_size,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
groups=groups)
|
||||||
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
|
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
|
||||||
self.dropout = StableDropout(config.hidden_dropout_prob)
|
self.dropout = StableDropout(config.hidden_dropout_prob)
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -452,9 +465,10 @@ class DebertaV2Encoder(nn.Module):
|
|||||||
def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
|
def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
|
||||||
if self.relative_attention and relative_pos is None:
|
if self.relative_attention and relative_pos is None:
|
||||||
q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
|
q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
|
||||||
relative_pos = build_relative_position(
|
relative_pos = build_relative_position(q,
|
||||||
q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions
|
hidden_states.size(-2),
|
||||||
)
|
bucket_size=self.position_buckets,
|
||||||
|
max_position=self.max_relative_positions)
|
||||||
return relative_pos
|
return relative_pos
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -491,6 +505,7 @@ class DebertaV2Encoder(nn.Module):
|
|||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
|
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
return module(*inputs, output_attentions)
|
return module(*inputs, output_attentions)
|
||||||
|
|
||||||
@ -535,9 +550,9 @@ class DebertaV2Encoder(nn.Module):
|
|||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
|
return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
|
||||||
return BaseModelOutput(
|
return BaseModelOutput(last_hidden_state=output_states,
|
||||||
last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions
|
hidden_states=all_hidden_states,
|
||||||
)
|
attentions=all_attentions)
|
||||||
|
|
||||||
|
|
||||||
def make_log_bucket_position(relative_pos, bucket_size, max_position):
|
def make_log_bucket_position(relative_pos, bucket_size, max_position):
|
||||||
@ -610,10 +625,8 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if config.hidden_size % config.num_attention_heads != 0:
|
if config.hidden_size % config.num_attention_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
||||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
f"heads ({config.num_attention_heads})")
|
||||||
f"heads ({config.num_attention_heads})"
|
|
||||||
)
|
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
_attention_head_size = config.hidden_size // config.num_attention_heads
|
_attention_head_size = config.hidden_size // config.num_attention_heads
|
||||||
self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size)
|
self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size)
|
||||||
@ -706,28 +719,22 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
|
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
|
||||||
if self.relative_attention:
|
if self.relative_attention:
|
||||||
rel_embeddings = self.pos_dropout(rel_embeddings)
|
rel_embeddings = self.pos_dropout(rel_embeddings)
|
||||||
rel_att = self.disentangled_attention_bias(
|
rel_att = self.disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings,
|
||||||
query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
|
scale_factor)
|
||||||
)
|
|
||||||
|
|
||||||
if rel_att is not None:
|
if rel_att is not None:
|
||||||
attention_scores = attention_scores + rel_att
|
attention_scores = attention_scores + rel_att
|
||||||
attention_scores = attention_scores
|
attention_scores = attention_scores
|
||||||
attention_scores = attention_scores.view(
|
attention_scores = attention_scores.view(-1, self.num_attention_heads, attention_scores.size(-2),
|
||||||
-1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
|
attention_scores.size(-1))
|
||||||
)
|
|
||||||
|
|
||||||
# bsz x height x length x dimension
|
# bsz x height x length x dimension
|
||||||
attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
|
attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
|
||||||
attention_probs = self.dropout(attention_probs)
|
attention_probs = self.dropout(attention_probs)
|
||||||
context_layer = torch.bmm(
|
context_layer = torch.bmm(attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)),
|
||||||
attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer
|
value_layer)
|
||||||
)
|
context_layer = (context_layer.view(-1, self.num_attention_heads, context_layer.size(-2),
|
||||||
context_layer = (
|
context_layer.size(-1)).permute(0, 2, 1, 3).contiguous())
|
||||||
context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1))
|
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
||||||
context_layer = context_layer.view(new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
@ -738,9 +745,10 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
|
def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
|
||||||
if relative_pos is None:
|
if relative_pos is None:
|
||||||
q = query_layer.size(-2)
|
q = query_layer.size(-2)
|
||||||
relative_pos = build_relative_position(
|
relative_pos = build_relative_position(q,
|
||||||
q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions
|
key_layer.size(-2),
|
||||||
)
|
bucket_size=self.position_buckets,
|
||||||
|
max_position=self.max_relative_positions)
|
||||||
if relative_pos.dim() == 2:
|
if relative_pos.dim() == 2:
|
||||||
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
|
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
|
||||||
elif relative_pos.dim() == 3:
|
elif relative_pos.dim() == 3:
|
||||||
@ -758,25 +766,22 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
# rel_embeddings = rel_embeddings.unsqueeze(0)
|
# rel_embeddings = rel_embeddings.unsqueeze(0)
|
||||||
# rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
|
# rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
|
||||||
if self.share_att_key:
|
if self.share_att_key:
|
||||||
pos_query_layer = self.transpose_for_scores(
|
pos_query_layer = self.transpose_for_scores(self.query_proj(rel_embeddings),
|
||||||
self.query_proj(rel_embeddings), self.num_attention_heads
|
self.num_attention_heads).repeat(
|
||||||
).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
|
query_layer.size(0) // self.num_attention_heads, 1, 1)
|
||||||
pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(
|
pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(
|
||||||
query_layer.size(0) // self.num_attention_heads, 1, 1
|
query_layer.size(0) // self.num_attention_heads, 1, 1)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if "c2p" in self.pos_att_type:
|
if "c2p" in self.pos_att_type:
|
||||||
pos_key_layer = self.transpose_for_scores(
|
pos_key_layer = self.transpose_for_scores(self.pos_key_proj(rel_embeddings),
|
||||||
self.pos_key_proj(rel_embeddings), self.num_attention_heads
|
self.num_attention_heads).repeat(
|
||||||
).repeat(
|
query_layer.size(0) // self.num_attention_heads, 1,
|
||||||
query_layer.size(0) // self.num_attention_heads, 1, 1
|
1) # .split(self.all_head_size, dim=-1)
|
||||||
) # .split(self.all_head_size, dim=-1)
|
|
||||||
if "p2c" in self.pos_att_type:
|
if "p2c" in self.pos_att_type:
|
||||||
pos_query_layer = self.transpose_for_scores(
|
pos_query_layer = self.transpose_for_scores(self.pos_query_proj(rel_embeddings),
|
||||||
self.pos_query_proj(rel_embeddings), self.num_attention_heads
|
self.num_attention_heads).repeat(
|
||||||
).repeat(
|
query_layer.size(0) // self.num_attention_heads, 1,
|
||||||
query_layer.size(0) // self.num_attention_heads, 1, 1
|
1) # .split(self.all_head_size, dim=-1)
|
||||||
) # .split(self.all_head_size, dim=-1)
|
|
||||||
|
|
||||||
score = 0
|
score = 0
|
||||||
# content->position
|
# content->position
|
||||||
@ -787,7 +792,9 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
c2p_att = torch.gather(
|
c2p_att = torch.gather(
|
||||||
c2p_att,
|
c2p_att,
|
||||||
dim=-1,
|
dim=-1,
|
||||||
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
|
index=c2p_pos.squeeze(0).expand([query_layer.size(0),
|
||||||
|
query_layer.size(1),
|
||||||
|
relative_pos.size(-1)]),
|
||||||
)
|
)
|
||||||
score += c2p_att / scale
|
score += c2p_att / scale
|
||||||
|
|
||||||
@ -810,7 +817,9 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
p2c_att = torch.gather(
|
p2c_att = torch.gather(
|
||||||
p2c_att,
|
p2c_att,
|
||||||
dim=-1,
|
dim=-1,
|
||||||
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
|
index=p2c_pos.squeeze(0).expand([query_layer.size(0),
|
||||||
|
key_layer.size(-2),
|
||||||
|
key_layer.size(-2)]),
|
||||||
).transpose(-1, -2)
|
).transpose(-1, -2)
|
||||||
score += p2c_att / scale
|
score += p2c_att / scale
|
||||||
|
|
||||||
@ -990,6 +999,7 @@ DEBERTA_INPUTS_DOCSTRING = r"""
|
|||||||
)
|
)
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2
|
# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2
|
||||||
class DebertaV2Model(DebertaV2PreTrainedModel):
|
class DebertaV2Model(DebertaV2PreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@ -1032,9 +1042,8 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
|
|||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, BaseModelOutput]:
|
) -> Union[Tuple, BaseModelOutput]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (output_hidden_states
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
@ -1091,7 +1100,7 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
|
|||||||
sequence_output = encoded_layers[-1]
|
sequence_output = encoded_layers[-1]
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :]
|
return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2):]
|
||||||
|
|
||||||
return BaseModelOutput(
|
return BaseModelOutput(
|
||||||
last_hidden_state=sequence_output,
|
last_hidden_state=sequence_output,
|
||||||
@ -1165,7 +1174,7 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
|
|||||||
|
|
||||||
masked_lm_loss = None
|
masked_lm_loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
||||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
@ -1182,6 +1191,7 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
|
|||||||
|
|
||||||
# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta
|
# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta
|
||||||
class DebertaV2PredictionHeadTransform(nn.Module):
|
class DebertaV2PredictionHeadTransform(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
@ -1200,6 +1210,7 @@ class DebertaV2PredictionHeadTransform(nn.Module):
|
|||||||
|
|
||||||
# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta
|
# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta
|
||||||
class DebertaV2LMPredictionHead(nn.Module):
|
class DebertaV2LMPredictionHead(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.transform = DebertaV2PredictionHeadTransform(config)
|
self.transform = DebertaV2PredictionHeadTransform(config)
|
||||||
@ -1221,6 +1232,7 @@ class DebertaV2LMPredictionHead(nn.Module):
|
|||||||
|
|
||||||
# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta
|
# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta
|
||||||
class DebertaV2OnlyMLMHead(nn.Module):
|
class DebertaV2OnlyMLMHead(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.predictions = DebertaV2LMPredictionHead(config)
|
self.predictions = DebertaV2LMPredictionHead(config)
|
||||||
@ -1239,6 +1251,7 @@ class DebertaV2OnlyMLMHead(nn.Module):
|
|||||||
)
|
)
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2
|
# Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2
|
||||||
class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
|
class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@ -1318,9 +1331,8 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
|
|||||||
label_index = (labels >= 0).nonzero()
|
label_index = (labels >= 0).nonzero()
|
||||||
labels = labels.long()
|
labels = labels.long()
|
||||||
if label_index.size(0) > 0:
|
if label_index.size(0) > 0:
|
||||||
labeled_logits = torch.gather(
|
labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0),
|
||||||
logits, 0, label_index.expand(label_index.size(0), logits.size(1))
|
logits.size(1)))
|
||||||
)
|
|
||||||
labels = torch.gather(labels, 0, label_index.view(-1))
|
labels = torch.gather(labels, 0, label_index.view(-1))
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
|
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
|
||||||
@ -1345,9 +1357,10 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
|
|||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return SequenceClassifierOutput(
|
return SequenceClassifierOutput(loss=loss,
|
||||||
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
logits=logits,
|
||||||
)
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@ -1422,9 +1435,10 @@ class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
|
|||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return TokenClassifierOutput(
|
return TokenClassifierOutput(loss=loss,
|
||||||
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
logits=logits,
|
||||||
)
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@ -1536,6 +1550,7 @@ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
|
|||||||
DEBERTA_START_DOCSTRING,
|
DEBERTA_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel):
|
class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@ -1591,11 +1606,8 @@ class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel):
|
|||||||
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||||
flat_inputs_embeds = (
|
flat_inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
||||||
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
if inputs_embeds is not None else None)
|
||||||
if inputs_embeds is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = self.deberta(
|
outputs = self.deberta(
|
||||||
flat_input_ids,
|
flat_input_ids,
|
@ -1,24 +1,25 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import h5py
|
|
||||||
import logging
|
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
|
||||||
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.utils.data import DataLoader, Dataset
|
|
||||||
from torch.utils.data.sampler import RandomSampler
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
|
|
||||||
from bert_dataset_provider import BertDatasetProviderInterface
|
from bert_dataset_provider import BertDatasetProviderInterface
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from torch.utils.data.sampler import RandomSampler
|
||||||
|
|
||||||
import colossalai.utils as utils
|
import colossalai.utils as utils
|
||||||
|
|
||||||
|
|
||||||
# Workaround because python functions are not picklable
|
# Workaround because python functions are not picklable
|
||||||
class WorkerInitObj(object):
|
class WorkerInitObj(object):
|
||||||
|
|
||||||
def __init__(self, seed):
|
def __init__(self, seed):
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
@ -27,29 +28,25 @@ class WorkerInitObj(object):
|
|||||||
random.seed(self.seed + id)
|
random.seed(self.seed + id)
|
||||||
|
|
||||||
|
|
||||||
def create_pretraining_dataset(input_file, max_predictions_per_seq,
|
def create_pretraining_dataset(input_file, max_predictions_per_seq, num_workers, train_batch_size, worker_init,
|
||||||
num_workers, train_batch_size, worker_init,
|
|
||||||
data_sampler):
|
data_sampler):
|
||||||
train_data = pretraining_dataset(
|
train_data = pretraining_dataset(input_file=input_file, max_predictions_per_seq=max_predictions_per_seq)
|
||||||
input_file=input_file, max_predictions_per_seq=max_predictions_per_seq)
|
|
||||||
train_dataloader = DataLoader(train_data,
|
train_dataloader = DataLoader(train_data,
|
||||||
sampler=data_sampler(train_data),
|
sampler=data_sampler(train_data),
|
||||||
batch_size=train_batch_size,
|
batch_size=train_batch_size,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
worker_init_fn=worker_init,
|
worker_init_fn=worker_init,
|
||||||
pin_memory=True
|
pin_memory=True)
|
||||||
)
|
|
||||||
return train_dataloader, len(train_data)
|
return train_dataloader, len(train_data)
|
||||||
|
|
||||||
|
|
||||||
class pretraining_dataset(Dataset):
|
class pretraining_dataset(Dataset):
|
||||||
|
|
||||||
def __init__(self, input_file, max_predictions_per_seq):
|
def __init__(self, input_file, max_predictions_per_seq):
|
||||||
self.input_file = input_file
|
self.input_file = input_file
|
||||||
self.max_predictions_per_seq = max_predictions_per_seq
|
self.max_predictions_per_seq = max_predictions_per_seq
|
||||||
f = h5py.File(input_file, "r")
|
f = h5py.File(input_file, "r")
|
||||||
keys = [
|
keys = ['input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions']
|
||||||
'input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions'
|
|
||||||
]
|
|
||||||
self.inputs = [np.asarray(f[key][:]) for key in keys]
|
self.inputs = [np.asarray(f[key][:]) for key in keys]
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
@ -59,21 +56,16 @@ class pretraining_dataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
|
||||||
[
|
[input_ids, input_mask, segment_ids, masked_lm_labels] = [
|
||||||
input_ids, input_mask, segment_ids, masked_lm_labels
|
torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else torch.from_numpy(
|
||||||
] = [
|
np.asarray(input[index].astype(np.int64))) for indice, input in enumerate(self.inputs)
|
||||||
torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else
|
|
||||||
torch.from_numpy(np.asarray(input[index].astype(np.int64)))
|
|
||||||
for indice, input in enumerate(self.inputs)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return [
|
return [input_ids, input_mask, segment_ids, masked_lm_labels]
|
||||||
input_ids, input_mask,
|
|
||||||
segment_ids, masked_lm_labels
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
|
class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
|
||||||
|
|
||||||
def __init__(self, args, evaluate=False):
|
def __init__(self, args, evaluate=False):
|
||||||
self.num_workers = args.num_workers
|
self.num_workers = args.num_workers
|
||||||
self.max_seq_length = args.max_seq_length
|
self.max_seq_length = args.max_seq_length
|
||||||
@ -85,22 +77,24 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
|
|||||||
else:
|
else:
|
||||||
self.train_micro_batch_size_per_gpu = args.eval_micro_batch_size_per_gpu
|
self.train_micro_batch_size_per_gpu = args.eval_micro_batch_size_per_gpu
|
||||||
self.logger = args.logger
|
self.logger = args.logger
|
||||||
|
|
||||||
self.global_rank = dist.get_rank()
|
self.global_rank = dist.get_rank()
|
||||||
self.world_size = dist.get_world_size()
|
self.world_size = dist.get_world_size()
|
||||||
|
|
||||||
# Initialize dataset files
|
# Initialize dataset files
|
||||||
if not evaluate:
|
if not evaluate:
|
||||||
self.dataset_files = [
|
self.dataset_files = [
|
||||||
os.path.join(args.data_path_prefix, f) for f in os.listdir(args.data_path_prefix) if
|
os.path.join(args.data_path_prefix, f)
|
||||||
os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f
|
for f in os.listdir(args.data_path_prefix)
|
||||||
|
if os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
self.dataset_files = [
|
self.dataset_files = [
|
||||||
os.path.join(args.eval_data_path_prefix, f) for f in os.listdir(args.eval_data_path_prefix) if
|
os.path.join(args.eval_data_path_prefix, f)
|
||||||
os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f
|
for f in os.listdir(args.eval_data_path_prefix)
|
||||||
|
if os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f
|
||||||
]
|
]
|
||||||
|
|
||||||
self.dataset_files.sort()
|
self.dataset_files.sort()
|
||||||
# random.shuffle(self.dataset_files)
|
# random.shuffle(self.dataset_files)
|
||||||
self.num_files = len(self.dataset_files)
|
self.num_files = len(self.dataset_files)
|
||||||
@ -114,9 +108,7 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
|
|||||||
self.shuffle = True
|
self.shuffle = True
|
||||||
|
|
||||||
if self.global_rank == 0:
|
if self.global_rank == 0:
|
||||||
self.logger.info(
|
self.logger.info(f"NvidiaBertDatasetProvider - Initialization: num_files = {self.num_files}")
|
||||||
f"NvidiaBertDatasetProvider - Initialization: num_files = {self.num_files}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_shard(self, index):
|
def get_shard(self, index):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@ -130,9 +122,8 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
|
|||||||
worker_init=self.worker_init,
|
worker_init=self.worker_init,
|
||||||
data_sampler=self.data_sampler)
|
data_sampler=self.data_sampler)
|
||||||
else:
|
else:
|
||||||
self.train_dataloader, sample_count = self.dataset_future.result(
|
self.train_dataloader, sample_count = self.dataset_future.result(timeout=None)
|
||||||
timeout=None)
|
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"Data Loading Completed for Pretraining Data from {self.data_file} with {sample_count} samples took {time.time()-start:.2f}s."
|
f"Data Loading Completed for Pretraining Data from {self.data_file} with {sample_count} samples took {time.time()-start:.2f}s."
|
||||||
)
|
)
|
||||||
@ -145,11 +136,9 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
|
|||||||
|
|
||||||
def prefetch_shard(self, index):
|
def prefetch_shard(self, index):
|
||||||
self.data_file = self._get_shard_file(index)
|
self.data_file = self._get_shard_file(index)
|
||||||
self.dataset_future = self.pool.submit(
|
self.dataset_future = self.pool.submit(create_pretraining_dataset, self.data_file, self.max_predictions_per_seq,
|
||||||
create_pretraining_dataset, self.data_file,
|
self.num_workers, self.train_micro_batch_size_per_gpu, self.worker_init,
|
||||||
self.max_predictions_per_seq, self.num_workers,
|
self.data_sampler)
|
||||||
self.train_micro_batch_size_per_gpu, self.worker_init,
|
|
||||||
self.data_sampler)
|
|
||||||
|
|
||||||
def get_batch(self, batch_iter):
|
def get_batch(self, batch_iter):
|
||||||
return batch_iter
|
return batch_iter
|
||||||
@ -179,4 +168,3 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
|
|||||||
indices = torch.randperm(self.num_files, generator=g).tolist()
|
indices = torch.randperm(self.num_files, generator=g).tolist()
|
||||||
new_dataset = [self.dataset_files[i] for i in indices]
|
new_dataset = [self.dataset_files[i] for i in indices]
|
||||||
self.dataset_files = new_dataset
|
self.dataset_files = new_dataset
|
||||||
|
|
@ -1,35 +1,45 @@
|
|||||||
import transformers
|
|
||||||
import logging
|
import logging
|
||||||
from colossalai.nn.lr_scheduler import LinearWarmupLR
|
|
||||||
from transformers import get_linear_schedule_with_warmup
|
|
||||||
from transformers import BertForPreTraining, RobertaForMaskedLM, RobertaConfig
|
|
||||||
from transformers import GPT2Config, GPT2LMHeadModel
|
|
||||||
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
|
||||||
from colossalai.nn.optimizer import FusedAdam, HybridAdam
|
|
||||||
from torch.optim import AdamW
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
import torch
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
sys.path.append(os.getcwd())
|
|
||||||
from model.deberta_v2 import DebertaV2ForMaskedLM
|
|
||||||
from model.bert import BertForMaskedLM
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForMaskedLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
BertForPreTraining,
|
||||||
|
GPT2Config,
|
||||||
|
GPT2LMHeadModel,
|
||||||
|
RobertaConfig,
|
||||||
|
RobertaForMaskedLM,
|
||||||
|
get_linear_schedule_with_warmup,
|
||||||
|
)
|
||||||
|
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.nn.lr_scheduler import LinearWarmupLR
|
||||||
|
from colossalai.nn.optimizer import FusedAdam, HybridAdam
|
||||||
|
|
||||||
|
sys.path.append(os.getcwd())
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from model.bert import BertForMaskedLM
|
||||||
|
from model.deberta_v2 import DebertaV2ForMaskedLM
|
||||||
|
|
||||||
__all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining']
|
__all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining']
|
||||||
|
|
||||||
|
|
||||||
def get_new_state_dict(state_dict, start_index=13):
|
def get_new_state_dict(state_dict, start_index=13):
|
||||||
new_state_dict = OrderedDict()
|
new_state_dict = OrderedDict()
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
name = k[start_index:]
|
name = k[start_index:]
|
||||||
new_state_dict[name] = v
|
new_state_dict[name] = v
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
class LMModel(nn.Module):
|
class LMModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self, model, config, args):
|
def __init__(self, model, config, args):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -58,16 +68,18 @@ def get_model(args, logger):
|
|||||||
if len(args.load_pretrain_model) > 0:
|
if len(args.load_pretrain_model) > 0:
|
||||||
assert os.path.exists(args.load_pretrain_model)
|
assert os.path.exists(args.load_pretrain_model)
|
||||||
# load_checkpoint(args.load_pretrain_model, model, strict=False)
|
# load_checkpoint(args.load_pretrain_model, model, strict=False)
|
||||||
m_state_dict = torch.load(args.load_pretrain_model, map_location=torch.device(f"cuda:{torch.cuda.current_device()}"))
|
m_state_dict = torch.load(args.load_pretrain_model,
|
||||||
|
map_location=torch.device(f"cuda:{torch.cuda.current_device()}"))
|
||||||
# new_state_dict = get_new_state_dict(m_state_dict)
|
# new_state_dict = get_new_state_dict(m_state_dict)
|
||||||
model.load_state_dict(m_state_dict, strict=True) # must insure that every process have identical parameters !!!!!!!
|
model.load_state_dict(m_state_dict,
|
||||||
|
strict=True) # must insure that every process have identical parameters !!!!!!!
|
||||||
logger.info("load model success")
|
logger.info("load model success")
|
||||||
|
|
||||||
numel = sum([p.numel() for p in model.parameters()])
|
numel = sum([p.numel() for p in model.parameters()])
|
||||||
if args.checkpoint_activations:
|
if args.checkpoint_activations:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
# model = LMModel(model, config, args)
|
# model = LMModel(model, config, args)
|
||||||
|
|
||||||
return config, model, numel
|
return config, model, numel
|
||||||
|
|
||||||
|
|
||||||
@ -89,7 +101,10 @@ def get_optimizer(model, lr):
|
|||||||
|
|
||||||
def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1):
|
def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1):
|
||||||
# warmup_steps = int(total_steps * warmup_ratio)
|
# warmup_steps = int(total_steps * warmup_ratio)
|
||||||
lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, last_epoch=last_epoch)
|
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
|
||||||
|
num_warmup_steps=warmup_steps,
|
||||||
|
num_training_steps=total_steps,
|
||||||
|
last_epoch=last_epoch)
|
||||||
# lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps)
|
# lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps)
|
||||||
return lr_scheduler
|
return lr_scheduler
|
||||||
|
|
||||||
@ -103,10 +118,7 @@ def save_ckpt(model, optimizer, lr_scheduler, path, epoch, shard, global_step):
|
|||||||
checkpoint['epoch'] = epoch
|
checkpoint['epoch'] = epoch
|
||||||
checkpoint['shard'] = shard
|
checkpoint['shard'] = shard
|
||||||
checkpoint['global_step'] = global_step
|
checkpoint['global_step'] = global_step
|
||||||
model_state = model.state_dict() #each process must run model.state_dict()
|
model_state = model.state_dict() #each process must run model.state_dict()
|
||||||
if gpc.get_global_rank() == 0:
|
if gpc.get_global_rank() == 0:
|
||||||
torch.save(checkpoint, optimizer_lr_path)
|
torch.save(checkpoint, optimizer_lr_path)
|
||||||
torch.save(model_state, model_path)
|
torch.save(model_state, model_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -35,4 +35,3 @@ env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \
|
|||||||
--mlm bert \
|
--mlm bert \
|
||||||
--wandb \
|
--wandb \
|
||||||
--checkpoint_activations \
|
--checkpoint_activations \
|
||||||
|
|
@ -38,4 +38,3 @@ env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \
|
|||||||
--resume_train \
|
--resume_train \
|
||||||
--load_pretrain_model /ckpt/1.pt \
|
--load_pretrain_model /ckpt/1.pt \
|
||||||
--load_optimizer_lr /ckpt/1.op_lrs \
|
--load_optimizer_lr /ckpt/1.op_lrs \
|
||||||
|
|
@ -4,21 +4,6 @@ import time
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from functools import partial
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.context import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
||||||
from colossalai.zero import ZeroOptimizer
|
|
||||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
|
||||||
|
|
||||||
from arguments import parse_args
|
from arguments import parse_args
|
||||||
from evaluation import evaluate
|
from evaluation import evaluate
|
||||||
from loss import LossForPretraining
|
from loss import LossForPretraining
|
||||||
@ -30,6 +15,15 @@ from utils.exp_util import get_mem_info, get_tflops, log_args, throughput_calcul
|
|||||||
from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables
|
from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables
|
||||||
from utils.logger import Logger
|
from utils.logger import Logger
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
|
||||||
|
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
from colossalai.zero import ZeroOptimizer
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
@ -39,7 +33,7 @@ def main():
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
|
||||||
|
|
||||||
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
||||||
|
|
||||||
logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug)
|
logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug)
|
||||||
|
|
||||||
if args.vscode_debug:
|
if args.vscode_debug:
|
||||||
@ -52,7 +46,7 @@ def main():
|
|||||||
args.local_rank = -1
|
args.local_rank = -1
|
||||||
args.log_interval = 1
|
args.log_interval = 1
|
||||||
else:
|
else:
|
||||||
colossalai.launch_from_torch(config={}) #args.colossal_config
|
colossalai.launch_from_torch(config={}) #args.colossal_config
|
||||||
args.local_rank = int(os.environ["LOCAL_RANK"])
|
args.local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
logger.info(
|
logger.info(
|
||||||
f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' +
|
f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' +
|
||||||
@ -63,7 +57,7 @@ def main():
|
|||||||
args.tokenizer = tokenizer
|
args.tokenizer = tokenizer
|
||||||
args.logger = logger
|
args.logger = logger
|
||||||
set_global_variables(launch_time, args.tensorboard_path)
|
set_global_variables(launch_time, args.tensorboard_path)
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
init_dev = get_current_device()
|
init_dev = get_current_device()
|
||||||
|
|
||||||
@ -116,7 +110,7 @@ def main():
|
|||||||
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)
|
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)
|
||||||
|
|
||||||
logger.info(get_mem_info(prefix='After init optim, '))
|
logger.info(get_mem_info(prefix='After init optim, '))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
config, model, numel = get_model(args, logger)
|
config, model, numel = get_model(args, logger)
|
||||||
logger.info("no_zero")
|
logger.info("no_zero")
|
||||||
@ -129,7 +123,7 @@ def main():
|
|||||||
get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length)
|
get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length)
|
||||||
|
|
||||||
# 144003367 is is the length of the entire dataset
|
# 144003367 is is the length of the entire dataset
|
||||||
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader)
|
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader)
|
||||||
total_steps = steps_per_epoch * args.epoch
|
total_steps = steps_per_epoch * args.epoch
|
||||||
|
|
||||||
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1)
|
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1)
|
||||||
@ -156,14 +150,15 @@ def main():
|
|||||||
start_epoch = o_l_state_dict['epoch']
|
start_epoch = o_l_state_dict['epoch']
|
||||||
start_shard = o_l_state_dict['shard'] + 1
|
start_shard = o_l_state_dict['shard'] + 1
|
||||||
# global_step = o_l_state_dict['global_step'] + 1
|
# global_step = o_l_state_dict['global_step'] + 1
|
||||||
logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}')
|
logger.info(
|
||||||
|
f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}'
|
||||||
|
)
|
||||||
|
|
||||||
criterion = LossForPretraining(config.vocab_size)
|
criterion = LossForPretraining(config.vocab_size)
|
||||||
|
|
||||||
# build dataloader
|
# build dataloader
|
||||||
pretrain_dataset_provider = NvidiaBertDatasetProvider(args)
|
pretrain_dataset_provider = NvidiaBertDatasetProvider(args)
|
||||||
|
|
||||||
|
|
||||||
logger.info(get_mem_info(prefix='After init model, '))
|
logger.info(get_mem_info(prefix='After init model, '))
|
||||||
|
|
||||||
best_loss = None
|
best_loss = None
|
||||||
@ -189,8 +184,8 @@ def main():
|
|||||||
iterator_data = enumerate(dataset_iterator)
|
iterator_data = enumerate(dataset_iterator)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
for step, batch_data in iterator_data:
|
for step, batch_data in iterator_data:
|
||||||
|
|
||||||
# batch_data = pretrain_dataset_provider.get_batch(batch_index)
|
# batch_data = pretrain_dataset_provider.get_batch(batch_index)
|
||||||
input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}")
|
input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}")
|
||||||
@ -200,7 +195,7 @@ def main():
|
|||||||
# nsp_label = batch_data[5].cuda()
|
# nsp_label = batch_data[5].cuda()
|
||||||
|
|
||||||
output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
loss = criterion(output.logits, mlm_label)
|
loss = criterion(output.logits, mlm_label)
|
||||||
pretrain_dataset_provider.prefetch_batch()
|
pretrain_dataset_provider.prefetch_batch()
|
||||||
|
|
||||||
@ -210,7 +205,7 @@ def main():
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
if global_step % args.log_interval == 0 and global_step != 0 \
|
if global_step % args.log_interval == 0 and global_step != 0 \
|
||||||
@ -242,9 +237,10 @@ def main():
|
|||||||
logger.info('*' * 100)
|
logger.info('*' * 100)
|
||||||
|
|
||||||
eval_loss += evaluate(model, args, logger, global_step, criterion)
|
eval_loss += evaluate(model, args, logger, global_step, criterion)
|
||||||
save_ckpt(model, optimizer, lr_scheduler, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step)
|
save_ckpt(model, optimizer, lr_scheduler,
|
||||||
|
os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch,
|
||||||
|
shard, global_step)
|
||||||
|
|
||||||
eval_loss /= len(os.listdir(args.data_path_prefix))
|
eval_loss /= len(os.listdir(args.data_path_prefix))
|
||||||
logger.info(
|
logger.info(
|
||||||
f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins'
|
f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins'
|
@ -1,8 +1,10 @@
|
|||||||
import time
|
|
||||||
import wandb
|
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import wandb
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
|
||||||
class WandbLog:
|
class WandbLog:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -15,7 +17,7 @@ class WandbLog:
|
|||||||
|
|
||||||
if model:
|
if model:
|
||||||
wandb.watch(model)
|
wandb.watch(model)
|
||||||
|
|
||||||
if gradient:
|
if gradient:
|
||||||
wandb.watch(gradient)
|
wandb.watch(gradient)
|
||||||
|
|
||||||
@ -30,7 +32,7 @@ class TensorboardLog:
|
|||||||
def log_train(self, result, step):
|
def log_train(self, result, step):
|
||||||
for k, v in result.items():
|
for k, v in result.items():
|
||||||
self.writer.add_scalar(f'{k}/train', v, step)
|
self.writer.add_scalar(f'{k}/train', v, step)
|
||||||
|
|
||||||
def log_eval(self, result, step):
|
def log_eval(self, result, step):
|
||||||
for k, v in result.items():
|
for k, v in result.items():
|
||||||
self.writer.add_scalar(f'{k}/eval', v, step)
|
self.writer.add_scalar(f'{k}/eval', v, step)
|
||||||
@ -38,9 +40,3 @@ class TensorboardLog:
|
|||||||
def log_zeroshot(self, result, step):
|
def log_zeroshot(self, result, step):
|
||||||
for k, v in result.items():
|
for k, v in result.items():
|
||||||
self.writer.add_scalar(f'{k}_acc/eval', v, step)
|
self.writer.add_scalar(f'{k}_acc/eval', v, step)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,9 +1,13 @@
|
|||||||
import functools
|
import functools
|
||||||
import os, shutil
|
import os
|
||||||
import torch
|
import shutil
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
|
import torch
|
||||||
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
|
|
||||||
def logging(s, log_path, print_=True, log_=True):
|
def logging(s, log_path, print_=True, log_=True):
|
||||||
if print_:
|
if print_:
|
||||||
print(s)
|
print(s)
|
||||||
@ -11,9 +15,11 @@ def logging(s, log_path, print_=True, log_=True):
|
|||||||
with open(log_path, 'a+') as f_log:
|
with open(log_path, 'a+') as f_log:
|
||||||
f_log.write(s + '\n')
|
f_log.write(s + '\n')
|
||||||
|
|
||||||
|
|
||||||
def get_logger(log_path, **kwargs):
|
def get_logger(log_path, **kwargs):
|
||||||
return functools.partial(logging, log_path=log_path, **kwargs)
|
return functools.partial(logging, log_path=log_path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
|
def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
|
||||||
if debug:
|
if debug:
|
||||||
print('Debug Mode : no experiment dir created')
|
print('Debug Mode : no experiment dir created')
|
||||||
@ -33,6 +39,7 @@ def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
|
|||||||
|
|
||||||
return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
|
return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
|
||||||
|
|
||||||
|
|
||||||
def get_cpu_mem():
|
def get_cpu_mem():
|
||||||
return psutil.Process().memory_info().rss / 1024**2
|
return psutil.Process().memory_info().rss / 1024**2
|
||||||
|
|
||||||
@ -52,11 +59,15 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
|
|||||||
def get_parameters_in_billions(model, world_size=1):
|
def get_parameters_in_billions(model, world_size=1):
|
||||||
gpus_per_model = world_size
|
gpus_per_model = world_size
|
||||||
|
|
||||||
approx_parameters_in_billions = sum([sum([p.ds_numel if hasattr(p,'ds_id') else p.nelement() for p in model_module.parameters()])
|
approx_parameters_in_billions = sum([
|
||||||
for model_module in model])
|
sum([p.ds_numel if hasattr(p, 'ds_id') else p.nelement()
|
||||||
|
for p in model_module.parameters()])
|
||||||
|
for model_module in model
|
||||||
|
])
|
||||||
|
|
||||||
return approx_parameters_in_billions * gpus_per_model / (1e9)
|
return approx_parameters_in_billions * gpus_per_model / (1e9)
|
||||||
|
|
||||||
|
|
||||||
def throughput_calculator(numel, args, config, iteration_time, total_iterations, world_size=1):
|
def throughput_calculator(numel, args, config, iteration_time, total_iterations, world_size=1):
|
||||||
gpus_per_model = 1
|
gpus_per_model = 1
|
||||||
batch_size = args.train_micro_batch_size_per_gpu
|
batch_size = args.train_micro_batch_size_per_gpu
|
||||||
@ -76,10 +87,13 @@ def throughput_calculator(numel, args, config, iteration_time, total_iterations,
|
|||||||
# The factor of 4 is when used with activation check-pointing,
|
# The factor of 4 is when used with activation check-pointing,
|
||||||
# otherwise it will be 3.
|
# otherwise it will be 3.
|
||||||
checkpoint_activations_factor = 4 if args.checkpoint_activations else 3
|
checkpoint_activations_factor = 4 if args.checkpoint_activations else 3
|
||||||
flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers * (hidden_size**2)) * (1. + (args.max_seq_length / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size)))
|
flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers *
|
||||||
|
(hidden_size**2)) * (1. + (args.max_seq_length / (6. * hidden_size)) +
|
||||||
|
(vocab_size / (16. * num_layers * hidden_size)))
|
||||||
tflops = flops_per_iteration / (elapsed_time_per_iter * (10**12))
|
tflops = flops_per_iteration / (elapsed_time_per_iter * (10**12))
|
||||||
return samples_per_second, tflops, approx_parameters_in_billions
|
return samples_per_second, tflops, approx_parameters_in_billions
|
||||||
|
|
||||||
|
|
||||||
def synchronize():
|
def synchronize():
|
||||||
if not torch.distributed.is_available():
|
if not torch.distributed.is_available():
|
||||||
return
|
return
|
||||||
@ -90,10 +104,11 @@ def synchronize():
|
|||||||
return
|
return
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
|
||||||
def log_args(logger, args):
|
def log_args(logger, args):
|
||||||
logger.info('--------args----------')
|
logger.info('--------args----------')
|
||||||
message = '\n'.join([f'{k:<30}: {v}' for k, v in vars(args).items()])
|
message = '\n'.join([f'{k:<30}: {v}' for k, v in vars(args).items()])
|
||||||
message += '\n'
|
message += '\n'
|
||||||
message += '\n'.join([f'{k:<30}: {v}' for k, v in gpc.config.items()])
|
message += '\n'.join([f'{k:<30}: {v}' for k, v in gpc.config.items()])
|
||||||
logger.info(message)
|
logger.info(message)
|
||||||
logger.info('--------args----------\n')
|
logger.info('--------args----------\n')
|
@ -1,5 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .WandbLog import TensorboardLog
|
from .WandbLog import TensorboardLog
|
||||||
|
|
||||||
_GLOBAL_TIMERS = None
|
_GLOBAL_TIMERS = None
|
||||||
@ -10,30 +12,34 @@ def set_global_variables(launch_time, tensorboard_path):
|
|||||||
_set_timers()
|
_set_timers()
|
||||||
_set_tensorboard_writer(launch_time, tensorboard_path)
|
_set_tensorboard_writer(launch_time, tensorboard_path)
|
||||||
|
|
||||||
|
|
||||||
def _set_timers():
|
def _set_timers():
|
||||||
"""Initialize timers."""
|
"""Initialize timers."""
|
||||||
global _GLOBAL_TIMERS
|
global _GLOBAL_TIMERS
|
||||||
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
|
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
|
||||||
_GLOBAL_TIMERS = Timers()
|
_GLOBAL_TIMERS = Timers()
|
||||||
|
|
||||||
|
|
||||||
def _set_tensorboard_writer(launch_time, tensorboard_path):
|
def _set_tensorboard_writer(launch_time, tensorboard_path):
|
||||||
"""Set tensorboard writer."""
|
"""Set tensorboard writer."""
|
||||||
global _GLOBAL_TENSORBOARD_WRITER
|
global _GLOBAL_TENSORBOARD_WRITER
|
||||||
_ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER,
|
_ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, 'tensorboard writer')
|
||||||
'tensorboard writer')
|
|
||||||
if torch.distributed.get_rank() == 0:
|
if torch.distributed.get_rank() == 0:
|
||||||
_GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f'/{launch_time}', launch_time)
|
_GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f'/{launch_time}', launch_time)
|
||||||
|
|
||||||
|
|
||||||
def get_timers():
|
def get_timers():
|
||||||
"""Return timers."""
|
"""Return timers."""
|
||||||
_ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
|
_ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
|
||||||
return _GLOBAL_TIMERS
|
return _GLOBAL_TIMERS
|
||||||
|
|
||||||
|
|
||||||
def get_tensorboard_writer():
|
def get_tensorboard_writer():
|
||||||
"""Return tensorboard writer. It can be None so no need
|
"""Return tensorboard writer. It can be None so no need
|
||||||
to check if it is initialized."""
|
to check if it is initialized."""
|
||||||
return _GLOBAL_TENSORBOARD_WRITER
|
return _GLOBAL_TENSORBOARD_WRITER
|
||||||
|
|
||||||
|
|
||||||
def _ensure_var_is_initialized(var, name):
|
def _ensure_var_is_initialized(var, name):
|
||||||
"""Make sure the input variable is not None."""
|
"""Make sure the input variable is not None."""
|
||||||
assert var is not None, '{} is not initialized.'.format(name)
|
assert var is not None, '{} is not initialized.'.format(name)
|
||||||
@ -115,12 +121,10 @@ class Timers:
|
|||||||
assert normalizer > 0.0
|
assert normalizer > 0.0
|
||||||
string = 'time (ms)'
|
string = 'time (ms)'
|
||||||
for name in names:
|
for name in names:
|
||||||
elapsed_time = self.timers[name].elapsed(
|
elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
|
||||||
reset=reset) * 1000.0 / normalizer
|
|
||||||
string += ' | {}: {:.2f}'.format(name, elapsed_time)
|
string += ' | {}: {:.2f}'.format(name, elapsed_time)
|
||||||
if torch.distributed.is_initialized():
|
if torch.distributed.is_initialized():
|
||||||
if torch.distributed.get_rank() == (
|
if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1):
|
||||||
torch.distributed.get_world_size() - 1):
|
|
||||||
print(string, flush=True)
|
print(string, flush=True)
|
||||||
else:
|
else:
|
||||||
print(string, flush=True)
|
print(string, flush=True)
|
@ -1,22 +1,22 @@
|
|||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
datefmt='%m/%d/%Y %H:%M:%S',
|
||||||
datefmt='%m/%d/%Y %H:%M:%S',
|
level=logging.INFO)
|
||||||
level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Logger():
|
class Logger():
|
||||||
|
|
||||||
def __init__(self, log_path, cuda=False, debug=False):
|
def __init__(self, log_path, cuda=False, debug=False):
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
self.cuda = cuda
|
self.cuda = cuda
|
||||||
self.log_path = log_path
|
self.log_path = log_path
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
|
|
||||||
|
|
||||||
def info(self, message, log_=True, print_=True, *args, **kwargs):
|
def info(self, message, log_=True, print_=True, *args, **kwargs):
|
||||||
if (self.cuda and dist.get_rank() == 0) or not self.cuda:
|
if (self.cuda and dist.get_rank() == 0) or not self.cuda:
|
||||||
if print_:
|
if print_:
|
||||||
@ -26,6 +26,5 @@ class Logger():
|
|||||||
with open(self.log_path, 'a+') as f_log:
|
with open(self.log_path, 'a+') as f_log:
|
||||||
f_log.write(message + '\n')
|
f_log.write(message + '\n')
|
||||||
|
|
||||||
|
|
||||||
def error(self, message, *args, **kwargs):
|
def error(self, message, *args, **kwargs):
|
||||||
self.logger.error(message, *args, **kwargs)
|
self.logger.error(message, *args, **kwargs)
|
@ -4,4 +4,4 @@ tqdm
|
|||||||
tensorboard
|
tensorboard
|
||||||
numpy
|
numpy
|
||||||
h5py
|
h5py
|
||||||
wandb
|
wandb
|
@ -1,184 +0,0 @@
|
|||||||
#include <algorithm>
|
|
||||||
#include <iostream>
|
|
||||||
#include <limits>
|
|
||||||
#include <math.h>
|
|
||||||
#include <stdexcept>
|
|
||||||
#include <pybind11/pybind11.h>
|
|
||||||
#include <pybind11/numpy.h>
|
|
||||||
#include <random>
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <pybind11/stl.h>
|
|
||||||
#include <chrono>
|
|
||||||
#include <tuple>
|
|
||||||
#include <unordered_set>
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
const int32_t LONG_SENTENCE_LEN = 512;
|
|
||||||
|
|
||||||
struct MaskedLMInstance {
|
|
||||||
int index;
|
|
||||||
std::string label;
|
|
||||||
MaskedLMInstance(int index, std::string label) {
|
|
||||||
this->index = index;
|
|
||||||
this->label = label;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
auto get_new_segment(std::vector<std::string> segment, std::vector<std::string> segment_jieba, const std::vector<bool> chinese_vocab) { // const std::unordered_set<std::string> &chinese_vocab
|
|
||||||
std::unordered_set<std::string> seq_cws_dict;
|
|
||||||
for (auto word : segment_jieba) {
|
|
||||||
seq_cws_dict.insert(word);
|
|
||||||
}
|
|
||||||
int i = 0;
|
|
||||||
std::vector<std::string> new_segment;
|
|
||||||
int segment_size = segment.size();
|
|
||||||
while (i < segment_size) {
|
|
||||||
if (!chinese_vocab[i]) { //chinese_vocab.find(segment[i]) == chinese_vocab.end()
|
|
||||||
new_segment.emplace_back(segment[i]);
|
|
||||||
i += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
bool has_add = false;
|
|
||||||
for (int length = 3; length >= 1; length--) {
|
|
||||||
if (i + length > segment_size) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
std::string chinese_word = "";
|
|
||||||
for (int j = i; j < i + length; j++) {
|
|
||||||
chinese_word += segment[j];
|
|
||||||
}
|
|
||||||
if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) {
|
|
||||||
new_segment.emplace_back(segment[i]);
|
|
||||||
for (int j = i + 1; j < i + length; j++) {
|
|
||||||
new_segment.emplace_back("##" + segment[j]);
|
|
||||||
}
|
|
||||||
i += length;
|
|
||||||
has_add = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!has_add) {
|
|
||||||
new_segment.emplace_back(segment[i]);
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return new_segment;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool startsWith(const std::string& s, const std::string& sub) {
|
|
||||||
return s.find(sub) == 0 ? true : false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto create_whole_masked_lm_predictions(std::vector<std::string> &tokens,
|
|
||||||
const std::vector<std::string> &original_tokens,
|
|
||||||
const std::vector<std::string> &vocab_words,
|
|
||||||
std::map<std::string, int> &vocab,
|
|
||||||
const int max_predictions_per_seq,
|
|
||||||
const double masked_lm_prob) {
|
|
||||||
// for (auto item : vocab) {
|
|
||||||
// std::cout << "key=" << std::string(py::str(item.first)) << ", "
|
|
||||||
// << "value=" << std::string(py::str(item.second)) << std::endl;
|
|
||||||
// }
|
|
||||||
std::vector<std::vector<int> > cand_indexes;
|
|
||||||
std::vector<int> cand_temp;
|
|
||||||
int tokens_size = tokens.size();
|
|
||||||
std::string prefix = "##";
|
|
||||||
bool do_whole_masked = true;
|
|
||||||
|
|
||||||
for (int i = 0; i < tokens_size; i++) {
|
|
||||||
if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (do_whole_masked && (cand_indexes.size() > 0) && (tokens[i].rfind(prefix, 0) == 0)) {
|
|
||||||
cand_temp.emplace_back(i);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
if (cand_temp.size() > 0) {
|
|
||||||
cand_indexes.emplace_back(cand_temp);
|
|
||||||
}
|
|
||||||
cand_temp.clear();
|
|
||||||
cand_temp.emplace_back(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto seed = std::chrono::system_clock::now().time_since_epoch().count();
|
|
||||||
std::shuffle(cand_indexes.begin(), cand_indexes.end(), std::default_random_engine(seed));
|
|
||||||
// for (auto i : cand_indexes) {
|
|
||||||
// for (auto j : i) {
|
|
||||||
// std::cout << tokens[j] << " ";
|
|
||||||
// }
|
|
||||||
// std::cout << std::endl;
|
|
||||||
// }
|
|
||||||
// for (auto i : output_tokens) {
|
|
||||||
// std::cout << i;
|
|
||||||
// }
|
|
||||||
// std::cout << std::endl;
|
|
||||||
|
|
||||||
int num_to_predict = std::min(max_predictions_per_seq,
|
|
||||||
std::max(1, int(tokens_size * masked_lm_prob)));
|
|
||||||
// std::cout << num_to_predict << std::endl;
|
|
||||||
|
|
||||||
std::set<int> covered_indexes;
|
|
||||||
std::vector<int> masked_lm_output(tokens_size, -1);
|
|
||||||
int vocab_words_len = vocab_words.size();
|
|
||||||
std::default_random_engine e(seed);
|
|
||||||
std::uniform_real_distribution<double> u1(0.0, 1.0);
|
|
||||||
std::uniform_int_distribution<unsigned> u2(0, vocab_words_len - 1);
|
|
||||||
int mask_cnt = 0;
|
|
||||||
std::vector<std::string> output_tokens;
|
|
||||||
output_tokens = original_tokens;
|
|
||||||
|
|
||||||
for (auto index_set : cand_indexes) {
|
|
||||||
if (mask_cnt > num_to_predict) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
int index_set_size = index_set.size();
|
|
||||||
if (mask_cnt + index_set_size > num_to_predict) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
bool is_any_index_covered = false;
|
|
||||||
for (auto index : index_set) {
|
|
||||||
if (covered_indexes.find(index) != covered_indexes.end()) {
|
|
||||||
is_any_index_covered = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (is_any_index_covered) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (auto index : index_set) {
|
|
||||||
|
|
||||||
covered_indexes.insert(index);
|
|
||||||
std::string masked_token;
|
|
||||||
if (u1(e) < 0.8) {
|
|
||||||
masked_token = "[MASK]";
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
if (u1(e) < 0.5) {
|
|
||||||
masked_token = output_tokens[index];
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
int random_index = u2(e);
|
|
||||||
masked_token = vocab_words[random_index];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index]));
|
|
||||||
masked_lm_output[index] = vocab[output_tokens[index]];
|
|
||||||
output_tokens[index] = masked_token;
|
|
||||||
mask_cnt++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// for (auto p : masked_lms) {
|
|
||||||
// masked_lm_output[p.index] = vocab[p.label];
|
|
||||||
// }
|
|
||||||
return std::make_tuple(output_tokens, masked_lm_output);
|
|
||||||
}
|
|
||||||
|
|
||||||
PYBIND11_MODULE(mask, m) {
|
|
||||||
m.def("create_whole_masked_lm_predictions", &create_whole_masked_lm_predictions);
|
|
||||||
m.def("get_new_segment", &get_new_segment);
|
|
||||||
}
|
|
@ -1,176 +0,0 @@
|
|||||||
import colossalai
|
|
||||||
from numpy import require
|
|
||||||
|
|
||||||
__all__ = ['parse_args']
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = colossalai.get_default_parser()
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--distplan",
|
|
||||||
type=str,
|
|
||||||
default='CAI_Gemini',
|
|
||||||
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--tp_degree",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--placement",
|
|
||||||
type=str,
|
|
||||||
default='cpu',
|
|
||||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--shardinit",
|
|
||||||
action='store_true',
|
|
||||||
help="Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--lr',
|
|
||||||
type=float,
|
|
||||||
required=True,
|
|
||||||
help='initial learning rate')
|
|
||||||
parser.add_argument(
|
|
||||||
'--epoch',
|
|
||||||
type=int,
|
|
||||||
required=True,
|
|
||||||
help='number of epoch')
|
|
||||||
parser.add_argument(
|
|
||||||
'--data_path_prefix',
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="location of the train data corpus")
|
|
||||||
parser.add_argument(
|
|
||||||
'--eval_data_path_prefix',
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help='location of the evaluation data corpus')
|
|
||||||
parser.add_argument(
|
|
||||||
'--tokenizer_path',
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help='location of the tokenizer')
|
|
||||||
parser.add_argument(
|
|
||||||
'--max_seq_length',
|
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help='sequence length')
|
|
||||||
parser.add_argument(
|
|
||||||
'--refresh_bucket_size',
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help=
|
|
||||||
"This param makes sure that a certain task is repeated for this time steps to \
|
|
||||||
optimise on the back propogation speed with APEX's DistributedDataParallel")
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_predictions_per_seq",
|
|
||||||
"--max_pred",
|
|
||||||
default=80,
|
|
||||||
type=int,
|
|
||||||
help=
|
|
||||||
"The maximum number of masked tokens in a sequence to be predicted.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--gradient_accumulation_steps",
|
|
||||||
default=1,
|
|
||||||
type=int,
|
|
||||||
help="accumulation_steps")
|
|
||||||
parser.add_argument(
|
|
||||||
"--train_micro_batch_size_per_gpu",
|
|
||||||
default=2,
|
|
||||||
type=int,
|
|
||||||
required=True,
|
|
||||||
help="train batch size")
|
|
||||||
parser.add_argument(
|
|
||||||
"--eval_micro_batch_size_per_gpu",
|
|
||||||
default=2,
|
|
||||||
type=int,
|
|
||||||
required=True,
|
|
||||||
help="eval batch size")
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_workers",
|
|
||||||
default=8,
|
|
||||||
type=int,
|
|
||||||
help="")
|
|
||||||
parser.add_argument(
|
|
||||||
"--async_worker",
|
|
||||||
action='store_true',
|
|
||||||
help="")
|
|
||||||
parser.add_argument(
|
|
||||||
"--bert_config",
|
|
||||||
required=True,
|
|
||||||
type=str,
|
|
||||||
help="location of config.json")
|
|
||||||
parser.add_argument(
|
|
||||||
"--wandb",
|
|
||||||
action='store_true',
|
|
||||||
help="use wandb to watch model")
|
|
||||||
parser.add_argument(
|
|
||||||
"--wandb_project_name",
|
|
||||||
default='roberta',
|
|
||||||
help="wandb project name")
|
|
||||||
parser.add_argument(
|
|
||||||
"--log_interval",
|
|
||||||
default=100,
|
|
||||||
type=int,
|
|
||||||
help="report interval")
|
|
||||||
parser.add_argument(
|
|
||||||
"--log_path",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="log file which records train step")
|
|
||||||
parser.add_argument(
|
|
||||||
"--tensorboard_path",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="location of tensorboard file")
|
|
||||||
parser.add_argument(
|
|
||||||
"--colossal_config",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="colossal config, which contains zero config and so on")
|
|
||||||
parser.add_argument(
|
|
||||||
"--ckpt_path",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="location of saving checkpoint, which contains model and optimizer")
|
|
||||||
parser.add_argument(
|
|
||||||
'--seed',
|
|
||||||
type=int,
|
|
||||||
default=42,
|
|
||||||
help="random seed for initialization")
|
|
||||||
parser.add_argument(
|
|
||||||
'--vscode_debug',
|
|
||||||
action='store_true',
|
|
||||||
help="use vscode to debug")
|
|
||||||
parser.add_argument(
|
|
||||||
'--load_pretrain_model',
|
|
||||||
default='',
|
|
||||||
type=str,
|
|
||||||
help="location of model's checkpoin")
|
|
||||||
parser.add_argument(
|
|
||||||
'--load_optimizer_lr',
|
|
||||||
default='',
|
|
||||||
type=str,
|
|
||||||
help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step")
|
|
||||||
parser.add_argument(
|
|
||||||
'--resume_train',
|
|
||||||
action='store_true',
|
|
||||||
help="whether resume training from a early checkpoint")
|
|
||||||
parser.add_argument(
|
|
||||||
'--mlm',
|
|
||||||
default='bert',
|
|
||||||
type=str,
|
|
||||||
help="model type, bert or deberta")
|
|
||||||
parser.add_argument(
|
|
||||||
'--checkpoint_activations',
|
|
||||||
action='store_true',
|
|
||||||
help="whether to use gradient checkpointing")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
return args
|
|
Loading…
Reference in New Issue
Block a user