diff --git a/examples/language/roberta/README.md b/examples/language/roberta/README.md
new file mode 100644
index 000000000..c119d23b5
--- /dev/null
+++ b/examples/language/roberta/README.md
@@ -0,0 +1,58 @@
+# Introduction
+This repo introduce how to pretrain a chinese roberta-large from scratch, including preprocessing, pretraining, finetune. The repo can help you quickly train a high-quality bert.
+
+## 0. Prerequisite
+- Install Colossal-AI
+- Editing the port from /etc/ssh/sshd_config and /etc/ssh/ssh_config, every host expose the same ssh port of server and client. If you are a root user, you also set the **PermitRootLogin** from /etc/ssh/sshd_config to "yes"
+- Ensure that each host can log in to each other without password. If you have n hosts, need to execute n2 times
+
+```
+ssh-keygen
+ssh-copy-id -i ~/.ssh/id_rsa.pub ip_destination
+```
+
+- In all hosts, edit /etc/hosts to record all hosts' name and ip.The example is shown below.
+
+```bash
+192.168.2.1 GPU001
+192.168.2.2 GPU002
+192.168.2.3 GPU003
+192.168.2.4 GPU004
+192.168.2.5 GPU005
+192.168.2.6 GPU006
+192.168.2.7 GPU007
+...
+```
+
+- restart ssh
+```
+service ssh restart
+```
+
+## 1. Corpus Preprocessing
+```bash
+cd preprocessing
+```
+following the `README.md`, preprocess orginal corpus to h5py+numpy
+
+## 2. Pretrain
+
+```bash
+cd pretraining
+```
+following the `README.md`, load the h5py generated by preprocess of step 1 to pretrain the model
+
+## 3. Finetune
+
+The checkpoint produced by this repo can replace `pytorch_model.bin` from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) directly. Then use transfomers from HuggingFace to finetune downstream application.
+
+## Contributors
+The repo is contributed by AI team from [Moore Threads](https://www.mthreads.com/). If you find any problems for pretraining, please file an issue or send an email to yehua.zhang@mthreads.com. At last, welcome any form of contribution!
+
+```
+@misc{
+ title={A simple Chinese RoBERTa Example for Whole Word Masked},
+ author={Yehua Zhang, Chen Zhang},
+ year={2022}
+}
+```
\ No newline at end of file
diff --git a/examples/language/roberta/configs/colossalai_ddp.py b/examples/language/roberta/configs/colossalai_ddp.py
new file mode 100644
index 000000000..c3c59aa40
--- /dev/null
+++ b/examples/language/roberta/configs/colossalai_ddp.py
@@ -0,0 +1,4 @@
+from colossalai.zero.shard_utils import TensorShardStrategy
+from colossalai.nn.optimizer import FusedAdam
+
+clip_grad_norm = 1.0
diff --git a/examples/language/roberta/configs/colossalai_zero.py b/examples/language/roberta/configs/colossalai_zero.py
new file mode 100644
index 000000000..c5debdce0
--- /dev/null
+++ b/examples/language/roberta/configs/colossalai_zero.py
@@ -0,0 +1,32 @@
+from colossalai.zero.shard_utils import TensorShardStrategy
+from colossalai.nn.optimizer import FusedAdam
+
+# fp16 = dict(
+# mode=AMP_TYPE.TORCH,
+# )
+
+# seed = 2
+zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(),
+ reduce_scatter_bucket_size_mb=25,
+ fp32_reduce_scatter=False,
+ tensor_placement_policy="cuda",
+ gradient_predivide_factor=1.0,
+ reuse_fp16_shard=False),
+ optimizer_config=dict(gpu_margin_mem_ratio=0.8,
+ initial_scale=2**5,
+ min_scale=1,
+ growth_factor=2,
+ backoff_factor=0.5,
+ growth_interval=1000,
+ hysteresis=2,
+ max_scale=2**32))
+
+# gradient_accumulation = 4
+clip_grad_norm = 1.0
+optimizer = dict(
+ type=FusedAdam,
+ lr=0.00015,
+ weight_decay=1e-2,
+)
+
+# 64433
\ No newline at end of file
diff --git a/examples/language/roberta/preprocessing/Makefile b/examples/language/roberta/preprocessing/Makefile
new file mode 100644
index 000000000..82ee4e1c5
--- /dev/null
+++ b/examples/language/roberta/preprocessing/Makefile
@@ -0,0 +1,9 @@
+CXXFLAGS += -O3 -Wall -shared -std=c++14 -fPIC -fdiagnostics-color
+CPPFLAGS += $(shell python3 -m pybind11 --includes)
+LIBNAME = mask
+LIBEXT = $(shell python3-config --extension-suffix)
+
+default: $(LIBNAME)$(LIBEXT)
+
+%$(LIBEXT): %.cpp
+ $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@
diff --git a/examples/language/roberta/preprocessing/README.md b/examples/language/roberta/preprocessing/README.md
new file mode 100644
index 000000000..1dbd745ab
--- /dev/null
+++ b/examples/language/roberta/preprocessing/README.md
@@ -0,0 +1,105 @@
+# Data PreProcessing for chinese Whole Word Masked
+
+
+
+## Catalogue:
+* 1. Introduction
+* 2. Quick Start Guide:
+ * 2.1. Split Sentence
+ * 2.2.Tokenizer & Whole Word Masked
+
+
+
+
+## 1. Introduction: [Back to Top]
+This folder is used to preprocess chinese corpus with Whole Word Masked. You can obtain corpus from [WuDao](https://resource.wudaoai.cn/home?ind&name=WuDaoCorpora%202.0&id=1394901288847716352). Moreover, data preprocessing is flexible, and you can modify the code based on your needs, hardware or parallel framework(Open MPI, Spark, Dask).
+
+
+
+## 2. Quick Start Guide: [Back to Top]
+
+
+
+### 2.1. Split Sentence & Split data into multiple shard:
+Firstly, each file has multiple documents, and each document contains multiple sentences. Split sentence through punctuation, such as `。!`. **Secondly, split data into multiple shard based on server hardware (cpu, cpu memory, hard disk) and corpus size.** Each shard contains a part of corpus, and the model needs to train all the shards as one epoch.
+In this example, split 200G Corpus into 100 shard, and each shard is about 2G. The size of the shard is memory-dependent, taking into account the number of servers, the memory used by the tokenizer, and the memory used by the multi-process training to read the shard (n data parallel requires n\*shard_size memory). **To sum up, data preprocessing and model pretraining requires fighting with hardware, not just GPU.**
+
+```python
+python sentence_split.py --input_path /orginal_corpus --output_path /shard --shard 100
+# This step takes a short time
+```
+* `--input_path`: all original corpus, e.g., /orginal_corpus/0.json /orginal_corpus/1.json ...
+* `--output_path`: all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ...
+* `--shard`: Number of shard, e.g., 10, 50, or 100
+
+Input json:
+
+```
+[
+ {
+ "id": 0,
+ "title": "打篮球",
+ "content": "我今天去打篮球。不回来吃饭。"
+ }
+ {
+ "id": 1,
+ "title": "旅游",
+ "content": "我后天去旅游。下周请假。"
+ }
+]
+```
+
+Output txt:
+
+```
+我今天去打篮球。
+不回来吃饭。
+]]
+我后天去旅游。
+下周请假。
+```
+
+
+
+### 2.2. Tokenizer & Whole Word Masked:
+
+```python
+python tokenize_mask.py --input_path /shard --output_path /h5 --tokenizer_path /roberta --backend python
+# This step is time consuming and is mainly spent on mask
+```
+
+**[optional but recommended]**: the C++ backend with `pybind11` can provide faster speed
+
+```shell
+make
+```
+
+* `--input_path`: location of all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ...
+* `--output_path`: location of all h5 with token_id, input_mask, segment_ids and masked_lm_positions, e.g., /h5/0.h5, /h5/1.h5 ...
+* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json. Download config.json, special_tokens_map.json, vocab.txt and tokenzier.json from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main)
+* `--backend`: python or c++, **specifies c++ can obtain faster preprocess speed**
+* `--dupe_factor`: specifies how many times the preprocessor repeats to create the input from the same article/document
+* `--worker`: number of process
+
+Input txt:
+
+```
+我今天去打篮球。
+不回来吃饭。
+]]
+我后天去旅游。
+下周请假。
+```
+
+Output h5+numpy:
+
+```
+'input_ids': [[id0,id1,id2,id3,id4,id5,id6,0,0..],
+ ...]
+'input_mask': [[1,1,1,1,1,1,0,0..],
+ ...]
+'segment_ids': [[0,0,0,0,0,...],
+ ...]
+'masked_lm_positions': [[label1,-1,-1,label2,-1...],
+ ...]
+```
\ No newline at end of file
diff --git a/examples/language/roberta/preprocessing/get_mask.py b/examples/language/roberta/preprocessing/get_mask.py
new file mode 100644
index 000000000..da297f98e
--- /dev/null
+++ b/examples/language/roberta/preprocessing/get_mask.py
@@ -0,0 +1,266 @@
+import torch
+import os
+from enum import IntEnum
+from random import choice
+import random
+import collections
+import time
+import logging
+import jieba
+jieba.setLogLevel(logging.CRITICAL)
+import re
+import numpy as np
+import mask
+
+PAD = 0
+MaskedLMInstance = collections.namedtuple("MaskedLMInstance",
+ ["index", "label"])
+
+
+def map_to_numpy(data):
+ return np.asarray(data)
+
+
+class PreTrainingDataset():
+ def __init__(self,
+ tokenizer,
+ max_seq_length,
+ backend='python',
+ max_predictions_per_seq: int = 80,
+ do_whole_word_mask: bool = True):
+ self.tokenizer = tokenizer
+ self.max_seq_length = max_seq_length
+ self.masked_lm_prob = 0.15
+ self.backend = backend
+ self.do_whole_word_mask = do_whole_word_mask
+ self.max_predictions_per_seq = max_predictions_per_seq
+ self.vocab_words = list(tokenizer.vocab.keys())
+ self.rec = re.compile('[\u4E00-\u9FA5]')
+ self.whole_rec = re.compile('##[\u4E00-\u9FA5]')
+
+ self.mlm_p = 0.15
+ self.mlm_mask_p = 0.8
+ self.mlm_tamper_p = 0.05
+ self.mlm_maintain_p = 0.1
+
+
+ def tokenize(self, doc):
+ temp = []
+ for d in doc:
+ temp.append(self.tokenizer.tokenize(d))
+ return temp
+
+
+ def create_training_instance(self, instance):
+ is_next = 1
+ raw_text_list = self.get_new_segment(instance)
+ tokens_a = raw_text_list
+ assert len(tokens_a) == len(instance)
+ # tokens_a, tokens_b, is_next = instance.get_values()
+ # print(f'is_next label:{is_next}')
+ # Create mapper
+ tokens = []
+ original_tokens = []
+ segment_ids = []
+ tokens.append("[CLS]")
+ original_tokens.append('[CLS]')
+ segment_ids.append(0)
+ for index, token in enumerate(tokens_a):
+ tokens.append(token)
+ original_tokens.append(instance[index])
+ segment_ids.append(0)
+
+ tokens.append("[SEP]")
+ original_tokens.append('[SEP]')
+ segment_ids.append(0)
+
+ # for token in tokens_b:
+ # tokens.append(token)
+ # segment_ids.append(1)
+
+ # tokens.append("[SEP]")
+ # segment_ids.append(1)
+
+ # Get Masked LM predictions
+ if self.backend == 'c++':
+ output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions(tokens, original_tokens, self.vocab_words,
+ self.tokenizer.vocab, self.max_predictions_per_seq, self.masked_lm_prob)
+ elif self.backend == 'python':
+ output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens)
+
+ # Convert to Ids
+ input_ids = self.tokenizer.convert_tokens_to_ids(output_tokens)
+ input_mask = [1] * len(input_ids)
+
+ while len(input_ids) < self.max_seq_length:
+ input_ids.append(PAD)
+ segment_ids.append(PAD)
+ input_mask.append(PAD)
+ masked_lm_output.append(-1)
+ return ([
+ map_to_numpy(input_ids),
+ map_to_numpy(input_mask),
+ map_to_numpy(segment_ids),
+ map_to_numpy(masked_lm_output),
+ map_to_numpy([is_next])
+ ])
+
+
+ def create_masked_lm_predictions(self, tokens):
+ cand_indexes = []
+ for i, token in enumerate(tokens):
+ if token == "[CLS]" or token == "[SEP]":
+ continue
+ if (self.do_whole_word_mask and len(cand_indexes) >= 1 and
+ token.startswith("##")):
+ cand_indexes[-1].append(i)
+ else:
+ cand_indexes.append([i])
+
+ # cand_indexes.append(i)
+
+ random.shuffle(cand_indexes)
+ output_tokens = list(tokens)
+
+ num_to_predict = min(
+ self.max_predictions_per_seq,
+ max(1, int(round(len(tokens) * self.masked_lm_prob))))
+
+ masked_lms = []
+ covered_indexes = set()
+ for index in cand_indexes:
+ if len(masked_lms) >= num_to_predict:
+ break
+ if index in covered_indexes:
+ continue
+ covered_indexes.add(index)
+
+ masked_token = None
+ # 80% mask
+ if random.random() < 0.8:
+ masked_token = "[MASK]"
+ else:
+ # 10% Keep Original
+ if random.random() < 0.5:
+ masked_token = tokens[index]
+ # 10% replace w/ random word
+ else:
+ masked_token = self.vocab_words[random.randint(
+ 0,
+ len(self.vocab_words) - 1)]
+
+ output_tokens[index] = masked_token
+ masked_lms.append(
+ MaskedLMInstance(index=index, label=tokens[index]))
+
+ masked_lms = sorted(masked_lms, key=lambda x: x.index)
+ masked_lm_output = [-1] * len(output_tokens)
+ for p in masked_lms:
+ masked_lm_output[p.index] = self.tokenizer.vocab[p.label]
+
+ return (output_tokens, masked_lm_output)
+
+
+ def get_new_segment(self, segment):
+ """
+ 输入一句话,返回一句经过处理的话: 为了支持中文全称mask,将被分开的词,将上特殊标记("#"),使得后续处理模块,能够知道哪些字是属于同一个词的。
+ :param segment: 一句话
+ :return: 一句处理过的话
+ """
+ seq_cws = jieba.lcut(''.join(segment))
+ seq_cws_dict = {x: 1 for x in seq_cws}
+ new_segment = []
+ i = 0
+ while i < len(segment):
+ if len(self.rec.findall(segment[i])) == 0: # 不是中文的,原文加进去。
+ new_segment.append(segment[i])
+ i += 1
+ continue
+
+ has_add = False
+ for length in range(3, 0, -1):
+ if i + length > len(segment):
+ continue
+ if ''.join(segment[i: i+length]) in seq_cws_dict:
+ new_segment.append(segment[i])
+ for l in range(1, length):
+ new_segment.append('##' + segment[i+l])
+ i += length
+ has_add = True
+ break
+ if not has_add:
+ new_segment.append(segment[i])
+ i += 1
+ return new_segment
+
+
+ def create_whole_masked_lm_predictions(self, tokens):
+ """Creates the predictions for the masked LM objective."""
+
+ cand_indexes = []
+ for (i, token) in enumerate(tokens):
+ if token == "[CLS]" or token == "[SEP]":
+ continue
+ # Whole Word Masking means that if we mask all of the wordpieces
+ # corresponding to an original word. When a word has been split into
+ # WordPieces, the first token does not have any marker and any subsequence
+ # tokens are prefixed with ##. So whenever we see the ## token, we
+ # append it to the previous set of word indexes.
+ #
+ # Note that Whole Word Masking does *not* change the training code
+ # at all -- we still predict each WordPiece independently, softmaxed
+ # over the entire vocabulary.
+ if (self.do_whole_word_mask and len(cand_indexes) >= 1 and
+ token.startswith("##")):
+ cand_indexes[-1].append(i)
+ else:
+ cand_indexes.append([i])
+
+ random.shuffle(cand_indexes)
+
+ output_tokens = [t[2:] if len(self.whole_rec.findall(t))>0 else t for t in tokens] # 去掉"##"
+
+ num_to_predict = min(self.max_predictions_per_seq,
+ max(1, int(round(len(tokens) * self.masked_lm_prob))))
+
+ masked_lms = []
+ covered_indexes = set()
+ for index_set in cand_indexes:
+ if len(masked_lms) >= num_to_predict:
+ break
+ # If adding a whole-word mask would exceed the maximum number of
+ # predictions, then just skip this candidate.
+ if len(masked_lms) + len(index_set) > num_to_predict:
+ continue
+ is_any_index_covered = False
+ for index in index_set:
+ if index in covered_indexes:
+ is_any_index_covered = True
+ break
+ if is_any_index_covered:
+ continue
+ for index in index_set:
+ covered_indexes.add(index)
+
+ masked_token = None
+ # 80% of the time, replace with [MASK]
+ if random.random() < 0.8:
+ masked_token = "[MASK]"
+ else:
+ # 10% of the time, keep original
+ if random.random() < 0.5:
+ masked_token = tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index] # 去掉"##"
+ # 10% of the time, replace with random word
+ else:
+ masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)]
+
+ output_tokens[index] = masked_token
+
+ masked_lms.append(MaskedLMInstance(index=index, label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index]))
+ assert len(masked_lms) <= num_to_predict
+ masked_lms = sorted(masked_lms, key=lambda x: x.index)
+ masked_lm_output = [-1] * len(output_tokens)
+ for p in masked_lms:
+ masked_lm_output[p.index] = self.tokenizer.vocab[p.label]
+
+ return (output_tokens, masked_lm_output)
diff --git a/examples/language/roberta/preprocessing/mask.cpp b/examples/language/roberta/preprocessing/mask.cpp
new file mode 100644
index 000000000..8355c45cf
--- /dev/null
+++ b/examples/language/roberta/preprocessing/mask.cpp
@@ -0,0 +1,184 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace py = pybind11;
+
+const int32_t LONG_SENTENCE_LEN = 512;
+
+struct MaskedLMInstance {
+ int index;
+ std::string label;
+ MaskedLMInstance(int index, std::string label) {
+ this->index = index;
+ this->label = label;
+ }
+};
+
+auto get_new_segment(std::vector segment, std::vector segment_jieba, const std::vector chinese_vocab) { // const std::unordered_set &chinese_vocab
+ std::unordered_set seq_cws_dict;
+ for (auto word : segment_jieba) {
+ seq_cws_dict.insert(word);
+ }
+ int i = 0;
+ std::vector new_segment;
+ int segment_size = segment.size();
+ while (i < segment_size) {
+ if (!chinese_vocab[i]) { //chinese_vocab.find(segment[i]) == chinese_vocab.end()
+ new_segment.emplace_back(segment[i]);
+ i += 1;
+ continue;
+ }
+ bool has_add = false;
+ for (int length = 3; length >= 1; length--) {
+ if (i + length > segment_size) {
+ continue;
+ }
+ std::string chinese_word = "";
+ for (int j = i; j < i + length; j++) {
+ chinese_word += segment[j];
+ }
+ if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) {
+ new_segment.emplace_back(segment[i]);
+ for (int j = i + 1; j < i + length; j++) {
+ new_segment.emplace_back("##" + segment[j]);
+ }
+ i += length;
+ has_add = true;
+ break;
+ }
+ }
+ if (!has_add) {
+ new_segment.emplace_back(segment[i]);
+ i += 1;
+ }
+ }
+
+ return new_segment;
+}
+
+bool startsWith(const std::string& s, const std::string& sub) {
+ return s.find(sub) == 0 ? true : false;
+}
+
+auto create_whole_masked_lm_predictions(std::vector &tokens,
+ const std::vector &original_tokens,
+ const std::vector &vocab_words,
+ std::map &vocab,
+ const int max_predictions_per_seq,
+ const double masked_lm_prob) {
+ // for (auto item : vocab) {
+ // std::cout << "key=" << std::string(py::str(item.first)) << ", "
+ // << "value=" << std::string(py::str(item.second)) << std::endl;
+ // }
+ std::vector > cand_indexes;
+ std::vector cand_temp;
+ int tokens_size = tokens.size();
+ std::string prefix = "##";
+ bool do_whole_masked = true;
+
+ for (int i = 0; i < tokens_size; i++) {
+ if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") {
+ continue;
+ }
+ if (do_whole_masked && (cand_indexes.size() > 0) && (tokens[i].rfind(prefix, 0) == 0)) {
+ cand_temp.emplace_back(i);
+ }
+ else {
+ if (cand_temp.size() > 0) {
+ cand_indexes.emplace_back(cand_temp);
+ }
+ cand_temp.clear();
+ cand_temp.emplace_back(i);
+ }
+ }
+ auto seed = std::chrono::system_clock::now().time_since_epoch().count();
+ std::shuffle(cand_indexes.begin(), cand_indexes.end(), std::default_random_engine(seed));
+ // for (auto i : cand_indexes) {
+ // for (auto j : i) {
+ // std::cout << tokens[j] << " ";
+ // }
+ // std::cout << std::endl;
+ // }
+ // for (auto i : output_tokens) {
+ // std::cout << i;
+ // }
+ // std::cout << std::endl;
+
+ int num_to_predict = std::min(max_predictions_per_seq,
+ std::max(1, int(tokens_size * masked_lm_prob)));
+ // std::cout << num_to_predict << std::endl;
+
+ std::set covered_indexes;
+ std::vector masked_lm_output(tokens_size, -1);
+ int vocab_words_len = vocab_words.size();
+ std::default_random_engine e(seed);
+ std::uniform_real_distribution u1(0.0, 1.0);
+ std::uniform_int_distribution u2(0, vocab_words_len - 1);
+ int mask_cnt = 0;
+ std::vector output_tokens;
+ output_tokens = original_tokens;
+
+ for (auto index_set : cand_indexes) {
+ if (mask_cnt > num_to_predict) {
+ break;
+ }
+ int index_set_size = index_set.size();
+ if (mask_cnt + index_set_size > num_to_predict) {
+ continue;
+ }
+ bool is_any_index_covered = false;
+ for (auto index : index_set) {
+ if (covered_indexes.find(index) != covered_indexes.end()) {
+ is_any_index_covered = true;
+ break;
+ }
+ }
+ if (is_any_index_covered) {
+ continue;
+ }
+ for (auto index : index_set) {
+
+ covered_indexes.insert(index);
+ std::string masked_token;
+ if (u1(e) < 0.8) {
+ masked_token = "[MASK]";
+ }
+ else {
+ if (u1(e) < 0.5) {
+ masked_token = output_tokens[index];
+ }
+ else {
+ int random_index = u2(e);
+ masked_token = vocab_words[random_index];
+ }
+ }
+ // masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index]));
+ masked_lm_output[index] = vocab[output_tokens[index]];
+ output_tokens[index] = masked_token;
+ mask_cnt++;
+ }
+ }
+
+ // for (auto p : masked_lms) {
+ // masked_lm_output[p.index] = vocab[p.label];
+ // }
+ return std::make_tuple(output_tokens, masked_lm_output);
+}
+
+PYBIND11_MODULE(mask, m) {
+ m.def("create_whole_masked_lm_predictions", &create_whole_masked_lm_predictions);
+ m.def("get_new_segment", &get_new_segment);
+}
diff --git a/examples/language/roberta/preprocessing/sentence_split.py b/examples/language/roberta/preprocessing/sentence_split.py
new file mode 100644
index 000000000..231be152b
--- /dev/null
+++ b/examples/language/roberta/preprocessing/sentence_split.py
@@ -0,0 +1,163 @@
+
+import multiprocessing
+import os
+import re
+from tqdm import tqdm
+from typing import List
+import json
+import time
+import argparse
+import functools
+
+def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]:
+ """
+ Args:
+ document:
+ flag: Type:str, "all" 中英文标点分句,"zh" 中文标点分句,"en" 英文标点分句
+ limit: 默认单句最大长度为510个字符
+ Returns: Type:list
+ """
+ sent_list = []
+ try:
+ if flag == "zh":
+ document = re.sub('(?P([。?!…](?![”’"\'])))', r'\g\n', document) # 单字符断句符
+ document = re.sub('(?P([。?!]|…{1,2})[”’"\'])', r'\g\n', document) # 特殊引号
+ elif flag == "en":
+ document = re.sub('(?P([.?!](?![”’"\'])))', r'\g\n', document) # 英文单字符断句符
+ document = re.sub('(?P([?!.]["\']))', r'\g\n', document) # 特殊引号
+ else:
+ document = re.sub('(?P([。?!….?!](?![”’"\'])))', r'\g\n', document) # 单字符断句符
+
+ document = re.sub('(?P(([。?!.!?]|…{1,2})[”’"\']))', r'\g\n',
+ document) # 特殊引号
+
+ sent_list_ori = document.splitlines()
+ for sent in sent_list_ori:
+ sent = sent.strip()
+ if not sent:
+ continue
+ elif len(sent) <= 2:
+ continue
+ else:
+ while len(sent) > limit:
+ temp = sent[0:limit]
+ sent_list.append(temp)
+ sent = sent[limit:]
+ sent_list.append(sent)
+ except:
+ sent_list.clear()
+ sent_list.append(document)
+ return sent_list
+
+
+def get_sent(output_path,
+ input_path,
+ fin_list=[], host=-1, seq_len=512) -> None:
+
+ workers = 32
+
+ if input_path[-1] == '/':
+ input_path = input_path[:-1]
+
+ cur_path = os.path.join(output_path, str(host) + '.txt')
+ new_split_sentence = functools.partial(split_sentence, limit=seq_len-2)
+ with open(cur_path, 'w', encoding='utf-8') as f:
+ for fi, fin_path in enumerate(fin_list):
+ if not os.path.exists(os.path.join(input_path, fin_path[0])):
+ continue
+ if '.json' not in fin_path[0]:
+ continue
+
+ print("Processing ", fin_path[0], " ", fi)
+
+ with open(os.path.join(input_path, fin_path[0]), 'r') as fin:
+ f_data = [l['content'] for l in json.load(fin)]
+
+ pool = multiprocessing.Pool(workers)
+ all_sent = pool.imap_unordered(new_split_sentence, f_data, 32)
+ pool.close()
+ print('finished..')
+
+ cnt = 0
+ for d in tqdm(all_sent):
+ for i in d:
+ f.write(i.strip() + '\n')
+ f.write(']]' + '\n')
+ cnt += 1
+ # if cnt >= 2:
+ # exit()
+
+
+def getFileSize(filepath, shard):
+ all_data = []
+ for i in os.listdir(filepath):
+ all_data.append(os.path.join(filepath, i))
+ all_size = sum([os.path.getsize(os.path.join(filepath, f)) for f in all_data])
+ ans = [[f.split('/')[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data]
+ ans = sorted(ans, key=lambda x: x[1], reverse=True)
+ per_size = all_size / shard
+ real_shard = []
+ temp = []
+ accu_size = 0
+ for i in ans:
+ accu_size += i[1]
+ temp.append(i)
+ if accu_size > per_size:
+ real_shard.append(temp)
+ accu_size = 0
+ temp = []
+
+ if len(temp) > 0:
+ real_shard.append(temp)
+
+ return real_shard
+
+
+def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'):
+ import socket
+ host = int(socket.gethostname().split(server_name)[-1])
+
+ fin_list = real_shard[server_num * base + host - 1]
+ print(fin_list)
+ print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}')
+ return fin_list, host
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--server_num', type=int, default=10, help='number of servers')
+ parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
+ parser.add_argument('--shard', type=int, default=100, help='number of shards, e.g., 10, 50, or 100')
+ parser.add_argument('--input_path', type=str, required=True, help='input path of original corpus')
+ parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence')
+ args = parser.parse_args()
+
+ server_num = args.server_num
+ seq_len = args.seq_len
+ shard = args.shard
+ input_path = args.input_path
+ output_path = args.output_path
+
+ real_shard = getFileSize(input_path, shard)
+
+ start = time.time()
+ for index, shard in enumerate(real_shard):
+ get_sent(output_path,
+ input_path,
+ fin_list=shard,
+ host=index,
+ seq_len=seq_len)
+ print(f'cost {str(time.time() - start)}')
+
+ # if you have multiple server, you can use code below or modify code to openmpi
+
+ # for i in range(len(real_shard) // server_num + 1):
+ # fin_list, host = get_start_end(real_shard, i)
+
+ # start = time.time()
+ # get_sent(output_path,
+ # input_path,
+ # fin_list=fin_list, host= 10 * i + host - 1)
+
+ # print(f'cost {str(time.time() - start)}')
diff --git a/examples/language/roberta/preprocessing/tokenize_mask.py b/examples/language/roberta/preprocessing/tokenize_mask.py
new file mode 100644
index 000000000..b33871d5d
--- /dev/null
+++ b/examples/language/roberta/preprocessing/tokenize_mask.py
@@ -0,0 +1,275 @@
+import time
+import os
+import psutil
+import h5py
+import socket
+import argparse
+import numpy as np
+import multiprocessing
+from tqdm import tqdm
+from random import shuffle
+from transformers import AutoTokenizer
+from get_mask import PreTrainingDataset
+
+
+def get_raw_instance(document, max_sequence_length=512):
+
+ """
+ 获取初步的训练实例,将整段按照max_sequence_length切分成多个部分,并以多个处理好的实例的形式返回。
+ :param document: 一整段
+ :param max_sequence_length:
+ :return: a list. each element is a sequence of text
+ """
+ # document = self.documents[index]
+ max_sequence_length_allowed = max_sequence_length - 2
+ # document = [seq for seq in document if len(seq)= max_sequence_length_allowed:
+ if len(curr_seq) > 0:
+ result_list.append(curr_seq)
+ curr_seq = []
+ result_list.append(document[sz_idx][ : max_sequence_length_allowed])
+ sz_idx += 1
+ else:
+ result_list.append(curr_seq)
+ curr_seq = []
+ # 对最后一个序列进行处理,如果太短的话,丢弃掉。
+ if len(curr_seq) > max_sequence_length_allowed / 2: # /2
+ result_list.append(curr_seq)
+
+ # # 计算总共可以得到多少份
+ # num_instance=int(len(big_list)/max_sequence_length_allowed)+1
+ # print("num_instance:",num_instance)
+ # # 切分成多份,添加到列表中
+ # result_list=[]
+ # for j in range(num_instance):
+ # index=j*max_sequence_length_allowed
+ # end_index=index+max_sequence_length_allowed if j!=num_instance-1 else -1
+ # result_list.append(big_list[index:end_index])
+ return result_list
+
+
+def split_numpy_chunk(path, tokenizer, pretrain_data, host):
+
+ documents = []
+ instances = []
+
+ s = time.time()
+ with open(path, encoding='utf-8') as fd:
+ document = []
+ for i, line in enumerate(tqdm(fd)):
+ line = line.strip()
+ # document = line
+ # if len(document.split("")) <= 3:
+ # continue
+ if len(line
+ ) > 0 and line[:2] == "]]": # This is end of document
+ documents.append(document)
+ document = []
+ elif len(line) >= 2:
+ document.append(line)
+ if len(document) > 0:
+ documents.append(document)
+ print('read_file ', time.time() - s)
+
+ # documents = [x for x in documents if x]
+ # print(len(documents))
+ # print(len(documents[0]))
+ # print(documents[0][0:10])
+ from typing import List
+ import multiprocessing
+
+ ans = []
+ for docs in tqdm(documents):
+ ans.append(pretrain_data.tokenize(docs))
+ print(time.time() - s)
+ del documents
+
+ instances = []
+ for a in tqdm(ans):
+ raw_ins = get_raw_instance(a)
+ instances.extend(raw_ins)
+ del ans
+
+ print('len instance', len(instances))
+
+ sen_num = len(instances)
+ seq_len = 512
+ input_ids = np.zeros([sen_num, seq_len], dtype=np.int32)
+ input_mask = np.zeros([sen_num, seq_len], dtype=np.int32)
+ segment_ids = np.zeros([sen_num, seq_len], dtype=np.int32)
+ masked_lm_output = np.zeros([sen_num, seq_len], dtype=np.int32)
+
+ for index, ins in tqdm(enumerate(instances)):
+ mask_dict = pretrain_data.create_training_instance(ins)
+ input_ids[index] = mask_dict[0]
+ input_mask[index] = mask_dict[1]
+ segment_ids[index] = mask_dict[2]
+ masked_lm_output[index] = mask_dict[3]
+
+ with h5py.File(f'/output/{host}.h5', 'w') as hf:
+ hf.create_dataset("input_ids", data=input_ids)
+ hf.create_dataset("input_mask", data=input_ids)
+ hf.create_dataset("segment_ids", data=segment_ids)
+ hf.create_dataset("masked_lm_positions", data=masked_lm_output)
+
+ del instances
+
+
+def split_numpy_chunk_pool(input_path,
+ output_path,
+ pretrain_data,
+ worker,
+ dupe_factor,
+ seq_len,
+ file_name):
+
+ if os.path.exists(os.path.join(output_path, f'{file_name}.h5')):
+ print(f'{file_name}.h5 exists')
+ return
+
+ documents = []
+ instances = []
+
+ s = time.time()
+ with open(input_path, 'r', encoding='utf-8') as fd:
+ document = []
+ for i, line in enumerate(tqdm(fd)):
+ line = line.strip()
+ if len(line
+ ) > 0 and line[:2] == "]]": # This is end of document
+ documents.append(document)
+ document = []
+ elif len(line) >= 2:
+ document.append(line)
+ if len(document) > 0:
+ documents.append(document)
+ print(f'read_file cost {time.time() - s}, length is {len(documents)}')
+
+ ans = []
+ s = time.time()
+ pool = multiprocessing.Pool(worker)
+ encoded_doc = pool.imap_unordered(pretrain_data.tokenize, documents, 100)
+ for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour='cyan'):
+ ans.append(res)
+ pool.close()
+ print((time.time() - s) / 60)
+ del documents
+
+ instances = []
+ for a in tqdm(ans, colour='MAGENTA'):
+ raw_ins = get_raw_instance(a, max_sequence_length=seq_len)
+ instances.extend(raw_ins)
+ del ans
+
+ print('len instance', len(instances))
+
+ new_instances = []
+ for _ in range(dupe_factor):
+ for ins in instances:
+ new_instances.append(ins)
+
+ shuffle(new_instances)
+ instances = new_instances
+ print('after dupe_factor, len instance', len(instances))
+
+ sentence_num = len(instances)
+ input_ids = np.zeros([sentence_num, seq_len], dtype=np.int32)
+ input_mask = np.zeros([sentence_num, seq_len], dtype=np.int32)
+ segment_ids = np.zeros([sentence_num, seq_len], dtype=np.int32)
+ masked_lm_output = np.zeros([sentence_num, seq_len], dtype=np.int32)
+
+ s = time.time()
+ pool = multiprocessing.Pool(worker)
+ encoded_docs = pool.imap_unordered(pretrain_data.create_training_instance, instances, 32)
+ for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour='blue'):
+ input_ids[index] = mask_dict[0]
+ input_mask[index] = mask_dict[1]
+ segment_ids[index] = mask_dict[2]
+ masked_lm_output[index] = mask_dict[3]
+ pool.close()
+ print((time.time() - s) / 60)
+
+ with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf:
+ hf.create_dataset("input_ids", data=input_ids)
+ hf.create_dataset("input_mask", data=input_mask)
+ hf.create_dataset("segment_ids", data=segment_ids)
+ hf.create_dataset("masked_lm_positions", data=masked_lm_output)
+
+ del instances
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer')
+ parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
+ parser.add_argument('--max_predictions_per_seq', type=int, default=80, help='number of shards, e.g., 10, 50, or 100')
+ parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence')
+ parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id')
+ parser.add_argument('--backend', type=str, default='python', help='backend of mask token, python, c++, numpy respectively')
+ parser.add_argument('--dupe_factor', type=int, default=1, help='specifies how many times the preprocessor repeats to create the input from the same article/document')
+ parser.add_argument('--worker', type=int, default=32, help='number of process')
+ parser.add_argument('--server_num', type=int, default=10, help='number of servers')
+ args = parser.parse_args()
+
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
+ pretrain_data = PreTrainingDataset(tokenizer,
+ args.seq_len,
+ args.backend,
+ max_predictions_per_seq=args.max_predictions_per_seq)
+
+
+ data_len = len(os.listdir(args.input_path))
+
+ for i in range(data_len):
+ input_path = os.path.join(args.input_path, f'{i}.txt')
+ if os.path.exists(input_path):
+ start = time.time()
+ print(f'process {input_path}')
+ split_numpy_chunk_pool(input_path,
+ args.output_path,
+ pretrain_data,
+ args.worker,
+ args.dupe_factor,
+ args.seq_len,
+ i)
+ end_ = time.time()
+ print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
+ print(f'has cost {(end_ - start) / 60}')
+ print('-' * 100)
+ print('')
+
+ # if you have multiple server, you can use code below or modify code to openmpi
+
+ # host = int(socket.gethostname().split('GPU')[-1])
+ # for i in range(data_len // args.server_num + 1):
+ # h = args.server_num * i + host - 1
+ # input_path = os.path.join(args.input_path, f'{h}.txt')
+ # if os.path.exists(input_path):
+ # start = time.time()
+ # print(f'I am server {host}, process {input_path}')
+ # split_numpy_chunk_pool(input_path,
+ # args.output_path,
+ # pretrain_data,
+ # args.worker,
+ # args.dupe_factor,
+ # args.seq_len,
+ # h)
+ # end_ = time.time()
+ # print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
+ # print(f'has cost {(end_ - start) / 60}')
+ # print('-' * 100)
+ # print('')
+
+
diff --git a/examples/language/roberta/pretraining/README.md b/examples/language/roberta/pretraining/README.md
new file mode 100644
index 000000000..055d69696
--- /dev/null
+++ b/examples/language/roberta/pretraining/README.md
@@ -0,0 +1,24 @@
+# Pretraining
+1. Pretraining roberta through running the script below. Detailed parameter descriptions can be found in the arguments.py. `data_path_prefix` is absolute path specifies output of preprocessing. **You have to modify the *hostfile* according to your cluster.**
+
+```bash
+bash run_pretrain.sh
+```
+* `--hostfile`: servers' host name from /etc/hosts
+* `--include`: servers which will be used
+* `--nproc_per_node`: number of process(GPU) from each server
+* `--data_path_prefix`: absolute location of train data, e.g., /h5/0.h5
+* `--eval_data_path_prefix`: absolute location of eval data
+* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json, e.g./tokenizer/tokenizer.json
+* `--bert_config`: config.json which represent model
+* `--mlm`: model type of backbone, bert or deberta_v2
+
+2. if resume training from earylier checkpoint, run the script below.
+
+```shell
+bash run_pretrain_resume.sh
+```
+* `--resume_train`: whether to resume training
+* `--load_pretrain_model`: absolute path which contains model checkpoint
+* `--load_optimizer_lr`: absolute path which contains optimizer checkpoint
+
diff --git a/examples/language/roberta/pretraining/arguments.py b/examples/language/roberta/pretraining/arguments.py
new file mode 100644
index 000000000..3a9370e00
--- /dev/null
+++ b/examples/language/roberta/pretraining/arguments.py
@@ -0,0 +1,152 @@
+import colossalai
+from numpy import require
+
+__all__ = ['parse_args']
+
+
+def parse_args():
+ parser = colossalai.get_default_parser()
+
+ parser.add_argument(
+ '--lr',
+ type=float,
+ required=True,
+ help='initial learning rate')
+ parser.add_argument(
+ '--epoch',
+ type=int,
+ required=True,
+ help='number of epoch')
+ parser.add_argument(
+ '--data_path_prefix',
+ type=str,
+ required=True,
+ help="location of the train data corpus")
+ parser.add_argument(
+ '--eval_data_path_prefix',
+ type=str,
+ required=True,
+ help='location of the evaluation data corpus')
+ parser.add_argument(
+ '--tokenizer_path',
+ type=str,
+ required=True,
+ help='location of the tokenizer')
+ parser.add_argument(
+ '--max_seq_length',
+ type=int,
+ default=512,
+ help='sequence length')
+ parser.add_argument(
+ '--refresh_bucket_size',
+ type=int,
+ default=1,
+ help=
+ "This param makes sure that a certain task is repeated for this time steps to \
+ optimise on the back propogation speed with APEX's DistributedDataParallel")
+ parser.add_argument(
+ "--max_predictions_per_seq",
+ "--max_pred",
+ default=80,
+ type=int,
+ help=
+ "The maximum number of masked tokens in a sequence to be predicted.")
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ default=1,
+ type=int,
+ help="accumulation_steps")
+ parser.add_argument(
+ "--train_micro_batch_size_per_gpu",
+ default=2,
+ type=int,
+ required=True,
+ help="train batch size")
+ parser.add_argument(
+ "--eval_micro_batch_size_per_gpu",
+ default=2,
+ type=int,
+ required=True,
+ help="eval batch size")
+ parser.add_argument(
+ "--num_workers",
+ default=8,
+ type=int,
+ help="")
+ parser.add_argument(
+ "--async_worker",
+ action='store_true',
+ help="")
+ parser.add_argument(
+ "--bert_config",
+ required=True,
+ type=str,
+ help="location of config.json")
+ parser.add_argument(
+ "--wandb",
+ action='store_true',
+ help="use wandb to watch model")
+ parser.add_argument(
+ "--wandb_project_name",
+ default='roberta',
+ help="wandb project name")
+ parser.add_argument(
+ "--log_interval",
+ default=100,
+ type=int,
+ help="report interval")
+ parser.add_argument(
+ "--log_path",
+ type=str,
+ required=True,
+ help="log file which records train step")
+ parser.add_argument(
+ "--tensorboard_path",
+ type=str,
+ required=True,
+ help="location of tensorboard file")
+ parser.add_argument(
+ "--colossal_config",
+ type=str,
+ required=True,
+ help="colossal config, which contains zero config and so on")
+ parser.add_argument(
+ "--ckpt_path",
+ type=str,
+ required=True,
+ help="location of saving checkpoint, which contains model and optimizer")
+ parser.add_argument(
+ '--seed',
+ type=int,
+ default=42,
+ help="random seed for initialization")
+ parser.add_argument(
+ '--vscode_debug',
+ action='store_true',
+ help="use vscode to debug")
+ parser.add_argument(
+ '--load_pretrain_model',
+ default='',
+ type=str,
+ help="location of model's checkpoin")
+ parser.add_argument(
+ '--load_optimizer_lr',
+ default='',
+ type=str,
+ help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step")
+ parser.add_argument(
+ '--resume_train',
+ action='store_true',
+ help="whether resume training from a early checkpoint")
+ parser.add_argument(
+ '--mlm',
+ default='bert',
+ type=str,
+ help="model type, bert or deberta")
+ parser.add_argument(
+ '--checkpoint_activations',
+ action='store_true',
+ help="whether to use gradient checkpointing")
+
+ args = parser.parse_args()
+ return args
diff --git a/examples/language/roberta/pretraining/bert_dataset_provider.py b/examples/language/roberta/pretraining/bert_dataset_provider.py
new file mode 100644
index 000000000..1d8cf2a91
--- /dev/null
+++ b/examples/language/roberta/pretraining/bert_dataset_provider.py
@@ -0,0 +1,15 @@
+class BertDatasetProviderInterface:
+ def get_shard(self, index, shuffle=True):
+ raise NotImplementedError
+
+ def release_shard(self, index):
+ raise NotImplementedError
+
+ def prefetch_shard(self, index):
+ raise NotImplementedError
+
+ def get_batch(self, batch_iter):
+ raise NotImplementedError
+
+ def prefetch_batch(self):
+ raise NotImplementedError
diff --git a/examples/language/roberta/pretraining/evaluation.py b/examples/language/roberta/pretraining/evaluation.py
new file mode 100644
index 000000000..83f94082f
--- /dev/null
+++ b/examples/language/roberta/pretraining/evaluation.py
@@ -0,0 +1,71 @@
+import os
+import math
+import torch
+from tqdm import tqdm
+from utils.global_vars import get_timers, get_tensorboard_writer
+from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
+
+def evaluate(engine, args, logger, global_step):
+ evaluate_dataset_provider = NvidiaBertDatasetProvider(args, evaluate=True)
+ start_shard = 0
+
+ engine.eval()
+ timers = get_timers()
+ eval_step = 0
+ eval_loss = 0
+ cur_loss = 0
+ world_size = torch.distributed.get_world_size()
+
+ with torch.no_grad():
+
+ for shard in range(start_shard, len(os.listdir(args.eval_data_path_prefix))):
+
+ timers('eval_shard_time').start()
+
+ dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard)
+ # evaluate_dataset_provider.prefetch_shard(shard + 1)
+ if torch.distributed.get_rank() == 0:
+ iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), colour='MAGENTA', smoothing=1)
+ else:
+ iterator_data = enumerate(dataset_iterator)
+
+ for step, batch_data in iterator_data: #tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1):
+
+ # batch_data = pretrain_dataset_provider.get_batch(batch_index)
+ eval_step += 1
+ input_ids = batch_data[0].cuda()
+ attention_mask = batch_data[1].cuda()
+ token_type_ids = batch_data[2].cuda()
+ mlm_label = batch_data[3].cuda()
+ # nsp_label = batch_data[5].cuda()
+
+ output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
+
+ loss = engine.criterion(output.logits, mlm_label)#prediction_scores
+ evaluate_dataset_provider.prefetch_batch()
+
+ eval_loss += loss.float().item()
+
+ cur_loss = eval_loss / eval_step
+ elapsed_time = timers("eval_shard_time").elapsed()
+ elapsed_time_per_iteration = elapsed_time / eval_step
+ ppl = math.exp(cur_loss)
+
+ if args.wandb and torch.distributed.get_rank() == 0:
+ tensorboard_log = get_tensorboard_writer()
+ tensorboard_log.log_eval({
+ 'loss': cur_loss,
+ 'ppl': ppl,
+ 'mins_batch': elapsed_time_per_iteration
+ }, global_step)
+
+ eval_log_str = f'evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \
+ f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}'
+
+ logger.info(eval_log_str)
+ logger.info('-' * 100)
+ logger.info('')
+
+ evaluate_dataset_provider.release_shard()
+ engine.train()
+ return cur_loss
diff --git a/examples/language/roberta/pretraining/hostfile b/examples/language/roberta/pretraining/hostfile
new file mode 100644
index 000000000..f4e047f01
--- /dev/null
+++ b/examples/language/roberta/pretraining/hostfile
@@ -0,0 +1,10 @@
+GPU001
+GPU002
+GPU003
+GPU004
+GPU005
+GPU006
+GPU007
+GPU008
+GPU009
+GPU010
diff --git a/examples/language/roberta/pretraining/loss.py b/examples/language/roberta/pretraining/loss.py
new file mode 100644
index 000000000..dc4f872a7
--- /dev/null
+++ b/examples/language/roberta/pretraining/loss.py
@@ -0,0 +1,17 @@
+import torch
+
+__all__ = ['LossForPretraining']
+
+
+class LossForPretraining(torch.nn.Module):
+
+ def __init__(self, vocab_size):
+ super(LossForPretraining, self).__init__()
+ self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1)
+ self.vocab_size = vocab_size
+
+ def forward(self, prediction_scores, masked_lm_labels, next_sentence_labels=None):
+ masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1))
+ # next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1))
+ total_loss = masked_lm_loss #+ next_sentence_loss
+ return total_loss
diff --git a/examples/language/roberta/pretraining/model/bert.py b/examples/language/roberta/pretraining/model/bert.py
new file mode 100644
index 000000000..67c85f760
--- /dev/null
+++ b/examples/language/roberta/pretraining/model/bert.py
@@ -0,0 +1,1893 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch BERT model."""
+
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from packaging import version
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from transformers.utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from transformers.models.bert.configuration_bert import BertConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "bert-base-uncased"
+_CONFIG_FOR_DOC = "BertConfig"
+_TOKENIZER_FOR_DOC = "BertTokenizer"
+
+# TokenClassification docstring
+_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
+_TOKEN_CLASS_EXPECTED_OUTPUT = (
+ "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] "
+)
+_TOKEN_CLASS_EXPECTED_LOSS = 0.01
+
+# QuestionAnswering docstring
+_CHECKPOINT_FOR_QA = "deepset/bert-base-cased-squad2"
+_QA_EXPECTED_OUTPUT = "'a nice puppet'"
+_QA_EXPECTED_LOSS = 7.41
+_QA_TARGET_START_INDEX = 14
+_QA_TARGET_END_INDEX = 15
+
+# SequenceClassification docstring
+_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity"
+_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
+_SEQ_CLASS_EXPECTED_LOSS = 0.01
+
+
+BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "bert-base-uncased",
+ "bert-large-uncased",
+ "bert-base-cased",
+ "bert-large-cased",
+ "bert-base-multilingual-uncased",
+ "bert-base-multilingual-cased",
+ "bert-base-chinese",
+ "bert-base-german-cased",
+ "bert-large-uncased-whole-word-masking",
+ "bert-large-cased-whole-word-masking",
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
+ "bert-large-cased-whole-word-masking-finetuned-squad",
+ "bert-base-cased-finetuned-mrpc",
+ "bert-base-german-dbmdz-cased",
+ "bert-base-german-dbmdz-uncased",
+ "cl-tohoku/bert-base-japanese",
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
+ "cl-tohoku/bert-base-japanese-char",
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking",
+ "TurkuNLP/bert-base-finnish-cased-v1",
+ "TurkuNLP/bert-base-finnish-uncased-v1",
+ "wietsedv/bert-base-dutch-cased",
+ # See all BERT models at https://huggingface.co/models?filter=bert
+]
+
+
+def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
+ """Load tf checkpoints in a pytorch model."""
+ try:
+ import re
+
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(tf_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array)
+
+ for name, array in zip(names, arrays):
+ name = name.split("/")
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if any(
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
+ for n in name
+ ):
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
+ scope_names = re.split(r"_(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "output_weights":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "squad":
+ pointer = getattr(pointer, "classifier")
+ else:
+ try:
+ pointer = getattr(pointer, scope_names[0])
+ except AttributeError:
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ if m_name[-11:] == "_embeddings":
+ pointer = getattr(pointer, "weight")
+ elif m_name == "kernel":
+ array = np.transpose(array)
+ try:
+ if pointer.shape != array.shape:
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
+ except AssertionError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info(f"Initialize PyTorch weight {name}")
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ if version.parse(torch.__version__) > version.parse("1.6.0"):
+ self.register_buffer(
+ "token_type_ids",
+ torch.zeros(self.position_ids.size(), dtype=torch.long),
+ persistent=False,
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values_length: int = 0,
+ ) -> torch.Tensor:
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = position_embedding_type or getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+
+ self.is_decoder = config.is_decoder
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None):
+ super().__init__()
+ self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
+ self.output = BertSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BertAttention(config)
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = BertAttention(config, position_embedding_type="absolute")
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ if self.is_decoder:
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+ else:
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ cross_attn_present_key_value = None
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
+ )
+
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ cross_attn_past_key_value,
+ output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
+ cross_attn_present_key_value = cross_attention_outputs[-1]
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ # if decoder, return the attn key/values as the last output
+ if self.is_decoder:
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ if use_cache:
+ logger.warning(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, past_key_value, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BertOnlyNSPHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, pooled_output):
+ seq_relationship_score = self.seq_relationship(pooled_output)
+ return seq_relationship_score
+
+
+class BertPreTrainingHeads(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, sequence_output, pooled_output):
+ prediction_scores = self.predictions(sequence_output)
+ seq_relationship_score = self.seq_relationship(pooled_output)
+ return prediction_scores, seq_relationship_score
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ load_tf_weights = load_tf_weights_in_bert
+ base_model_prefix = "bert"
+ supports_gradient_checkpointing = True
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, BertEncoder):
+ module.gradient_checkpointing = value
+
+
+@dataclass
+class BertForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`BertForPreTraining`].
+
+ Args:
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
+ (classification) loss.
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
+ before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ prediction_logits: torch.FloatTensor = None
+ seq_relationship_logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+BERT_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`BertConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+BERT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
+ BERT_START_DOCSTRING,
+)
+class BertModel(BertPreTrainedModel):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BertEmbeddings(config)
+ self.encoder = BertEncoder(config)
+
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.config.is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
+ sentence prediction (classification)` head.
+ """,
+ BERT_START_DOCSTRING,
+)
+class BertForPreTraining(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config)
+ self.cls = BertPreTrainingHeads(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ next_sentence_label: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
+ the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
+ pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
+
+ - 0 indicates sequence B is a continuation of sequence A,
+ - 1 indicates sequence B is a random sequence.
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
+ Used to hide legacy arguments that have been deprecated.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import BertTokenizer, BertForPreTraining
+ >>> import torch
+
+ >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
+ >>> model = BertForPreTraining.from_pretrained("bert-base-uncased")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.prediction_logits
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
+ ```
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output, pooled_output = outputs[:2]
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
+
+ total_loss = None
+ if labels is not None and next_sentence_label is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
+ total_loss = masked_lm_loss + next_sentence_loss
+
+ if not return_dict:
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return BertForPreTrainingOutput(
+ loss=total_loss,
+ prediction_logits=prediction_scores,
+ seq_relationship_logits=seq_relationship_score,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
+)
+class BertLMHeadModel(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if not config.is_decoder:
+ logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutputWithCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
+
+ def _reorder_cache(self, past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
+
+
+@add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
+class BertForMaskedLM(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if config.is_decoder:
+ logger.warning(
+ "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
+ "bi-directional self-attention."
+ )
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output="'paris'",
+ expected_loss=0.88,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+ effective_batch_size = input_shape[0]
+
+ # add a dummy token
+ if self.config.pad_token_id is None:
+ raise ValueError("The PAD token should be defined for generation")
+
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
+ dummy_token = torch.full(
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
+ )
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+
+@add_start_docstrings(
+ """Bert Model with a `next sentence prediction (classification)` head on top.""",
+ BERT_START_DOCSTRING,
+)
+class BertForNextSentencePrediction(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config)
+ self.cls = BertOnlyNSPHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
+ (see `input_ids` docstring). Indices should be in `[0, 1]`:
+
+ - 0 indicates sequence B is a continuation of sequence A,
+ - 1 indicates sequence B is a random sequence.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import BertTokenizer, BertForNextSentencePrediction
+ >>> import torch
+
+ >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
+ >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased")
+
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
+
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
+ >>> logits = outputs.logits
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
+ ```
+ """
+
+ if "next_sentence_label" in kwargs:
+ warnings.warn(
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
+ " `labels` instead.",
+ FutureWarning,
+ )
+ labels = kwargs.pop("next_sentence_label")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ seq_relationship_scores = self.cls(pooled_output)
+
+ next_sentence_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
+
+ if not return_dict:
+ output = (seq_relationship_scores,) + outputs[2:]
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
+
+ return NextSentencePredictorOutput(
+ loss=next_sentence_loss,
+ logits=seq_relationship_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+ output) e.g. for GLUE tasks.
+ """,
+ BERT_START_DOCSTRING,
+)
+class BertForSequenceClassification(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.bert = BertModel(config)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ BERT_START_DOCSTRING,
+)
+class BertForMultipleChoice(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ BERT_START_DOCSTRING,
+)
+class BertForTokenClassification(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,
+ expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ BERT_START_DOCSTRING,
+)
+class BertForQuestionAnswering(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_QA,
+ output_type=QuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ qa_target_start_index=_QA_TARGET_START_INDEX,
+ qa_target_end_index=_QA_TARGET_END_INDEX,
+ expected_output=_QA_EXPECTED_OUTPUT,
+ expected_loss=_QA_EXPECTED_LOSS,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.Tensor] = None,
+ end_positions: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/examples/language/roberta/pretraining/model/deberta_v2.py b/examples/language/roberta/pretraining/model/deberta_v2.py
new file mode 100644
index 000000000..c6ce82847
--- /dev/null
+++ b/examples/language/roberta/pretraining/model/deberta_v2.py
@@ -0,0 +1,1631 @@
+# coding=utf-8
+# Copyright 2020 Microsoft and the Hugging Face Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch DeBERTa-v2 model."""
+
+import math
+from collections.abc import Sequence
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import (
+ BaseModelOutput,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.pytorch_utils import softmax_backward_data
+from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from transformers.models.deberta_v2.configuration_deberta_v2 import DebertaV2Config
+from transformers import T5Tokenizer, T5ForConditionalGeneration, FillMaskPipeline
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "DebertaV2Config"
+_TOKENIZER_FOR_DOC = "DebertaV2Tokenizer"
+_CHECKPOINT_FOR_DOC = "microsoft/deberta-v2-xlarge"
+
+DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "microsoft/deberta-v2-xlarge",
+ "microsoft/deberta-v2-xxlarge",
+ "microsoft/deberta-v2-xlarge-mnli",
+ "microsoft/deberta-v2-xxlarge-mnli",
+]
+
+
+# Copied from transformers.models.deberta.modeling_deberta.ContextPooler
+class ContextPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
+ self.dropout = StableDropout(config.pooler_dropout)
+ self.config = config
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+
+ context_token = hidden_states[:, 0]
+ context_token = self.dropout(context_token)
+ pooled_output = self.dense(context_token)
+ pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
+ return pooled_output
+
+ @property
+ def output_dim(self):
+ return self.config.hidden_size
+
+
+# Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2
+class XSoftmax(torch.autograd.Function):
+ """
+ Masked Softmax which is optimized for saving memory
+
+ Args:
+ input (`torch.tensor`): The input tensor that will apply softmax.
+ mask (`torch.IntTensor`):
+ The mask matrix where 0 indicate that element will be ignored in the softmax calculation.
+ dim (int): The dimension that will apply softmax
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax
+
+ >>> # Make a tensor
+ >>> x = torch.randn([4, 20, 100])
+
+ >>> # Create a mask
+ >>> mask = (x > 0).int()
+
+ >>> # Specify the dimension to apply softmax
+ >>> dim = -1
+
+ >>> y = XSoftmax.apply(x, mask, dim)
+ ```"""
+
+ @staticmethod
+ def forward(self, input, mask, dim):
+ self.dim = dim
+ rmask = ~(mask.to(torch.bool))
+
+ output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
+ output = torch.softmax(output, self.dim)
+ output.masked_fill_(rmask, 0)
+ self.save_for_backward(output)
+ return output
+
+ @staticmethod
+ def backward(self, grad_output):
+ (output,) = self.saved_tensors
+ inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
+ return inputGrad, None, None
+
+ @staticmethod
+ def symbolic(g, self, mask, dim):
+ import torch.onnx.symbolic_helper as sym_help
+ from torch.onnx.symbolic_opset9 import masked_fill, softmax
+
+ mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"])
+ r_mask = g.op(
+ "Cast",
+ g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
+ to_i=sym_help.cast_pytorch_to_onnx["Byte"],
+ )
+ output = masked_fill(
+ g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
+ )
+ output = softmax(g, output, dim)
+ return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
+class DropoutContext(object):
+ def __init__(self):
+ self.dropout = 0
+ self.mask = None
+ self.scale = 1
+ self.reuse_mask = True
+
+
+# Copied from transformers.models.deberta.modeling_deberta.get_mask
+def get_mask(input, local_context):
+ if not isinstance(local_context, DropoutContext):
+ dropout = local_context
+ mask = None
+ else:
+ dropout = local_context.dropout
+ dropout *= local_context.scale
+ mask = local_context.mask if local_context.reuse_mask else None
+
+ if dropout > 0 and mask is None:
+ mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
+
+ if isinstance(local_context, DropoutContext):
+ if local_context.mask is None:
+ local_context.mask = mask
+
+ return mask, dropout
+
+
+# Copied from transformers.models.deberta.modeling_deberta.XDropout
+class XDropout(torch.autograd.Function):
+ """Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""
+
+ @staticmethod
+ def forward(ctx, input, local_ctx):
+ mask, dropout = get_mask(input, local_ctx)
+ ctx.scale = 1.0 / (1 - dropout)
+ if dropout > 0:
+ ctx.save_for_backward(mask)
+ return input.masked_fill(mask, 0) * ctx.scale
+ else:
+ return input
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ if ctx.scale > 1:
+ (mask,) = ctx.saved_tensors
+ return grad_output.masked_fill(mask, 0) * ctx.scale, None
+ else:
+ return grad_output, None
+
+
+# Copied from transformers.models.deberta.modeling_deberta.StableDropout
+class StableDropout(nn.Module):
+ """
+ Optimized dropout module for stabilizing the training
+
+ Args:
+ drop_prob (float): the dropout probabilities
+ """
+
+ def __init__(self, drop_prob):
+ super().__init__()
+ self.drop_prob = drop_prob
+ self.count = 0
+ self.context_stack = None
+
+ def forward(self, x):
+ """
+ Call the module
+
+ Args:
+ x (`torch.tensor`): The input tensor to apply dropout
+ """
+ if self.training and self.drop_prob > 0:
+ return XDropout.apply(x, self.get_context())
+ return x
+
+ def clear_context(self):
+ self.count = 0
+ self.context_stack = None
+
+ def init_context(self, reuse_mask=True, scale=1):
+ if self.context_stack is None:
+ self.context_stack = []
+ self.count = 0
+ for c in self.context_stack:
+ c.reuse_mask = reuse_mask
+ c.scale = scale
+
+ def get_context(self):
+ if self.context_stack is not None:
+ if self.count >= len(self.context_stack):
+ self.context_stack.append(DropoutContext())
+ ctx = self.context_stack[self.count]
+ ctx.dropout = self.drop_prob
+ self.count += 1
+ return ctx
+ else:
+ return self.drop_prob
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm
+class DebertaV2SelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
+ self.dropout = StableDropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2
+class DebertaV2Attention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.self = DisentangledSelfAttention(config)
+ self.output = DebertaV2SelfOutput(config)
+ self.config = config
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ output_attentions=False,
+ query_states=None,
+ relative_pos=None,
+ rel_embeddings=None,
+ ):
+ self_output = self.self(
+ hidden_states,
+ attention_mask,
+ output_attentions,
+ query_states=query_states,
+ relative_pos=relative_pos,
+ rel_embeddings=rel_embeddings,
+ )
+ if output_attentions:
+ self_output, att_matrix = self_output
+ if query_states is None:
+ query_states = hidden_states
+ attention_output = self.output(self_output, query_states)
+
+ if output_attentions:
+ return (attention_output, att_matrix)
+ else:
+ return attention_output
+
+
+# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2
+class DebertaV2Intermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm
+class DebertaV2Output(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
+ self.dropout = StableDropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
+class DebertaV2Layer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.attention = DebertaV2Attention(config)
+ self.intermediate = DebertaV2Intermediate(config)
+ self.output = DebertaV2Output(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ query_states=None,
+ relative_pos=None,
+ rel_embeddings=None,
+ output_attentions=False,
+ ):
+ attention_output = self.attention(
+ hidden_states,
+ attention_mask,
+ output_attentions=output_attentions,
+ query_states=query_states,
+ relative_pos=relative_pos,
+ rel_embeddings=rel_embeddings,
+ )
+ if output_attentions:
+ attention_output, att_matrix = attention_output
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ if output_attentions:
+ return (layer_output, att_matrix)
+ else:
+ return layer_output
+
+
+class ConvLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ kernel_size = getattr(config, "conv_kernel_size", 3)
+ groups = getattr(config, "conv_groups", 1)
+ self.conv_act = getattr(config, "conv_act", "tanh")
+ self.conv = nn.Conv1d(
+ config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups
+ )
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
+ self.dropout = StableDropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def forward(self, hidden_states, residual_states, input_mask):
+ out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
+ rmask = (1 - input_mask).bool()
+ out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0)
+ out = ACT2FN[self.conv_act](self.dropout(out))
+
+ layer_norm_input = residual_states + out
+ output = self.LayerNorm(layer_norm_input).to(layer_norm_input)
+
+ if input_mask is None:
+ output_states = output
+ else:
+ if input_mask.dim() != layer_norm_input.dim():
+ if input_mask.dim() == 4:
+ input_mask = input_mask.squeeze(1).squeeze(1)
+ input_mask = input_mask.unsqueeze(2)
+
+ input_mask = input_mask.to(output.dtype)
+ output_states = output * input_mask
+
+ return output_states
+
+
+class DebertaV2Encoder(nn.Module):
+ """Modified BertEncoder with relative position bias support"""
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.layer = nn.ModuleList([DebertaV2Layer(config) for _ in range(config.num_hidden_layers)])
+ self.relative_attention = getattr(config, "relative_attention", False)
+
+ if self.relative_attention:
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
+ if self.max_relative_positions < 1:
+ self.max_relative_positions = config.max_position_embeddings
+
+ self.position_buckets = getattr(config, "position_buckets", -1)
+ pos_ebd_size = self.max_relative_positions * 2
+
+ if self.position_buckets > 0:
+ pos_ebd_size = self.position_buckets * 2
+
+ # rel = nn.Parameter(torch.empty((pos_ebd_size, config.hidden_size)))
+ # self.rel_embeddings = nn.init.normal_(rel, mean=0.0, std=config.initializer_range)
+ self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size)
+
+ self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")]
+
+ if "layer_norm" in self.norm_rel_ebd:
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
+
+ self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None
+ self.gradient_checkpointing = False
+
+ def get_rel_embedding(self):
+ att_span = self.position_buckets
+ rel_index = torch.arange(0, att_span * 2).long().to(self.rel_embeddings.weight.device)
+ rel_embeddings = self.rel_embeddings(rel_index)
+ # rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
+ # rel_embeddings = self.rel_embeddings if self.relative_attention else None
+ if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd):
+ rel_embeddings = self.LayerNorm(rel_embeddings)
+ return rel_embeddings
+
+ def get_attention_mask(self, attention_mask):
+ if attention_mask.dim() <= 2:
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
+ attention_mask = attention_mask.byte()
+ elif attention_mask.dim() == 3:
+ attention_mask = attention_mask.unsqueeze(1)
+
+ return attention_mask
+
+ def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
+ if self.relative_attention and relative_pos is None:
+ q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
+ relative_pos = build_relative_position(
+ q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions
+ )
+ return relative_pos
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ output_hidden_states=True,
+ output_attentions=False,
+ query_states=None,
+ relative_pos=None,
+ return_dict=True,
+ ):
+ if attention_mask.dim() <= 2:
+ input_mask = attention_mask
+ else:
+ input_mask = (attention_mask.sum(-2) > 0).byte()
+ attention_mask = self.get_attention_mask(attention_mask)
+ relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
+
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ if isinstance(hidden_states, Sequence):
+ next_kv = hidden_states[0]
+ else:
+ next_kv = hidden_states
+ rel_embeddings = self.get_rel_embedding()
+ output_states = next_kv
+ for i, layer_module in enumerate(self.layer):
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (output_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ output_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ next_kv,
+ attention_mask,
+ query_states,
+ relative_pos,
+ rel_embeddings,
+ )
+ else:
+ output_states = layer_module(
+ next_kv,
+ attention_mask,
+ query_states=query_states,
+ relative_pos=relative_pos,
+ rel_embeddings=rel_embeddings,
+ output_attentions=output_attentions,
+ )
+
+ if output_attentions:
+ output_states, att_m = output_states
+
+ if i == 0 and self.conv is not None:
+ output_states = self.conv(hidden_states, output_states, input_mask)
+
+ if query_states is not None:
+ query_states = output_states
+ if isinstance(hidden_states, Sequence):
+ next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None
+ else:
+ next_kv = output_states
+
+ if output_attentions:
+ all_attentions = all_attentions + (att_m,)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (output_states,)
+
+ if not return_dict:
+ return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions
+ )
+
+
+def make_log_bucket_position(relative_pos, bucket_size, max_position):
+ sign = np.sign(relative_pos)
+ mid = bucket_size // 2
+ abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos))
+ log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid
+ bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int)
+ return bucket_pos
+
+
+def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1):
+ """
+ Build relative position according to the query and key
+
+ We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
+ \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
+ P_k\\)
+
+ Args:
+ query_size (int): the length of query
+ key_size (int): the length of key
+ bucket_size (int): the size of position bucket
+ max_position (int): the maximum allowed absolute position
+
+ Return:
+ `torch.LongTensor`: A tensor with shape [1, query_size, key_size]
+
+ """
+ q_ids = np.arange(0, query_size)
+ k_ids = np.arange(0, key_size)
+ rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1))
+ if bucket_size > 0 and max_position > 0:
+ rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
+ rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
+ rel_pos_ids = rel_pos_ids[:query_size, :]
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
+ return rel_pos_ids
+
+
+@torch.jit.script
+# Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand
+def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
+
+
+@torch.jit.script
+# Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand
+def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
+
+
+@torch.jit.script
+# Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand
+def pos_dynamic_expand(pos_index, p2c_att, key_layer):
+ return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
+
+
+class DisentangledSelfAttention(nn.Module):
+ """
+ Disentangled self-attention module
+
+ Parameters:
+ config (`DebertaV2Config`):
+ A model config class instance with the configuration to build a new model. The schema is similar to
+ *BertConfig*, for more details, please refer [`DebertaV2Config`]
+
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+ self.num_attention_heads = config.num_attention_heads
+ _attention_head_size = config.hidden_size // config.num_attention_heads
+ self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
+ self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
+ self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
+
+ self.share_att_key = getattr(config, "share_att_key", False)
+ self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
+ self.relative_attention = getattr(config, "relative_attention", False)
+
+ if self.relative_attention:
+ self.position_buckets = getattr(config, "position_buckets", -1)
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
+ if self.max_relative_positions < 1:
+ self.max_relative_positions = config.max_position_embeddings
+ self.pos_ebd_size = self.max_relative_positions
+ if self.position_buckets > 0:
+ self.pos_ebd_size = self.position_buckets
+
+ self.pos_dropout = StableDropout(config.hidden_dropout_prob)
+
+ if not self.share_att_key:
+ if "c2p" in self.pos_att_type:
+ self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
+ if "p2c" in self.pos_att_type:
+ self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = StableDropout(config.attention_probs_dropout_prob)
+ # self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
+
+ def transpose_for_scores(self, x, attention_heads):
+ new_x_shape = x.size()[:-1] + (attention_heads, -1)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ output_attentions=False,
+ query_states=None,
+ relative_pos=None,
+ rel_embeddings=None,
+ ):
+ """
+ Call the module
+
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ Input states to the module usually the output from previous layer, it will be the Q,K and V in
+ *Attention(Q,K,V)*
+
+ attention_mask (`torch.ByteTensor`):
+ An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
+ sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
+ th token.
+
+ output_attentions (`bool`, optional):
+ Whether return the attention matrix.
+
+ query_states (`torch.FloatTensor`, optional):
+ The *Q* state in *Attention(Q,K,V)*.
+
+ relative_pos (`torch.LongTensor`):
+ The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
+ values ranging in [*-max_relative_positions*, *max_relative_positions*].
+
+ rel_embeddings (`torch.FloatTensor`):
+ The embedding of relative distances. It's a tensor of shape [\\(2 \\times
+ \\text{max_relative_positions}\\), *hidden_size*].
+
+
+ """
+ if query_states is None:
+ query_states = hidden_states
+ query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads)
+ key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads)
+ value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
+
+ rel_att = None
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ scale_factor = 1
+ if "c2p" in self.pos_att_type:
+ scale_factor += 1
+ if "p2c" in self.pos_att_type:
+ scale_factor += 1
+ scale = math.sqrt(query_layer.size(-1) * scale_factor)
+ attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
+ if self.relative_attention:
+ rel_embeddings = self.pos_dropout(rel_embeddings)
+ rel_att = self.disentangled_attention_bias(
+ query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
+ )
+
+ if rel_att is not None:
+ attention_scores = attention_scores + rel_att
+ attention_scores = attention_scores
+ attention_scores = attention_scores.view(
+ -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
+ )
+
+ # bsz x height x length x dimension
+ attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
+ attention_probs = self.dropout(attention_probs)
+ context_layer = torch.bmm(
+ attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer
+ )
+ context_layer = (
+ context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1))
+ .permute(0, 2, 1, 3)
+ .contiguous()
+ )
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
+ context_layer = context_layer.view(new_context_layer_shape)
+ if output_attentions:
+ return (context_layer, attention_probs)
+ else:
+ return context_layer
+
+ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
+ if relative_pos is None:
+ q = query_layer.size(-2)
+ relative_pos = build_relative_position(
+ q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions
+ )
+ if relative_pos.dim() == 2:
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
+ elif relative_pos.dim() == 3:
+ relative_pos = relative_pos.unsqueeze(1)
+ # bsz x height x query x key
+ elif relative_pos.dim() != 4:
+ raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
+
+ att_span = self.pos_ebd_size
+ relative_pos = relative_pos.long().to(query_layer.device)
+
+ # rel_index = torch.arange(0, att_span * 2).long().to(query_layer.device)
+ # rel_embeddings = rel_embeddings(rel_index).unsqueeze(0)
+ rel_embeddings = rel_embeddings.unsqueeze(0)
+ # rel_embeddings = rel_embeddings.unsqueeze(0)
+ # rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
+ if self.share_att_key:
+ pos_query_layer = self.transpose_for_scores(
+ self.query_proj(rel_embeddings), self.num_attention_heads
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
+ pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(
+ query_layer.size(0) // self.num_attention_heads, 1, 1
+ )
+ else:
+ if "c2p" in self.pos_att_type:
+ pos_key_layer = self.transpose_for_scores(
+ self.pos_key_proj(rel_embeddings), self.num_attention_heads
+ ).repeat(
+ query_layer.size(0) // self.num_attention_heads, 1, 1
+ ) # .split(self.all_head_size, dim=-1)
+ if "p2c" in self.pos_att_type:
+ pos_query_layer = self.transpose_for_scores(
+ self.pos_query_proj(rel_embeddings), self.num_attention_heads
+ ).repeat(
+ query_layer.size(0) // self.num_attention_heads, 1, 1
+ ) # .split(self.all_head_size, dim=-1)
+
+ score = 0
+ # content->position
+ if "c2p" in self.pos_att_type:
+ scale = math.sqrt(pos_key_layer.size(-1) * scale_factor)
+ c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
+ c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
+ c2p_att = torch.gather(
+ c2p_att,
+ dim=-1,
+ index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
+ )
+ score += c2p_att / scale
+
+ # position->content
+ if "p2c" in self.pos_att_type:
+ scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)
+ if key_layer.size(-2) != query_layer.size(-2):
+ r_pos = build_relative_position(
+ key_layer.size(-2),
+ key_layer.size(-2),
+ bucket_size=self.position_buckets,
+ max_position=self.max_relative_positions,
+ ).to(query_layer.device)
+ r_pos = r_pos.unsqueeze(0)
+ else:
+ r_pos = relative_pos
+
+ p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
+ p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
+ p2c_att = torch.gather(
+ p2c_att,
+ dim=-1,
+ index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
+ ).transpose(-1, -2)
+ score += p2c_att / scale
+
+ return score
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm
+class DebertaV2Embeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ pad_token_id = getattr(config, "pad_token_id", 0)
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id)
+
+ self.position_biased_input = getattr(config, "position_biased_input", True)
+ if not self.position_biased_input:
+ self.position_embeddings = None
+ else:
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
+
+ if config.type_vocab_size > 0:
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
+
+ if self.embedding_size != config.hidden_size:
+ self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
+ self.dropout = StableDropout(config.hidden_dropout_prob)
+ self.config = config
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ if self.position_embeddings is not None:
+ position_embeddings = self.position_embeddings(position_ids.long())
+ else:
+ position_embeddings = torch.zeros_like(inputs_embeds)
+
+ embeddings = inputs_embeds
+ if self.position_biased_input:
+ embeddings += position_embeddings
+ if self.config.type_vocab_size > 0:
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+ embeddings += token_type_embeddings
+
+ if self.embedding_size != self.config.hidden_size:
+ embeddings = self.embed_proj(embeddings)
+
+ embeddings = self.LayerNorm(embeddings)
+
+ if mask is not None:
+ if mask.dim() != embeddings.dim():
+ if mask.dim() == 4:
+ mask = mask.squeeze(1).squeeze(1)
+ mask = mask.unsqueeze(2)
+ mask = mask.to(embeddings.dtype)
+
+ embeddings = embeddings * mask
+
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DebertaPreTrainedModel with Deberta->DebertaV2
+class DebertaV2PreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DebertaV2Config
+ base_model_prefix = "deberta"
+ _keys_to_ignore_on_load_missing = ["position_ids"]
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, DebertaV2Encoder):
+ module.gradient_checkpointing = value
+
+
+DEBERTA_START_DOCSTRING = r"""
+ The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled
+ Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build
+ on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two
+ improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data.
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.```
+
+
+ Parameters:
+ config ([`DebertaV2Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DEBERTA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`DebertaV2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.",
+ DEBERTA_START_DOCSTRING,
+)
+# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2
+class DebertaV2Model(DebertaV2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.embeddings = DebertaV2Embeddings(config)
+ self.encoder = DebertaV2Encoder(config)
+ self.z_steps = 0
+ self.config = config
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, new_embeddings):
+ self.embeddings.word_embeddings = new_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ raise NotImplementedError("The prune function is not implemented in DeBERTa model.")
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask,
+ output_hidden_states=True,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ )
+ encoded_layers = encoder_outputs[1]
+
+ if self.z_steps > 1:
+ hidden_states = encoded_layers[-2]
+ layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
+ query_states = encoded_layers[-1]
+ rel_embeddings = self.encoder.get_rel_embedding()
+ attention_mask = self.encoder.get_attention_mask(attention_mask)
+ rel_pos = self.encoder.get_rel_pos(embedding_output)
+ for layer in layers[1:]:
+ query_states = layer(
+ hidden_states,
+ attention_mask,
+ output_attentions=False,
+ query_states=query_states,
+ relative_pos=rel_pos,
+ rel_embeddings=rel_embeddings,
+ )
+ encoded_layers.append(query_states)
+
+ sequence_output = encoded_layers[-1]
+
+ if not return_dict:
+ return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :]
+
+ return BaseModelOutput(
+ last_hidden_state=sequence_output,
+ hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
+# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2
+class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.deberta = DebertaV2Model(config)
+ self.cls = DebertaV2OnlyMLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MaskedLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[1:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta
+class DebertaV2PredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta
+class DebertaV2LMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = DebertaV2PredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta
+class DebertaV2OnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = DebertaV2LMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+@add_start_docstrings(
+ """
+ DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """,
+ DEBERTA_START_DOCSTRING,
+)
+# Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2
+class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ num_labels = getattr(config, "num_labels", 2)
+ self.num_labels = num_labels
+
+ self.deberta = DebertaV2Model(config)
+ self.pooler = ContextPooler(config)
+ output_dim = self.pooler.output_dim
+
+ self.classifier = nn.Linear(output_dim, num_labels)
+ drop_out = getattr(config, "cls_dropout", None)
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
+ self.dropout = StableDropout(drop_out)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.deberta.get_input_embeddings()
+
+ def set_input_embeddings(self, new_embeddings):
+ self.deberta.set_input_embeddings(new_embeddings)
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deberta(
+ input_ids,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ encoder_layer = outputs[0]
+ pooled_output = self.pooler(encoder_layer)
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ # regression task
+ loss_fn = nn.MSELoss()
+ logits = logits.view(-1).to(labels.dtype)
+ loss = loss_fn(logits, labels.view(-1))
+ elif labels.dim() == 1 or labels.size(-1) == 1:
+ label_index = (labels >= 0).nonzero()
+ labels = labels.long()
+ if label_index.size(0) > 0:
+ labeled_logits = torch.gather(
+ logits, 0, label_index.expand(label_index.size(0), logits.size(1))
+ )
+ labels = torch.gather(labels, 0, label_index.view(-1))
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
+ else:
+ loss = torch.tensor(0).to(logits)
+ else:
+ log_softmax = nn.LogSoftmax(-1)
+ loss = -((log_softmax(logits) * labels).sum(-1)).mean()
+ elif self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+@add_start_docstrings(
+ """
+ DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ DEBERTA_START_DOCSTRING,
+)
+# Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2
+class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.deberta = DebertaV2Model(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+@add_start_docstrings(
+ """
+ DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ DEBERTA_START_DOCSTRING,
+)
+# Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering with Deberta->DebertaV2
+class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.deberta = DebertaV2Model(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=QuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.Tensor] = None,
+ end_positions: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[1:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ DEBERTA_START_DOCSTRING,
+)
+class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ num_labels = getattr(config, "num_labels", 2)
+ self.num_labels = num_labels
+
+ self.deberta = DebertaV2Model(config)
+ self.pooler = ContextPooler(config)
+ output_dim = self.pooler.output_dim
+
+ self.classifier = nn.Linear(output_dim, 1)
+ drop_out = getattr(config, "cls_dropout", None)
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
+ self.dropout = StableDropout(drop_out)
+
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.deberta.get_input_embeddings()
+
+ def set_input_embeddings(self, new_embeddings):
+ self.deberta.set_input_embeddings(new_embeddings)
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ flat_inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.deberta(
+ flat_input_ids,
+ position_ids=flat_position_ids,
+ token_type_ids=flat_token_type_ids,
+ attention_mask=flat_attention_mask,
+ inputs_embeds=flat_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ encoder_layer = outputs[0]
+ pooled_output = self.pooler(encoder_layer)
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/examples/language/roberta/pretraining/nvidia_bert_dataset_provider.py b/examples/language/roberta/pretraining/nvidia_bert_dataset_provider.py
new file mode 100644
index 000000000..cce836913
--- /dev/null
+++ b/examples/language/roberta/pretraining/nvidia_bert_dataset_provider.py
@@ -0,0 +1,182 @@
+import os
+import random
+import h5py
+import logging
+import json
+import time
+from concurrent.futures import ProcessPoolExecutor
+
+import numpy as np
+
+import torch
+import torch.distributed as dist
+from torch.utils.data import DataLoader, Dataset
+from torch.utils.data.sampler import RandomSampler
+from torch.utils.data.distributed import DistributedSampler
+
+from bert_dataset_provider import BertDatasetProviderInterface
+import colossalai.utils as utils
+
+# Workaround because python functions are not picklable
+class WorkerInitObj(object):
+ def __init__(self, seed):
+ self.seed = seed
+
+ def __call__(self, id):
+ np.random.seed(seed=self.seed + id)
+ random.seed(self.seed + id)
+
+
+def create_pretraining_dataset(input_file, max_predictions_per_seq,
+ num_workers, train_batch_size, worker_init,
+ data_sampler):
+ train_data = pretraining_dataset(
+ input_file=input_file, max_predictions_per_seq=max_predictions_per_seq)
+ train_dataloader = DataLoader(train_data,
+ sampler=data_sampler(train_data),
+ batch_size=train_batch_size,
+ num_workers=num_workers,
+ worker_init_fn=worker_init,
+ pin_memory=True
+ )
+ return train_dataloader, len(train_data)
+
+
+class pretraining_dataset(Dataset):
+ def __init__(self, input_file, max_predictions_per_seq):
+ self.input_file = input_file
+ self.max_predictions_per_seq = max_predictions_per_seq
+ f = h5py.File(input_file, "r")
+ keys = [
+ 'input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions'
+ ]
+ self.inputs = [np.asarray(f[key][:]) for key in keys]
+ f.close()
+
+ def __len__(self):
+ 'Denotes the total number of samples'
+ return len(self.inputs[0])
+
+ def __getitem__(self, index):
+
+ [
+ input_ids, input_mask, segment_ids, masked_lm_labels
+ ] = [
+ torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else
+ torch.from_numpy(np.asarray(input[index].astype(np.int64)))
+ for indice, input in enumerate(self.inputs)
+ ]
+
+ return [
+ input_ids, input_mask,
+ segment_ids, masked_lm_labels
+ ]
+
+
+class NvidiaBertDatasetProvider(BertDatasetProviderInterface):
+ def __init__(self, args, evaluate=False):
+ self.num_workers = args.num_workers
+ self.max_seq_length = args.max_seq_length
+ self.max_predictions_per_seq = args.max_predictions_per_seq
+
+ self.gradient_accumulation_steps = args.gradient_accumulation_steps
+ if not evaluate:
+ self.train_micro_batch_size_per_gpu = args.train_micro_batch_size_per_gpu
+ else:
+ self.train_micro_batch_size_per_gpu = args.eval_micro_batch_size_per_gpu
+ self.logger = args.logger
+
+ self.global_rank = dist.get_rank()
+ self.world_size = dist.get_world_size()
+
+ # Initialize dataset files
+ if not evaluate:
+ self.dataset_files = [
+ os.path.join(args.data_path_prefix, f) for f in os.listdir(args.data_path_prefix) if
+ os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f
+ ]
+ else:
+ self.dataset_files = [
+ os.path.join(args.eval_data_path_prefix, f) for f in os.listdir(args.eval_data_path_prefix) if
+ os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f
+ ]
+
+ self.dataset_files.sort()
+ # random.shuffle(self.dataset_files)
+ self.num_files = len(self.dataset_files)
+ # self.data_sampler = RandomSampler
+ self.data_sampler = DistributedSampler
+
+ self.worker_init = WorkerInitObj(args.seed + args.local_rank)
+ self.dataset_future = None
+ self.pool = ProcessPoolExecutor(1)
+ self.data_file = None
+ self.shuffle = True
+
+ if self.global_rank == 0:
+ self.logger.info(
+ f"NvidiaBertDatasetProvider - Initialization: num_files = {self.num_files}"
+ )
+
+ def get_shard(self, index):
+ start = time.time()
+ if self.dataset_future is None:
+ self.data_file = self._get_shard_file(index)
+ self.train_dataloader, sample_count = create_pretraining_dataset(
+ input_file=self.data_file,
+ max_predictions_per_seq=self.max_predictions_per_seq,
+ num_workers=self.num_workers,
+ train_batch_size=self.train_micro_batch_size_per_gpu,
+ worker_init=self.worker_init,
+ data_sampler=self.data_sampler)
+ else:
+ self.train_dataloader, sample_count = self.dataset_future.result(
+ timeout=None)
+
+ self.logger.info(
+ f"Data Loading Completed for Pretraining Data from {self.data_file} with {sample_count} samples took {time.time()-start:.2f}s."
+ )
+
+ return self.train_dataloader, sample_count
+
+ def release_shard(self):
+ del self.train_dataloader
+ self.pool.shutdown()
+
+ def prefetch_shard(self, index):
+ self.data_file = self._get_shard_file(index)
+ self.dataset_future = self.pool.submit(
+ create_pretraining_dataset, self.data_file,
+ self.max_predictions_per_seq, self.num_workers,
+ self.train_micro_batch_size_per_gpu, self.worker_init,
+ self.data_sampler)
+
+ def get_batch(self, batch_iter):
+ return batch_iter
+
+ def prefetch_batch(self):
+ pass
+
+ def _get_shard_file(self, shard_index):
+ file_index = self._get_shard_file_index(shard_index, self.global_rank)
+ return self.dataset_files[file_index]
+
+ def _get_shard_file_index(self, shard_index, global_rank):
+ # if dist.is_initialized() and self.world_size > self.num_files:
+ # remainder = self.world_size % self.num_files
+ # file_index = (shard_index * self.world_size) + global_rank + (
+ # remainder * shard_index)
+ # else:
+ # file_index = shard_index * self.world_size + global_rank
+
+ return shard_index % self.num_files
+
+ def shuffle_dataset(self, epoch):
+ if self.shuffle:
+ # deterministically shuffle based on epoch and seed
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(self.num_files, generator=g).tolist()
+ new_dataset = [self.dataset_files[i] for i in indices]
+ self.dataset_files = new_dataset
+
\ No newline at end of file
diff --git a/examples/language/roberta/pretraining/pretrain_utils.py b/examples/language/roberta/pretraining/pretrain_utils.py
new file mode 100644
index 000000000..ba17b0f5e
--- /dev/null
+++ b/examples/language/roberta/pretraining/pretrain_utils.py
@@ -0,0 +1,112 @@
+import transformers
+import logging
+from colossalai.nn.lr_scheduler import LinearWarmupLR
+from transformers import get_linear_schedule_with_warmup
+from transformers import BertForPreTraining, RobertaForMaskedLM, RobertaConfig
+from transformers import GPT2Config, GPT2LMHeadModel
+from transformers import AutoTokenizer, AutoModelForMaskedLM
+from colossalai.nn.optimizer import FusedAdam
+from torch.optim import AdamW
+from colossalai.core import global_context as gpc
+import torch
+import os
+import sys
+sys.path.append(os.getcwd())
+from model.deberta_v2 import DebertaV2ForMaskedLM
+from model.bert import BertForMaskedLM
+import torch.nn as nn
+
+from collections import OrderedDict
+
+__all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining']
+
+
+def get_new_state_dict(state_dict, start_index=13):
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ name = k[start_index:]
+ new_state_dict[name] = v
+ return new_state_dict
+
+
+class LMModel(nn.Module):
+ def __init__(self, model, config, args):
+ super().__init__()
+
+ self.checkpoint = args.checkpoint_activations
+ self.config = config
+ self.model = model
+ if self.checkpoint:
+ self.model.gradient_checkpointing_enable()
+
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None):
+ # Only return lm_logits
+ return self.model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
+
+
+def get_model(args, logger):
+
+ if args.mlm == 'bert':
+ config = transformers.BertConfig.from_json_file(args.bert_config)
+ model = BertForMaskedLM(config)
+ elif args.mlm == 'deberta_v2':
+ config = transformers.DebertaV2Config.from_json_file(args.bert_config)
+ model = DebertaV2ForMaskedLM(config)
+ else:
+ raise Exception("Invalid mlm!")
+
+ if len(args.load_pretrain_model) > 0:
+ assert os.path.exists(args.load_pretrain_model)
+ # load_checkpoint(args.load_pretrain_model, model, strict=False)
+ m_state_dict = torch.load(args.load_pretrain_model, map_location=torch.device(f"cuda:{torch.cuda.current_device()}"))
+ # new_state_dict = get_new_state_dict(m_state_dict)
+ model.load_state_dict(m_state_dict, strict=True) # must insure that every process have identical parameters !!!!!!!
+ logger.info("load model success")
+
+ numel = sum([p.numel() for p in model.parameters()])
+ if args.checkpoint_activations:
+ model.gradient_checkpointing_enable()
+ # model = LMModel(model, config, args)
+
+ return config, model, numel
+
+
+def get_optimizer(model, lr):
+ param_optimizer = list(model.named_parameters())
+ no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']
+
+ # configure the weight decay for bert models
+ optimizer_grouped_parameters = [{
+ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
+ 'weight_decay': 0.1
+ }, {
+ 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
+ 'weight_decay': 0.0
+ }]
+ optimizer = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=[0.9, 0.95])
+ return optimizer
+
+
+def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1):
+ # warmup_steps = int(total_steps * warmup_ratio)
+ lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, last_epoch=last_epoch)
+ # lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps)
+ return lr_scheduler
+
+
+def save_ckpt(model, optimizer, lr_scheduler, path, epoch, shard, global_step):
+ model_path = path + '_pytorch_model.bin'
+ optimizer_lr_path = path + '.op_lrs'
+ checkpoint = {}
+ checkpoint['optimizer'] = optimizer.state_dict()
+ checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
+ checkpoint['epoch'] = epoch
+ checkpoint['shard'] = shard
+ checkpoint['global_step'] = global_step
+ model_state = model.state_dict() #each process must run model.state_dict()
+ if gpc.get_global_rank() == 0:
+ torch.save(checkpoint, optimizer_lr_path)
+ torch.save(model_state, model_path)
+
+
+
diff --git a/examples/language/roberta/pretraining/run_pretrain.sh b/examples/language/roberta/pretraining/run_pretrain.sh
new file mode 100644
index 000000000..144cd0ab9
--- /dev/null
+++ b/examples/language/roberta/pretraining/run_pretrain.sh
@@ -0,0 +1,40 @@
+#!/usr/bin/env sh
+
+root_path=$PWD
+PY_FILE_PATH="$root_path/run_pretraining.py"
+
+tensorboard_path="$root_path/tensorboard"
+log_path="$root_path/exp_log"
+ckpt_path="$root_path/ckpt"
+
+colossal_config="$root_path/../configs/colossalai_ddp.py"
+
+mkdir -p $tensorboard_path
+mkdir -p $log_path
+mkdir -p $ckpt_path
+
+export PYTHONPATH=$PWD
+
+env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \
+ --include GPU002,GPU003,GPU004,GPU007 \
+ --nproc_per_node=8 \
+ $PY_FILE_PATH \
+ --master_addr GPU007 \
+ --master_port 20024 \
+ --lr 2.0e-4 \
+ --train_micro_batch_size_per_gpu 190 \
+ --eval_micro_batch_size_per_gpu 20 \
+ --epoch 15 \
+ --data_path_prefix /h5 \
+ --eval_data_path_prefix /eval_h5 \
+ --tokenizer_path /roberta \
+ --bert_config /roberta/config.json \
+ --tensorboard_path $tensorboard_path \
+ --log_path $log_path \
+ --ckpt_path $ckpt_path \
+ --colossal_config $colossal_config \
+ --log_interval 50 \
+ --mlm bert \
+ --wandb \
+ --checkpoint_activations \
+
\ No newline at end of file
diff --git a/examples/language/roberta/pretraining/run_pretrain_resume.sh b/examples/language/roberta/pretraining/run_pretrain_resume.sh
new file mode 100644
index 000000000..a0704cf7c
--- /dev/null
+++ b/examples/language/roberta/pretraining/run_pretrain_resume.sh
@@ -0,0 +1,43 @@
+#!/usr/bin/env sh
+
+root_path=$PWD
+PY_FILE_PATH="$root_path/run_pretraining.py"
+
+tensorboard_path="$root_path/tensorboard"
+log_path="$root_path/exp_log"
+ckpt_path="$root_path/ckpt"
+
+colossal_config="$root_path/../configs/colossalai_ddp.py"
+
+mkdir -p $tensorboard_path
+mkdir -p $log_path
+mkdir -p $ckpt_path
+
+export PYTHONPATH=$PWD
+
+env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \
+ --include GPU002,GPU003,GPU004,GPU007 \
+ --nproc_per_node=8 \
+ $PY_FILE_PATH \
+ --master_addr GPU007 \
+ --master_port 20024 \
+ --lr 2.0e-4 \
+ --train_micro_batch_size_per_gpu 190 \
+ --eval_micro_batch_size_per_gpu 20 \
+ --epoch 15 \
+ --data_path_prefix /h5 \
+ --eval_data_path_prefix /eval_h5 \
+ --tokenizer_path /roberta \
+ --bert_config /roberta/config.json \
+ --tensorboard_path $tensorboard_path \
+ --log_path $log_path \
+ --ckpt_path $ckpt_path \
+ --colossal_config $colossal_config \
+ --log_interval 50 \
+ --mlm bert \
+ --wandb \
+ --checkpoint_activations \
+ --resume_train \
+ --load_pretrain_model /ckpt/1.pt \
+ --load_optimizer_lr /ckpt/1.op_lrs \
+
\ No newline at end of file
diff --git a/examples/language/roberta/pretraining/run_pretraining.py b/examples/language/roberta/pretraining/run_pretraining.py
new file mode 100644
index 000000000..9840a122c
--- /dev/null
+++ b/examples/language/roberta/pretraining/run_pretraining.py
@@ -0,0 +1,226 @@
+import colossalai
+import math
+import torch
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+import colossalai.nn as col_nn
+from arguments import parse_args
+from pretrain_utils import get_model, get_optimizer, get_lr_scheduler, save_ckpt
+from utils.exp_util import get_tflops, get_mem_info, throughput_calculator, log_args
+from utils.global_vars import set_global_variables, get_timers, get_tensorboard_writer
+from utils.logger import Logger
+from evaluation import evaluate
+from loss import LossForPretraining
+
+from colossalai.zero.init_ctx import ZeroInitContext
+from colossalai.zero.shard_utils import TensorShardStrategy
+from colossalai.zero.sharded_model import ShardedModelV2
+from colossalai.zero.sharded_optim import ShardedOptimizerV2
+from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
+from tqdm import tqdm
+import os
+import time
+from functools import partial
+
+from transformers import AutoTokenizer
+
+from colossalai.gemini import ChunkManager, GeminiManager
+from colossalai.utils.model.colo_init_context import ColoInitContext
+from colossalai.utils import get_current_device
+from colossalai.nn.parallel import ZeroDDP
+from colossalai.zero import ZeroOptimizer
+from colossalai.tensor import ProcessGroup
+from colossalai.nn.optimizer import HybridAdam
+
+
+def main():
+
+ args = parse_args()
+ launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
+
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
+
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
+
+ logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug)
+
+ if args.vscode_debug:
+ colossalai.launch(config={},
+ rank=args.rank,
+ world_size=args.world_size,
+ host=args.host,
+ port=args.port,
+ backend=args.backend)
+ args.local_rank = -1
+ args.log_interval = 1
+ else:
+ colossalai.launch_from_torch(args.colossal_config) #args.colossal_config
+ args.local_rank = int(os.environ["LOCAL_RANK"])
+ logger.info(f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' +
+ f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}')
+
+ log_args(logger, args)
+ args.tokenizer = tokenizer
+ args.logger = logger
+ set_global_variables(launch_time, args.tensorboard_path)
+
+ use_zero = hasattr(gpc.config, 'zero')
+ world_size = torch.distributed.get_world_size()
+
+ # build model, optimizer and criterion
+ if use_zero:
+ shard_strategy = TensorShardStrategy()
+ with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy,
+ shard_param=True):
+
+ config, model, numel = get_model(args, logger)
+ # model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True)
+ else:
+ config, model, numel = get_model(args, logger)
+ logger.info("no_zero")
+ if torch.distributed.get_rank() == 0:
+ os.mkdir(os.path.join(args.ckpt_path, launch_time))
+
+ logger.info(f'Model numel: {numel}')
+
+ get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length)
+ steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader)
+ total_steps = steps_per_epoch * args.epoch
+
+ # build optimizer and lr_scheduler
+
+ start_epoch = 0
+ start_shard = 0
+ global_step = 0
+ if args.resume_train:
+ assert os.path.exists(args.load_optimizer_lr)
+ o_l_state_dict = torch.load(args.load_optimizer_lr, map_location='cpu')
+ o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1
+ optimizer = get_optimizer(model, lr=args.lr)
+ optimizer.load_state_dict(o_l_state_dict['optimizer'])
+ lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) #o_l_state_dict['lr_scheduler']['last_epoch']
+ for state in optimizer.state.values():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor):
+ state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}")
+ # if you want delete the above three code, have to move the model to gpu, because in optimizer.step()
+ lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler'])
+
+ start_epoch = o_l_state_dict['epoch']
+ start_shard = o_l_state_dict['shard'] + 1
+ # global_step = o_l_state_dict['global_step'] + 1
+ logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}')
+ else:
+ optimizer = get_optimizer(model, lr=args.lr)
+ lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1)
+
+ # optimizer = gpc.config.optimizer.pop('type')(
+ # model.parameters(), **gpc.config.optimizer)
+ # optimizer = ShardedOptimizerV2(model, optimizer, initial_scale=2**5)
+ criterion = LossForPretraining(config.vocab_size)
+
+ # build dataloader
+ pretrain_dataset_provider = NvidiaBertDatasetProvider(args)
+
+ # initialize with colossalai
+ engine, _, _, lr_scheduelr = colossalai.initialize(model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ lr_scheduler=lr_scheduler)
+
+ logger.info(get_mem_info(prefix='After init model, '))
+
+
+ best_loss = None
+ eval_loss = 0
+ train_loss = 0
+ timers = get_timers()
+ timers('interval_time').start()
+ timers('epoch_time').start()
+ timers('shard_time').start()
+
+ for epoch in range(start_epoch, args.epoch):
+
+ for shard in range(start_shard, len(os.listdir(args.data_path_prefix))):
+
+ dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard)
+ # pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload
+ if torch.distributed.get_rank() == 0:
+ iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1)
+ else:
+ iterator_data = enumerate(dataset_iterator)
+
+ engine.train()
+
+ for step, batch_data in iterator_data:
+
+ # batch_data = pretrain_dataset_provider.get_batch(batch_index)
+ input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}")
+ attention_mask = batch_data[1].cuda(f"cuda:{torch.cuda.current_device()}")
+ token_type_ids = batch_data[2].cuda(f"cuda:{torch.cuda.current_device()}")
+ mlm_label = batch_data[3].cuda(f"cuda:{torch.cuda.current_device()}")
+ # nsp_label = batch_data[5].cuda()
+
+ output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
+
+ loss = engine.criterion(output.logits, mlm_label)
+ pretrain_dataset_provider.prefetch_batch()
+
+ engine.backward(loss)
+ train_loss += loss.float().item()
+ # if (step + 1) % args.accumulation_step == 0:
+ engine.step()
+ lr_scheduelr.step()
+ engine.zero_grad()
+
+ global_step += 1
+
+ if global_step % args.log_interval == 0 and global_step != 0 \
+ and torch.distributed.get_rank() == 0:
+ elapsed_time = timers('interval_time').elapsed(reset=False)
+ elapsed_time_per_iteration = elapsed_time / global_step
+ samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(numel, args, config, elapsed_time, global_step, world_size)
+
+ cur_loss = train_loss / args.log_interval
+ current_lr = lr_scheduelr.get_last_lr()[0]
+ log_str = f'| epoch: {epoch} | shard: {shard} | step: {global_step} | lr {current_lr:.7f} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \
+ f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {math.exp(cur_loss):.3f} | TFLOPS: {get_tflops_func(elapsed_time_per_iteration):.3f} or {tflops:.3f}'
+ logger.info(log_str, print_=False)
+
+ if args.wandb:
+ tensorboard_log = get_tensorboard_writer()
+ tensorboard_log.log_train({
+ 'lr': current_lr,
+ 'loss': cur_loss,
+ 'ppl': math.exp(cur_loss),
+ 'mins_batch': elapsed_time_per_iteration
+ }, global_step)
+
+ train_loss = 0
+
+ logger.info(f'epoch {epoch} shard {shard} has cost {timers("shard_time").elapsed() / 60 :.3f} mins')
+ logger.info('*' * 100)
+
+ eval_loss += evaluate(engine, args, logger, global_step)
+ save_ckpt(engine.model, optimizer, lr_scheduelr, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step)
+
+
+ eval_loss /= len(os.listdir(args.data_path_prefix))
+ logger.info(f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' + \
+ f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}')
+ logger.info('-' * 100)
+ if args.wandb and torch.distributed.get_rank() == 0:
+ tensorboard_log = get_tensorboard_writer()
+ tensorboard_log.log_eval({
+ 'all_eval_shard_loss': eval_loss,
+ }, epoch)
+ start_shard = 0
+ eval_loss = 0
+
+ pretrain_dataset_provider.release_shard()
+
+ logger.info('Congratulation, training has finished!!!')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/language/roberta/pretraining/utils/WandbLog.py b/examples/language/roberta/pretraining/utils/WandbLog.py
new file mode 100644
index 000000000..9dd28a981
--- /dev/null
+++ b/examples/language/roberta/pretraining/utils/WandbLog.py
@@ -0,0 +1,46 @@
+import time
+import wandb
+import os
+from torch.utils.tensorboard import SummaryWriter
+
+class WandbLog:
+
+ @classmethod
+ def init_wandb(cls, project, notes=None, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None):
+ wandb.init(project=project, notes=notes, name=name, config=config)
+
+ @classmethod
+ def log(cls, result, model=None, gradient=None):
+ wandb.log(result)
+
+ if model:
+ wandb.watch(model)
+
+ if gradient:
+ wandb.watch(gradient)
+
+
+class TensorboardLog:
+
+ def __init__(self, location, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None):
+ if not os.path.exists(location):
+ os.mkdir(location)
+ self.writer = SummaryWriter(location, comment=name)
+
+ def log_train(self, result, step):
+ for k, v in result.items():
+ self.writer.add_scalar(f'{k}/train', v, step)
+
+ def log_eval(self, result, step):
+ for k, v in result.items():
+ self.writer.add_scalar(f'{k}/eval', v, step)
+
+ def log_zeroshot(self, result, step):
+ for k, v in result.items():
+ self.writer.add_scalar(f'{k}_acc/eval', v, step)
+
+
+
+
+
+
diff --git a/examples/language/roberta/pretraining/utils/exp_util.py b/examples/language/roberta/pretraining/utils/exp_util.py
new file mode 100644
index 000000000..a02b0872a
--- /dev/null
+++ b/examples/language/roberta/pretraining/utils/exp_util.py
@@ -0,0 +1,99 @@
+import functools
+import os, shutil
+import torch
+import psutil
+from colossalai.core import global_context as gpc
+
+def logging(s, log_path, print_=True, log_=True):
+ if print_:
+ print(s)
+ if log_:
+ with open(log_path, 'a+') as f_log:
+ f_log.write(s + '\n')
+
+def get_logger(log_path, **kwargs):
+ return functools.partial(logging, log_path=log_path, **kwargs)
+
+def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
+ if debug:
+ print('Debug Mode : no experiment dir created')
+ return functools.partial(logging, log_path=None, log_=False)
+
+ if not os.path.exists(dir_path):
+ os.makedirs(dir_path)
+
+ print('Experiment dir : {}'.format(dir_path))
+ if scripts_to_save is not None:
+ script_path = os.path.join(dir_path, 'scripts')
+ if not os.path.exists(script_path):
+ os.makedirs(script_path)
+ for script in scripts_to_save:
+ dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script))
+ shutil.copyfile(script, dst_file)
+
+ return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
+
+def get_cpu_mem():
+ return psutil.Process().memory_info().rss / 1024**2
+
+
+def get_gpu_mem():
+ return torch.cuda.memory_allocated() / 1024**2
+
+
+def get_mem_info(prefix=''):
+ return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
+
+
+def get_tflops(model_numel, batch_size, seq_len, step_time):
+ return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
+
+
+def get_parameters_in_billions(model, world_size=1):
+ gpus_per_model = world_size
+
+ approx_parameters_in_billions = sum([sum([p.ds_numel if hasattr(p,'ds_id') else p.nelement() for p in model_module.parameters()])
+ for model_module in model])
+
+ return approx_parameters_in_billions * gpus_per_model / (1e9)
+
+def throughput_calculator(numel, args, config, iteration_time, total_iterations, world_size=1):
+ gpus_per_model = 1
+ batch_size = args.train_micro_batch_size_per_gpu
+ samples_per_model = batch_size * args.max_seq_length
+ model_replica_count = world_size / gpus_per_model
+ approx_parameters_in_billions = numel
+ elapsed_time_per_iter = iteration_time / total_iterations
+ samples_per_second = batch_size / elapsed_time_per_iter
+
+ #flops calculator
+ hidden_size = config.hidden_size
+ num_layers = config.num_hidden_layers
+ vocab_size = config.vocab_size
+
+ # General TFLOPs formula (borrowed from Equation 3 in Section 5.1 of
+ # https://arxiv.org/pdf/2104.04473.pdf).
+ # The factor of 4 is when used with activation check-pointing,
+ # otherwise it will be 3.
+ checkpoint_activations_factor = 4 if args.checkpoint_activations else 3
+ flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers * (hidden_size**2)) * (1. + (args.max_seq_length / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size)))
+ tflops = flops_per_iteration / (elapsed_time_per_iter * (10**12))
+ return samples_per_second, tflops, approx_parameters_in_billions
+
+def synchronize():
+ if not torch.distributed.is_available():
+ return
+ if not torch.distributed.is_intialized():
+ return
+ world_size = torch.distributed.get_world_size()
+ if world_size == 1:
+ return
+ torch.distributed.barrier()
+
+def log_args(logger, args):
+ logger.info('--------args----------')
+ message = '\n'.join([f'{k:<30}: {v}' for k, v in vars(args).items()])
+ message += '\n'
+ message += '\n'.join([f'{k:<30}: {v}' for k, v in gpc.config.items()])
+ logger.info(message)
+ logger.info('--------args----------\n')
\ No newline at end of file
diff --git a/examples/language/roberta/pretraining/utils/global_vars.py b/examples/language/roberta/pretraining/utils/global_vars.py
new file mode 100644
index 000000000..363cbf91c
--- /dev/null
+++ b/examples/language/roberta/pretraining/utils/global_vars.py
@@ -0,0 +1,126 @@
+import time
+import torch
+from .WandbLog import TensorboardLog
+
+_GLOBAL_TIMERS = None
+_GLOBAL_TENSORBOARD_WRITER = None
+
+
+def set_global_variables(launch_time, tensorboard_path):
+ _set_timers()
+ _set_tensorboard_writer(launch_time, tensorboard_path)
+
+def _set_timers():
+ """Initialize timers."""
+ global _GLOBAL_TIMERS
+ _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
+ _GLOBAL_TIMERS = Timers()
+
+def _set_tensorboard_writer(launch_time, tensorboard_path):
+ """Set tensorboard writer."""
+ global _GLOBAL_TENSORBOARD_WRITER
+ _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER,
+ 'tensorboard writer')
+ if torch.distributed.get_rank() == 0:
+ _GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f'/{launch_time}', launch_time)
+
+def get_timers():
+ """Return timers."""
+ _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
+ return _GLOBAL_TIMERS
+
+def get_tensorboard_writer():
+ """Return tensorboard writer. It can be None so no need
+ to check if it is initialized."""
+ return _GLOBAL_TENSORBOARD_WRITER
+
+def _ensure_var_is_initialized(var, name):
+ """Make sure the input variable is not None."""
+ assert var is not None, '{} is not initialized.'.format(name)
+
+
+def _ensure_var_is_not_initialized(var, name):
+ """Make sure the input variable is not None."""
+ assert var is None, '{} is already initialized.'.format(name)
+
+
+class _Timer:
+ """Timer."""
+
+ def __init__(self, name):
+ self.name_ = name
+ self.elapsed_ = 0.0
+ self.started_ = False
+ self.start_time = time.time()
+
+ def start(self):
+ """Start the timer."""
+ # assert not self.started_, 'timer has already been started'
+ torch.cuda.synchronize()
+ self.start_time = time.time()
+ self.started_ = True
+
+ def stop(self):
+ """Stop the timer."""
+ assert self.started_, 'timer is not started'
+ torch.cuda.synchronize()
+ self.elapsed_ += (time.time() - self.start_time)
+ self.started_ = False
+
+ def reset(self):
+ """Reset timer."""
+ self.elapsed_ = 0.0
+ self.started_ = False
+
+ def elapsed(self, reset=True):
+ """Calculate the elapsed time."""
+ started_ = self.started_
+ # If the timing in progress, end it first.
+ if self.started_:
+ self.stop()
+ # Get the elapsed time.
+ elapsed_ = self.elapsed_
+ # Reset the elapsed time
+ if reset:
+ self.reset()
+ # If timing was in progress, set it back.
+ if started_:
+ self.start()
+ return elapsed_
+
+
+class Timers:
+ """Group of timers."""
+
+ def __init__(self):
+ self.timers = {}
+
+ def __call__(self, name):
+ if name not in self.timers:
+ self.timers[name] = _Timer(name)
+ return self.timers[name]
+
+ def write(self, names, writer, iteration, normalizer=1.0, reset=False):
+ """Write timers to a tensorboard writer"""
+ # currently when using add_scalars,
+ # torch.utils.add_scalars makes each timer its own run, which
+ # polutes the runs list, so we just add each as a scalar
+ assert normalizer > 0.0
+ for name in names:
+ value = self.timers[name].elapsed(reset=reset) / normalizer
+ writer.add_scalar(name + '-time', value, iteration)
+
+ def log(self, names, normalizer=1.0, reset=True):
+ """Log a group of timers."""
+ assert normalizer > 0.0
+ string = 'time (ms)'
+ for name in names:
+ elapsed_time = self.timers[name].elapsed(
+ reset=reset) * 1000.0 / normalizer
+ string += ' | {}: {:.2f}'.format(name, elapsed_time)
+ if torch.distributed.is_initialized():
+ if torch.distributed.get_rank() == (
+ torch.distributed.get_world_size() - 1):
+ print(string, flush=True)
+ else:
+ print(string, flush=True)
diff --git a/examples/language/roberta/pretraining/utils/logger.py b/examples/language/roberta/pretraining/utils/logger.py
new file mode 100644
index 000000000..481c4c6ce
--- /dev/null
+++ b/examples/language/roberta/pretraining/utils/logger.py
@@ -0,0 +1,31 @@
+import os
+import logging
+import torch.distributed as dist
+
+logging.basicConfig(
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
+ datefmt='%m/%d/%Y %H:%M:%S',
+ level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+class Logger():
+ def __init__(self, log_path, cuda=False, debug=False):
+ self.logger = logging.getLogger(__name__)
+ self.cuda = cuda
+ self.log_path = log_path
+ self.debug = debug
+
+
+ def info(self, message, log_=True, print_=True, *args, **kwargs):
+ if (self.cuda and dist.get_rank() == 0) or not self.cuda:
+ if print_:
+ self.logger.info(message, *args, **kwargs)
+
+ if log_:
+ with open(self.log_path, 'a+') as f_log:
+ f_log.write(message + '\n')
+
+
+ def error(self, message, *args, **kwargs):
+ self.logger.error(message, *args, **kwargs)