diff --git a/examples/language/roberta/README.md b/examples/language/roberta/README.md new file mode 100644 index 000000000..c119d23b5 --- /dev/null +++ b/examples/language/roberta/README.md @@ -0,0 +1,58 @@ +# Introduction +This repo introduce how to pretrain a chinese roberta-large from scratch, including preprocessing, pretraining, finetune. The repo can help you quickly train a high-quality bert. + +## 0. Prerequisite +- Install Colossal-AI +- Editing the port from /etc/ssh/sshd_config and /etc/ssh/ssh_config, every host expose the same ssh port of server and client. If you are a root user, you also set the **PermitRootLogin** from /etc/ssh/sshd_config to "yes" +- Ensure that each host can log in to each other without password. If you have n hosts, need to execute n2 times + +``` +ssh-keygen +ssh-copy-id -i ~/.ssh/id_rsa.pub ip_destination +``` + +- In all hosts, edit /etc/hosts to record all hosts' name and ip.The example is shown below. + +```bash +192.168.2.1 GPU001 +192.168.2.2 GPU002 +192.168.2.3 GPU003 +192.168.2.4 GPU004 +192.168.2.5 GPU005 +192.168.2.6 GPU006 +192.168.2.7 GPU007 +... +``` + +- restart ssh +``` +service ssh restart +``` + +## 1. Corpus Preprocessing +```bash +cd preprocessing +``` +following the `README.md`, preprocess orginal corpus to h5py+numpy + +## 2. Pretrain + +```bash +cd pretraining +``` +following the `README.md`, load the h5py generated by preprocess of step 1 to pretrain the model + +## 3. Finetune + +The checkpoint produced by this repo can replace `pytorch_model.bin` from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) directly. Then use transfomers from HuggingFace to finetune downstream application. + +## Contributors +The repo is contributed by AI team from [Moore Threads](https://www.mthreads.com/). If you find any problems for pretraining, please file an issue or send an email to yehua.zhang@mthreads.com. At last, welcome any form of contribution! + +``` +@misc{ + title={A simple Chinese RoBERTa Example for Whole Word Masked}, + author={Yehua Zhang, Chen Zhang}, + year={2022} +} +``` \ No newline at end of file diff --git a/examples/language/roberta/configs/colossalai_ddp.py b/examples/language/roberta/configs/colossalai_ddp.py new file mode 100644 index 000000000..c3c59aa40 --- /dev/null +++ b/examples/language/roberta/configs/colossalai_ddp.py @@ -0,0 +1,4 @@ +from colossalai.zero.shard_utils import TensorShardStrategy +from colossalai.nn.optimizer import FusedAdam + +clip_grad_norm = 1.0 diff --git a/examples/language/roberta/configs/colossalai_zero.py b/examples/language/roberta/configs/colossalai_zero.py new file mode 100644 index 000000000..c5debdce0 --- /dev/null +++ b/examples/language/roberta/configs/colossalai_zero.py @@ -0,0 +1,32 @@ +from colossalai.zero.shard_utils import TensorShardStrategy +from colossalai.nn.optimizer import FusedAdam + +# fp16 = dict( +# mode=AMP_TYPE.TORCH, +# ) + +# seed = 2 +zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(), + reduce_scatter_bucket_size_mb=25, + fp32_reduce_scatter=False, + tensor_placement_policy="cuda", + gradient_predivide_factor=1.0, + reuse_fp16_shard=False), + optimizer_config=dict(gpu_margin_mem_ratio=0.8, + initial_scale=2**5, + min_scale=1, + growth_factor=2, + backoff_factor=0.5, + growth_interval=1000, + hysteresis=2, + max_scale=2**32)) + +# gradient_accumulation = 4 +clip_grad_norm = 1.0 +optimizer = dict( + type=FusedAdam, + lr=0.00015, + weight_decay=1e-2, +) + +# 64433 \ No newline at end of file diff --git a/examples/language/roberta/preprocessing/Makefile b/examples/language/roberta/preprocessing/Makefile new file mode 100644 index 000000000..82ee4e1c5 --- /dev/null +++ b/examples/language/roberta/preprocessing/Makefile @@ -0,0 +1,9 @@ +CXXFLAGS += -O3 -Wall -shared -std=c++14 -fPIC -fdiagnostics-color +CPPFLAGS += $(shell python3 -m pybind11 --includes) +LIBNAME = mask +LIBEXT = $(shell python3-config --extension-suffix) + +default: $(LIBNAME)$(LIBEXT) + +%$(LIBEXT): %.cpp + $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ diff --git a/examples/language/roberta/preprocessing/README.md b/examples/language/roberta/preprocessing/README.md new file mode 100644 index 000000000..1dbd745ab --- /dev/null +++ b/examples/language/roberta/preprocessing/README.md @@ -0,0 +1,105 @@ +# Data PreProcessing for chinese Whole Word Masked + + + +## Catalogue: +* 1. Introduction +* 2. Quick Start Guide: + * 2.1. Split Sentence + * 2.2.Tokenizer & Whole Word Masked + + + + +## 1. Introduction: [Back to Top] +This folder is used to preprocess chinese corpus with Whole Word Masked. You can obtain corpus from [WuDao](https://resource.wudaoai.cn/home?ind&name=WuDaoCorpora%202.0&id=1394901288847716352). Moreover, data preprocessing is flexible, and you can modify the code based on your needs, hardware or parallel framework(Open MPI, Spark, Dask). + + + +## 2. Quick Start Guide: [Back to Top] + + + +### 2.1. Split Sentence & Split data into multiple shard: +Firstly, each file has multiple documents, and each document contains multiple sentences. Split sentence through punctuation, such as `。!`. **Secondly, split data into multiple shard based on server hardware (cpu, cpu memory, hard disk) and corpus size.** Each shard contains a part of corpus, and the model needs to train all the shards as one epoch. +In this example, split 200G Corpus into 100 shard, and each shard is about 2G. The size of the shard is memory-dependent, taking into account the number of servers, the memory used by the tokenizer, and the memory used by the multi-process training to read the shard (n data parallel requires n\*shard_size memory). **To sum up, data preprocessing and model pretraining requires fighting with hardware, not just GPU.** + +```python +python sentence_split.py --input_path /orginal_corpus --output_path /shard --shard 100 +# This step takes a short time +``` +* `--input_path`: all original corpus, e.g., /orginal_corpus/0.json /orginal_corpus/1.json ... +* `--output_path`: all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ... +* `--shard`: Number of shard, e.g., 10, 50, or 100 + +Input json: + +``` +[ + { + "id": 0, + "title": "打篮球", + "content": "我今天去打篮球。不回来吃饭。" + } + { + "id": 1, + "title": "旅游", + "content": "我后天去旅游。下周请假。" + } +] +``` + +Output txt: + +``` +我今天去打篮球。 +不回来吃饭。 +]] +我后天去旅游。 +下周请假。 +``` + + + +### 2.2. Tokenizer & Whole Word Masked: + +```python +python tokenize_mask.py --input_path /shard --output_path /h5 --tokenizer_path /roberta --backend python +# This step is time consuming and is mainly spent on mask +``` + +**[optional but recommended]**: the C++ backend with `pybind11` can provide faster speed + +```shell +make +``` + +* `--input_path`: location of all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ... +* `--output_path`: location of all h5 with token_id, input_mask, segment_ids and masked_lm_positions, e.g., /h5/0.h5, /h5/1.h5 ... +* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json. Download config.json, special_tokens_map.json, vocab.txt and tokenzier.json from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) +* `--backend`: python or c++, **specifies c++ can obtain faster preprocess speed** +* `--dupe_factor`: specifies how many times the preprocessor repeats to create the input from the same article/document +* `--worker`: number of process + +Input txt: + +``` +我今天去打篮球。 +不回来吃饭。 +]] +我后天去旅游。 +下周请假。 +``` + +Output h5+numpy: + +``` +'input_ids': [[id0,id1,id2,id3,id4,id5,id6,0,0..], + ...] +'input_mask': [[1,1,1,1,1,1,0,0..], + ...] +'segment_ids': [[0,0,0,0,0,...], + ...] +'masked_lm_positions': [[label1,-1,-1,label2,-1...], + ...] +``` \ No newline at end of file diff --git a/examples/language/roberta/preprocessing/get_mask.py b/examples/language/roberta/preprocessing/get_mask.py new file mode 100644 index 000000000..da297f98e --- /dev/null +++ b/examples/language/roberta/preprocessing/get_mask.py @@ -0,0 +1,266 @@ +import torch +import os +from enum import IntEnum +from random import choice +import random +import collections +import time +import logging +import jieba +jieba.setLogLevel(logging.CRITICAL) +import re +import numpy as np +import mask + +PAD = 0 +MaskedLMInstance = collections.namedtuple("MaskedLMInstance", + ["index", "label"]) + + +def map_to_numpy(data): + return np.asarray(data) + + +class PreTrainingDataset(): + def __init__(self, + tokenizer, + max_seq_length, + backend='python', + max_predictions_per_seq: int = 80, + do_whole_word_mask: bool = True): + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + self.masked_lm_prob = 0.15 + self.backend = backend + self.do_whole_word_mask = do_whole_word_mask + self.max_predictions_per_seq = max_predictions_per_seq + self.vocab_words = list(tokenizer.vocab.keys()) + self.rec = re.compile('[\u4E00-\u9FA5]') + self.whole_rec = re.compile('##[\u4E00-\u9FA5]') + + self.mlm_p = 0.15 + self.mlm_mask_p = 0.8 + self.mlm_tamper_p = 0.05 + self.mlm_maintain_p = 0.1 + + + def tokenize(self, doc): + temp = [] + for d in doc: + temp.append(self.tokenizer.tokenize(d)) + return temp + + + def create_training_instance(self, instance): + is_next = 1 + raw_text_list = self.get_new_segment(instance) + tokens_a = raw_text_list + assert len(tokens_a) == len(instance) + # tokens_a, tokens_b, is_next = instance.get_values() + # print(f'is_next label:{is_next}') + # Create mapper + tokens = [] + original_tokens = [] + segment_ids = [] + tokens.append("[CLS]") + original_tokens.append('[CLS]') + segment_ids.append(0) + for index, token in enumerate(tokens_a): + tokens.append(token) + original_tokens.append(instance[index]) + segment_ids.append(0) + + tokens.append("[SEP]") + original_tokens.append('[SEP]') + segment_ids.append(0) + + # for token in tokens_b: + # tokens.append(token) + # segment_ids.append(1) + + # tokens.append("[SEP]") + # segment_ids.append(1) + + # Get Masked LM predictions + if self.backend == 'c++': + output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions(tokens, original_tokens, self.vocab_words, + self.tokenizer.vocab, self.max_predictions_per_seq, self.masked_lm_prob) + elif self.backend == 'python': + output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens) + + # Convert to Ids + input_ids = self.tokenizer.convert_tokens_to_ids(output_tokens) + input_mask = [1] * len(input_ids) + + while len(input_ids) < self.max_seq_length: + input_ids.append(PAD) + segment_ids.append(PAD) + input_mask.append(PAD) + masked_lm_output.append(-1) + return ([ + map_to_numpy(input_ids), + map_to_numpy(input_mask), + map_to_numpy(segment_ids), + map_to_numpy(masked_lm_output), + map_to_numpy([is_next]) + ]) + + + def create_masked_lm_predictions(self, tokens): + cand_indexes = [] + for i, token in enumerate(tokens): + if token == "[CLS]" or token == "[SEP]": + continue + if (self.do_whole_word_mask and len(cand_indexes) >= 1 and + token.startswith("##")): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + + # cand_indexes.append(i) + + random.shuffle(cand_indexes) + output_tokens = list(tokens) + + num_to_predict = min( + self.max_predictions_per_seq, + max(1, int(round(len(tokens) * self.masked_lm_prob)))) + + masked_lms = [] + covered_indexes = set() + for index in cand_indexes: + if len(masked_lms) >= num_to_predict: + break + if index in covered_indexes: + continue + covered_indexes.add(index) + + masked_token = None + # 80% mask + if random.random() < 0.8: + masked_token = "[MASK]" + else: + # 10% Keep Original + if random.random() < 0.5: + masked_token = tokens[index] + # 10% replace w/ random word + else: + masked_token = self.vocab_words[random.randint( + 0, + len(self.vocab_words) - 1)] + + output_tokens[index] = masked_token + masked_lms.append( + MaskedLMInstance(index=index, label=tokens[index])) + + masked_lms = sorted(masked_lms, key=lambda x: x.index) + masked_lm_output = [-1] * len(output_tokens) + for p in masked_lms: + masked_lm_output[p.index] = self.tokenizer.vocab[p.label] + + return (output_tokens, masked_lm_output) + + + def get_new_segment(self, segment): + """ + 输入一句话,返回一句经过处理的话: 为了支持中文全称mask,将被分开的词,将上特殊标记("#"),使得后续处理模块,能够知道哪些字是属于同一个词的。 + :param segment: 一句话 + :return: 一句处理过的话 + """ + seq_cws = jieba.lcut(''.join(segment)) + seq_cws_dict = {x: 1 for x in seq_cws} + new_segment = [] + i = 0 + while i < len(segment): + if len(self.rec.findall(segment[i])) == 0: # 不是中文的,原文加进去。 + new_segment.append(segment[i]) + i += 1 + continue + + has_add = False + for length in range(3, 0, -1): + if i + length > len(segment): + continue + if ''.join(segment[i: i+length]) in seq_cws_dict: + new_segment.append(segment[i]) + for l in range(1, length): + new_segment.append('##' + segment[i+l]) + i += length + has_add = True + break + if not has_add: + new_segment.append(segment[i]) + i += 1 + return new_segment + + + def create_whole_masked_lm_predictions(self, tokens): + """Creates the predictions for the masked LM objective.""" + + cand_indexes = [] + for (i, token) in enumerate(tokens): + if token == "[CLS]" or token == "[SEP]": + continue + # Whole Word Masking means that if we mask all of the wordpieces + # corresponding to an original word. When a word has been split into + # WordPieces, the first token does not have any marker and any subsequence + # tokens are prefixed with ##. So whenever we see the ## token, we + # append it to the previous set of word indexes. + # + # Note that Whole Word Masking does *not* change the training code + # at all -- we still predict each WordPiece independently, softmaxed + # over the entire vocabulary. + if (self.do_whole_word_mask and len(cand_indexes) >= 1 and + token.startswith("##")): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + + random.shuffle(cand_indexes) + + output_tokens = [t[2:] if len(self.whole_rec.findall(t))>0 else t for t in tokens] # 去掉"##" + + num_to_predict = min(self.max_predictions_per_seq, + max(1, int(round(len(tokens) * self.masked_lm_prob)))) + + masked_lms = [] + covered_indexes = set() + for index_set in cand_indexes: + if len(masked_lms) >= num_to_predict: + break + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(masked_lms) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + covered_indexes.add(index) + + masked_token = None + # 80% of the time, replace with [MASK] + if random.random() < 0.8: + masked_token = "[MASK]" + else: + # 10% of the time, keep original + if random.random() < 0.5: + masked_token = tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index] # 去掉"##" + # 10% of the time, replace with random word + else: + masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)] + + output_tokens[index] = masked_token + + masked_lms.append(MaskedLMInstance(index=index, label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index])) + assert len(masked_lms) <= num_to_predict + masked_lms = sorted(masked_lms, key=lambda x: x.index) + masked_lm_output = [-1] * len(output_tokens) + for p in masked_lms: + masked_lm_output[p.index] = self.tokenizer.vocab[p.label] + + return (output_tokens, masked_lm_output) diff --git a/examples/language/roberta/preprocessing/mask.cpp b/examples/language/roberta/preprocessing/mask.cpp new file mode 100644 index 000000000..8355c45cf --- /dev/null +++ b/examples/language/roberta/preprocessing/mask.cpp @@ -0,0 +1,184 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +const int32_t LONG_SENTENCE_LEN = 512; + +struct MaskedLMInstance { + int index; + std::string label; + MaskedLMInstance(int index, std::string label) { + this->index = index; + this->label = label; + } +}; + +auto get_new_segment(std::vector segment, std::vector segment_jieba, const std::vector chinese_vocab) { // const std::unordered_set &chinese_vocab + std::unordered_set seq_cws_dict; + for (auto word : segment_jieba) { + seq_cws_dict.insert(word); + } + int i = 0; + std::vector new_segment; + int segment_size = segment.size(); + while (i < segment_size) { + if (!chinese_vocab[i]) { //chinese_vocab.find(segment[i]) == chinese_vocab.end() + new_segment.emplace_back(segment[i]); + i += 1; + continue; + } + bool has_add = false; + for (int length = 3; length >= 1; length--) { + if (i + length > segment_size) { + continue; + } + std::string chinese_word = ""; + for (int j = i; j < i + length; j++) { + chinese_word += segment[j]; + } + if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) { + new_segment.emplace_back(segment[i]); + for (int j = i + 1; j < i + length; j++) { + new_segment.emplace_back("##" + segment[j]); + } + i += length; + has_add = true; + break; + } + } + if (!has_add) { + new_segment.emplace_back(segment[i]); + i += 1; + } + } + + return new_segment; +} + +bool startsWith(const std::string& s, const std::string& sub) { + return s.find(sub) == 0 ? true : false; +} + +auto create_whole_masked_lm_predictions(std::vector &tokens, + const std::vector &original_tokens, + const std::vector &vocab_words, + std::map &vocab, + const int max_predictions_per_seq, + const double masked_lm_prob) { + // for (auto item : vocab) { + // std::cout << "key=" << std::string(py::str(item.first)) << ", " + // << "value=" << std::string(py::str(item.second)) << std::endl; + // } + std::vector > cand_indexes; + std::vector cand_temp; + int tokens_size = tokens.size(); + std::string prefix = "##"; + bool do_whole_masked = true; + + for (int i = 0; i < tokens_size; i++) { + if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") { + continue; + } + if (do_whole_masked && (cand_indexes.size() > 0) && (tokens[i].rfind(prefix, 0) == 0)) { + cand_temp.emplace_back(i); + } + else { + if (cand_temp.size() > 0) { + cand_indexes.emplace_back(cand_temp); + } + cand_temp.clear(); + cand_temp.emplace_back(i); + } + } + auto seed = std::chrono::system_clock::now().time_since_epoch().count(); + std::shuffle(cand_indexes.begin(), cand_indexes.end(), std::default_random_engine(seed)); + // for (auto i : cand_indexes) { + // for (auto j : i) { + // std::cout << tokens[j] << " "; + // } + // std::cout << std::endl; + // } + // for (auto i : output_tokens) { + // std::cout << i; + // } + // std::cout << std::endl; + + int num_to_predict = std::min(max_predictions_per_seq, + std::max(1, int(tokens_size * masked_lm_prob))); + // std::cout << num_to_predict << std::endl; + + std::set covered_indexes; + std::vector masked_lm_output(tokens_size, -1); + int vocab_words_len = vocab_words.size(); + std::default_random_engine e(seed); + std::uniform_real_distribution u1(0.0, 1.0); + std::uniform_int_distribution u2(0, vocab_words_len - 1); + int mask_cnt = 0; + std::vector output_tokens; + output_tokens = original_tokens; + + for (auto index_set : cand_indexes) { + if (mask_cnt > num_to_predict) { + break; + } + int index_set_size = index_set.size(); + if (mask_cnt + index_set_size > num_to_predict) { + continue; + } + bool is_any_index_covered = false; + for (auto index : index_set) { + if (covered_indexes.find(index) != covered_indexes.end()) { + is_any_index_covered = true; + break; + } + } + if (is_any_index_covered) { + continue; + } + for (auto index : index_set) { + + covered_indexes.insert(index); + std::string masked_token; + if (u1(e) < 0.8) { + masked_token = "[MASK]"; + } + else { + if (u1(e) < 0.5) { + masked_token = output_tokens[index]; + } + else { + int random_index = u2(e); + masked_token = vocab_words[random_index]; + } + } + // masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index])); + masked_lm_output[index] = vocab[output_tokens[index]]; + output_tokens[index] = masked_token; + mask_cnt++; + } + } + + // for (auto p : masked_lms) { + // masked_lm_output[p.index] = vocab[p.label]; + // } + return std::make_tuple(output_tokens, masked_lm_output); +} + +PYBIND11_MODULE(mask, m) { + m.def("create_whole_masked_lm_predictions", &create_whole_masked_lm_predictions); + m.def("get_new_segment", &get_new_segment); +} diff --git a/examples/language/roberta/preprocessing/sentence_split.py b/examples/language/roberta/preprocessing/sentence_split.py new file mode 100644 index 000000000..231be152b --- /dev/null +++ b/examples/language/roberta/preprocessing/sentence_split.py @@ -0,0 +1,163 @@ + +import multiprocessing +import os +import re +from tqdm import tqdm +from typing import List +import json +import time +import argparse +import functools + +def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]: + """ + Args: + document: + flag: Type:str, "all" 中英文标点分句,"zh" 中文标点分句,"en" 英文标点分句 + limit: 默认单句最大长度为510个字符 + Returns: Type:list + """ + sent_list = [] + try: + if flag == "zh": + document = re.sub('(?P([。?!…](?![”’"\'])))', r'\g\n', document) # 单字符断句符 + document = re.sub('(?P([。?!]|…{1,2})[”’"\'])', r'\g\n', document) # 特殊引号 + elif flag == "en": + document = re.sub('(?P([.?!](?![”’"\'])))', r'\g\n', document) # 英文单字符断句符 + document = re.sub('(?P([?!.]["\']))', r'\g\n', document) # 特殊引号 + else: + document = re.sub('(?P([。?!….?!](?![”’"\'])))', r'\g\n', document) # 单字符断句符 + + document = re.sub('(?P(([。?!.!?]|…{1,2})[”’"\']))', r'\g\n', + document) # 特殊引号 + + sent_list_ori = document.splitlines() + for sent in sent_list_ori: + sent = sent.strip() + if not sent: + continue + elif len(sent) <= 2: + continue + else: + while len(sent) > limit: + temp = sent[0:limit] + sent_list.append(temp) + sent = sent[limit:] + sent_list.append(sent) + except: + sent_list.clear() + sent_list.append(document) + return sent_list + + +def get_sent(output_path, + input_path, + fin_list=[], host=-1, seq_len=512) -> None: + + workers = 32 + + if input_path[-1] == '/': + input_path = input_path[:-1] + + cur_path = os.path.join(output_path, str(host) + '.txt') + new_split_sentence = functools.partial(split_sentence, limit=seq_len-2) + with open(cur_path, 'w', encoding='utf-8') as f: + for fi, fin_path in enumerate(fin_list): + if not os.path.exists(os.path.join(input_path, fin_path[0])): + continue + if '.json' not in fin_path[0]: + continue + + print("Processing ", fin_path[0], " ", fi) + + with open(os.path.join(input_path, fin_path[0]), 'r') as fin: + f_data = [l['content'] for l in json.load(fin)] + + pool = multiprocessing.Pool(workers) + all_sent = pool.imap_unordered(new_split_sentence, f_data, 32) + pool.close() + print('finished..') + + cnt = 0 + for d in tqdm(all_sent): + for i in d: + f.write(i.strip() + '\n') + f.write(']]' + '\n') + cnt += 1 + # if cnt >= 2: + # exit() + + +def getFileSize(filepath, shard): + all_data = [] + for i in os.listdir(filepath): + all_data.append(os.path.join(filepath, i)) + all_size = sum([os.path.getsize(os.path.join(filepath, f)) for f in all_data]) + ans = [[f.split('/')[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data] + ans = sorted(ans, key=lambda x: x[1], reverse=True) + per_size = all_size / shard + real_shard = [] + temp = [] + accu_size = 0 + for i in ans: + accu_size += i[1] + temp.append(i) + if accu_size > per_size: + real_shard.append(temp) + accu_size = 0 + temp = [] + + if len(temp) > 0: + real_shard.append(temp) + + return real_shard + + +def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'): + import socket + host = int(socket.gethostname().split(server_name)[-1]) + + fin_list = real_shard[server_num * base + host - 1] + print(fin_list) + print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}') + return fin_list, host + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--server_num', type=int, default=10, help='number of servers') + parser.add_argument('--seq_len', type=int, default=512, help='sequence length') + parser.add_argument('--shard', type=int, default=100, help='number of shards, e.g., 10, 50, or 100') + parser.add_argument('--input_path', type=str, required=True, help='input path of original corpus') + parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence') + args = parser.parse_args() + + server_num = args.server_num + seq_len = args.seq_len + shard = args.shard + input_path = args.input_path + output_path = args.output_path + + real_shard = getFileSize(input_path, shard) + + start = time.time() + for index, shard in enumerate(real_shard): + get_sent(output_path, + input_path, + fin_list=shard, + host=index, + seq_len=seq_len) + print(f'cost {str(time.time() - start)}') + + # if you have multiple server, you can use code below or modify code to openmpi + + # for i in range(len(real_shard) // server_num + 1): + # fin_list, host = get_start_end(real_shard, i) + + # start = time.time() + # get_sent(output_path, + # input_path, + # fin_list=fin_list, host= 10 * i + host - 1) + + # print(f'cost {str(time.time() - start)}') diff --git a/examples/language/roberta/preprocessing/tokenize_mask.py b/examples/language/roberta/preprocessing/tokenize_mask.py new file mode 100644 index 000000000..b33871d5d --- /dev/null +++ b/examples/language/roberta/preprocessing/tokenize_mask.py @@ -0,0 +1,275 @@ +import time +import os +import psutil +import h5py +import socket +import argparse +import numpy as np +import multiprocessing +from tqdm import tqdm +from random import shuffle +from transformers import AutoTokenizer +from get_mask import PreTrainingDataset + + +def get_raw_instance(document, max_sequence_length=512): + + """ + 获取初步的训练实例,将整段按照max_sequence_length切分成多个部分,并以多个处理好的实例的形式返回。 + :param document: 一整段 + :param max_sequence_length: + :return: a list. each element is a sequence of text + """ + # document = self.documents[index] + max_sequence_length_allowed = max_sequence_length - 2 + # document = [seq for seq in document if len(seq)= max_sequence_length_allowed: + if len(curr_seq) > 0: + result_list.append(curr_seq) + curr_seq = [] + result_list.append(document[sz_idx][ : max_sequence_length_allowed]) + sz_idx += 1 + else: + result_list.append(curr_seq) + curr_seq = [] + # 对最后一个序列进行处理,如果太短的话,丢弃掉。 + if len(curr_seq) > max_sequence_length_allowed / 2: # /2 + result_list.append(curr_seq) + + # # 计算总共可以得到多少份 + # num_instance=int(len(big_list)/max_sequence_length_allowed)+1 + # print("num_instance:",num_instance) + # # 切分成多份,添加到列表中 + # result_list=[] + # for j in range(num_instance): + # index=j*max_sequence_length_allowed + # end_index=index+max_sequence_length_allowed if j!=num_instance-1 else -1 + # result_list.append(big_list[index:end_index]) + return result_list + + +def split_numpy_chunk(path, tokenizer, pretrain_data, host): + + documents = [] + instances = [] + + s = time.time() + with open(path, encoding='utf-8') as fd: + document = [] + for i, line in enumerate(tqdm(fd)): + line = line.strip() + # document = line + # if len(document.split("")) <= 3: + # continue + if len(line + ) > 0 and line[:2] == "]]": # This is end of document + documents.append(document) + document = [] + elif len(line) >= 2: + document.append(line) + if len(document) > 0: + documents.append(document) + print('read_file ', time.time() - s) + + # documents = [x for x in documents if x] + # print(len(documents)) + # print(len(documents[0])) + # print(documents[0][0:10]) + from typing import List + import multiprocessing + + ans = [] + for docs in tqdm(documents): + ans.append(pretrain_data.tokenize(docs)) + print(time.time() - s) + del documents + + instances = [] + for a in tqdm(ans): + raw_ins = get_raw_instance(a) + instances.extend(raw_ins) + del ans + + print('len instance', len(instances)) + + sen_num = len(instances) + seq_len = 512 + input_ids = np.zeros([sen_num, seq_len], dtype=np.int32) + input_mask = np.zeros([sen_num, seq_len], dtype=np.int32) + segment_ids = np.zeros([sen_num, seq_len], dtype=np.int32) + masked_lm_output = np.zeros([sen_num, seq_len], dtype=np.int32) + + for index, ins in tqdm(enumerate(instances)): + mask_dict = pretrain_data.create_training_instance(ins) + input_ids[index] = mask_dict[0] + input_mask[index] = mask_dict[1] + segment_ids[index] = mask_dict[2] + masked_lm_output[index] = mask_dict[3] + + with h5py.File(f'/output/{host}.h5', 'w') as hf: + hf.create_dataset("input_ids", data=input_ids) + hf.create_dataset("input_mask", data=input_ids) + hf.create_dataset("segment_ids", data=segment_ids) + hf.create_dataset("masked_lm_positions", data=masked_lm_output) + + del instances + + +def split_numpy_chunk_pool(input_path, + output_path, + pretrain_data, + worker, + dupe_factor, + seq_len, + file_name): + + if os.path.exists(os.path.join(output_path, f'{file_name}.h5')): + print(f'{file_name}.h5 exists') + return + + documents = [] + instances = [] + + s = time.time() + with open(input_path, 'r', encoding='utf-8') as fd: + document = [] + for i, line in enumerate(tqdm(fd)): + line = line.strip() + if len(line + ) > 0 and line[:2] == "]]": # This is end of document + documents.append(document) + document = [] + elif len(line) >= 2: + document.append(line) + if len(document) > 0: + documents.append(document) + print(f'read_file cost {time.time() - s}, length is {len(documents)}') + + ans = [] + s = time.time() + pool = multiprocessing.Pool(worker) + encoded_doc = pool.imap_unordered(pretrain_data.tokenize, documents, 100) + for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour='cyan'): + ans.append(res) + pool.close() + print((time.time() - s) / 60) + del documents + + instances = [] + for a in tqdm(ans, colour='MAGENTA'): + raw_ins = get_raw_instance(a, max_sequence_length=seq_len) + instances.extend(raw_ins) + del ans + + print('len instance', len(instances)) + + new_instances = [] + for _ in range(dupe_factor): + for ins in instances: + new_instances.append(ins) + + shuffle(new_instances) + instances = new_instances + print('after dupe_factor, len instance', len(instances)) + + sentence_num = len(instances) + input_ids = np.zeros([sentence_num, seq_len], dtype=np.int32) + input_mask = np.zeros([sentence_num, seq_len], dtype=np.int32) + segment_ids = np.zeros([sentence_num, seq_len], dtype=np.int32) + masked_lm_output = np.zeros([sentence_num, seq_len], dtype=np.int32) + + s = time.time() + pool = multiprocessing.Pool(worker) + encoded_docs = pool.imap_unordered(pretrain_data.create_training_instance, instances, 32) + for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour='blue'): + input_ids[index] = mask_dict[0] + input_mask[index] = mask_dict[1] + segment_ids[index] = mask_dict[2] + masked_lm_output[index] = mask_dict[3] + pool.close() + print((time.time() - s) / 60) + + with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf: + hf.create_dataset("input_ids", data=input_ids) + hf.create_dataset("input_mask", data=input_mask) + hf.create_dataset("segment_ids", data=segment_ids) + hf.create_dataset("masked_lm_positions", data=masked_lm_output) + + del instances + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer') + parser.add_argument('--seq_len', type=int, default=512, help='sequence length') + parser.add_argument('--max_predictions_per_seq', type=int, default=80, help='number of shards, e.g., 10, 50, or 100') + parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence') + parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id') + parser.add_argument('--backend', type=str, default='python', help='backend of mask token, python, c++, numpy respectively') + parser.add_argument('--dupe_factor', type=int, default=1, help='specifies how many times the preprocessor repeats to create the input from the same article/document') + parser.add_argument('--worker', type=int, default=32, help='number of process') + parser.add_argument('--server_num', type=int, default=10, help='number of servers') + args = parser.parse_args() + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + pretrain_data = PreTrainingDataset(tokenizer, + args.seq_len, + args.backend, + max_predictions_per_seq=args.max_predictions_per_seq) + + + data_len = len(os.listdir(args.input_path)) + + for i in range(data_len): + input_path = os.path.join(args.input_path, f'{i}.txt') + if os.path.exists(input_path): + start = time.time() + print(f'process {input_path}') + split_numpy_chunk_pool(input_path, + args.output_path, + pretrain_data, + args.worker, + args.dupe_factor, + args.seq_len, + i) + end_ = time.time() + print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ) + print(f'has cost {(end_ - start) / 60}') + print('-' * 100) + print('') + + # if you have multiple server, you can use code below or modify code to openmpi + + # host = int(socket.gethostname().split('GPU')[-1]) + # for i in range(data_len // args.server_num + 1): + # h = args.server_num * i + host - 1 + # input_path = os.path.join(args.input_path, f'{h}.txt') + # if os.path.exists(input_path): + # start = time.time() + # print(f'I am server {host}, process {input_path}') + # split_numpy_chunk_pool(input_path, + # args.output_path, + # pretrain_data, + # args.worker, + # args.dupe_factor, + # args.seq_len, + # h) + # end_ = time.time() + # print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ) + # print(f'has cost {(end_ - start) / 60}') + # print('-' * 100) + # print('') + + diff --git a/examples/language/roberta/pretraining/README.md b/examples/language/roberta/pretraining/README.md new file mode 100644 index 000000000..055d69696 --- /dev/null +++ b/examples/language/roberta/pretraining/README.md @@ -0,0 +1,24 @@ +# Pretraining +1. Pretraining roberta through running the script below. Detailed parameter descriptions can be found in the arguments.py. `data_path_prefix` is absolute path specifies output of preprocessing. **You have to modify the *hostfile* according to your cluster.** + +```bash +bash run_pretrain.sh +``` +* `--hostfile`: servers' host name from /etc/hosts +* `--include`: servers which will be used +* `--nproc_per_node`: number of process(GPU) from each server +* `--data_path_prefix`: absolute location of train data, e.g., /h5/0.h5 +* `--eval_data_path_prefix`: absolute location of eval data +* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json, e.g./tokenizer/tokenizer.json +* `--bert_config`: config.json which represent model +* `--mlm`: model type of backbone, bert or deberta_v2 + +2. if resume training from earylier checkpoint, run the script below. + +```shell +bash run_pretrain_resume.sh +``` +* `--resume_train`: whether to resume training +* `--load_pretrain_model`: absolute path which contains model checkpoint +* `--load_optimizer_lr`: absolute path which contains optimizer checkpoint + diff --git a/examples/language/roberta/pretraining/arguments.py b/examples/language/roberta/pretraining/arguments.py new file mode 100644 index 000000000..3a9370e00 --- /dev/null +++ b/examples/language/roberta/pretraining/arguments.py @@ -0,0 +1,152 @@ +import colossalai +from numpy import require + +__all__ = ['parse_args'] + + +def parse_args(): + parser = colossalai.get_default_parser() + + parser.add_argument( + '--lr', + type=float, + required=True, + help='initial learning rate') + parser.add_argument( + '--epoch', + type=int, + required=True, + help='number of epoch') + parser.add_argument( + '--data_path_prefix', + type=str, + required=True, + help="location of the train data corpus") + parser.add_argument( + '--eval_data_path_prefix', + type=str, + required=True, + help='location of the evaluation data corpus') + parser.add_argument( + '--tokenizer_path', + type=str, + required=True, + help='location of the tokenizer') + parser.add_argument( + '--max_seq_length', + type=int, + default=512, + help='sequence length') + parser.add_argument( + '--refresh_bucket_size', + type=int, + default=1, + help= + "This param makes sure that a certain task is repeated for this time steps to \ + optimise on the back propogation speed with APEX's DistributedDataParallel") + parser.add_argument( + "--max_predictions_per_seq", + "--max_pred", + default=80, + type=int, + help= + "The maximum number of masked tokens in a sequence to be predicted.") + parser.add_argument( + "--gradient_accumulation_steps", + default=1, + type=int, + help="accumulation_steps") + parser.add_argument( + "--train_micro_batch_size_per_gpu", + default=2, + type=int, + required=True, + help="train batch size") + parser.add_argument( + "--eval_micro_batch_size_per_gpu", + default=2, + type=int, + required=True, + help="eval batch size") + parser.add_argument( + "--num_workers", + default=8, + type=int, + help="") + parser.add_argument( + "--async_worker", + action='store_true', + help="") + parser.add_argument( + "--bert_config", + required=True, + type=str, + help="location of config.json") + parser.add_argument( + "--wandb", + action='store_true', + help="use wandb to watch model") + parser.add_argument( + "--wandb_project_name", + default='roberta', + help="wandb project name") + parser.add_argument( + "--log_interval", + default=100, + type=int, + help="report interval") + parser.add_argument( + "--log_path", + type=str, + required=True, + help="log file which records train step") + parser.add_argument( + "--tensorboard_path", + type=str, + required=True, + help="location of tensorboard file") + parser.add_argument( + "--colossal_config", + type=str, + required=True, + help="colossal config, which contains zero config and so on") + parser.add_argument( + "--ckpt_path", + type=str, + required=True, + help="location of saving checkpoint, which contains model and optimizer") + parser.add_argument( + '--seed', + type=int, + default=42, + help="random seed for initialization") + parser.add_argument( + '--vscode_debug', + action='store_true', + help="use vscode to debug") + parser.add_argument( + '--load_pretrain_model', + default='', + type=str, + help="location of model's checkpoin") + parser.add_argument( + '--load_optimizer_lr', + default='', + type=str, + help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step") + parser.add_argument( + '--resume_train', + action='store_true', + help="whether resume training from a early checkpoint") + parser.add_argument( + '--mlm', + default='bert', + type=str, + help="model type, bert or deberta") + parser.add_argument( + '--checkpoint_activations', + action='store_true', + help="whether to use gradient checkpointing") + + args = parser.parse_args() + return args diff --git a/examples/language/roberta/pretraining/bert_dataset_provider.py b/examples/language/roberta/pretraining/bert_dataset_provider.py new file mode 100644 index 000000000..1d8cf2a91 --- /dev/null +++ b/examples/language/roberta/pretraining/bert_dataset_provider.py @@ -0,0 +1,15 @@ +class BertDatasetProviderInterface: + def get_shard(self, index, shuffle=True): + raise NotImplementedError + + def release_shard(self, index): + raise NotImplementedError + + def prefetch_shard(self, index): + raise NotImplementedError + + def get_batch(self, batch_iter): + raise NotImplementedError + + def prefetch_batch(self): + raise NotImplementedError diff --git a/examples/language/roberta/pretraining/evaluation.py b/examples/language/roberta/pretraining/evaluation.py new file mode 100644 index 000000000..83f94082f --- /dev/null +++ b/examples/language/roberta/pretraining/evaluation.py @@ -0,0 +1,71 @@ +import os +import math +import torch +from tqdm import tqdm +from utils.global_vars import get_timers, get_tensorboard_writer +from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider + +def evaluate(engine, args, logger, global_step): + evaluate_dataset_provider = NvidiaBertDatasetProvider(args, evaluate=True) + start_shard = 0 + + engine.eval() + timers = get_timers() + eval_step = 0 + eval_loss = 0 + cur_loss = 0 + world_size = torch.distributed.get_world_size() + + with torch.no_grad(): + + for shard in range(start_shard, len(os.listdir(args.eval_data_path_prefix))): + + timers('eval_shard_time').start() + + dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard) + # evaluate_dataset_provider.prefetch_shard(shard + 1) + if torch.distributed.get_rank() == 0: + iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), colour='MAGENTA', smoothing=1) + else: + iterator_data = enumerate(dataset_iterator) + + for step, batch_data in iterator_data: #tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1): + + # batch_data = pretrain_dataset_provider.get_batch(batch_index) + eval_step += 1 + input_ids = batch_data[0].cuda() + attention_mask = batch_data[1].cuda() + token_type_ids = batch_data[2].cuda() + mlm_label = batch_data[3].cuda() + # nsp_label = batch_data[5].cuda() + + output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + loss = engine.criterion(output.logits, mlm_label)#prediction_scores + evaluate_dataset_provider.prefetch_batch() + + eval_loss += loss.float().item() + + cur_loss = eval_loss / eval_step + elapsed_time = timers("eval_shard_time").elapsed() + elapsed_time_per_iteration = elapsed_time / eval_step + ppl = math.exp(cur_loss) + + if args.wandb and torch.distributed.get_rank() == 0: + tensorboard_log = get_tensorboard_writer() + tensorboard_log.log_eval({ + 'loss': cur_loss, + 'ppl': ppl, + 'mins_batch': elapsed_time_per_iteration + }, global_step) + + eval_log_str = f'evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \ + f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}' + + logger.info(eval_log_str) + logger.info('-' * 100) + logger.info('') + + evaluate_dataset_provider.release_shard() + engine.train() + return cur_loss diff --git a/examples/language/roberta/pretraining/hostfile b/examples/language/roberta/pretraining/hostfile new file mode 100644 index 000000000..f4e047f01 --- /dev/null +++ b/examples/language/roberta/pretraining/hostfile @@ -0,0 +1,10 @@ +GPU001 +GPU002 +GPU003 +GPU004 +GPU005 +GPU006 +GPU007 +GPU008 +GPU009 +GPU010 diff --git a/examples/language/roberta/pretraining/loss.py b/examples/language/roberta/pretraining/loss.py new file mode 100644 index 000000000..dc4f872a7 --- /dev/null +++ b/examples/language/roberta/pretraining/loss.py @@ -0,0 +1,17 @@ +import torch + +__all__ = ['LossForPretraining'] + + +class LossForPretraining(torch.nn.Module): + + def __init__(self, vocab_size): + super(LossForPretraining, self).__init__() + self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1) + self.vocab_size = vocab_size + + def forward(self, prediction_scores, masked_lm_labels, next_sentence_labels=None): + masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1)) + # next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1)) + total_loss = masked_lm_loss #+ next_sentence_loss + return total_loss diff --git a/examples/language/roberta/pretraining/model/bert.py b/examples/language/roberta/pretraining/model/bert.py new file mode 100644 index 000000000..67c85f760 --- /dev/null +++ b/examples/language/roberta/pretraining/model/bert.py @@ -0,0 +1,1893 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + + +import math +import os +import warnings +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from packaging import version +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.models.bert.configuration_bert import BertConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "bert-base-uncased" +_CONFIG_FOR_DOC = "BertConfig" +_TOKENIZER_FOR_DOC = "BertTokenizer" + +# TokenClassification docstring +_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english" +_TOKEN_CLASS_EXPECTED_OUTPUT = ( + "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] " +) +_TOKEN_CLASS_EXPECTED_LOSS = 0.01 + +# QuestionAnswering docstring +_CHECKPOINT_FOR_QA = "deepset/bert-base-cased-squad2" +_QA_EXPECTED_OUTPUT = "'a nice puppet'" +_QA_EXPECTED_LOSS = 7.41 +_QA_TARGET_START_INDEX = 14 +_QA_TARGET_END_INDEX = 15 + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity" +_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'" +_SEQ_CLASS_EXPECTED_LOSS = 0.01 + + +BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "bert-base-uncased", + "bert-large-uncased", + "bert-base-cased", + "bert-large-cased", + "bert-base-multilingual-uncased", + "bert-base-multilingual-cased", + "bert-base-chinese", + "bert-base-german-cased", + "bert-large-uncased-whole-word-masking", + "bert-large-cased-whole-word-masking", + "bert-large-uncased-whole-word-masking-finetuned-squad", + "bert-large-cased-whole-word-masking-finetuned-squad", + "bert-base-cased-finetuned-mrpc", + "bert-base-german-dbmdz-cased", + "bert-base-german-dbmdz-uncased", + "cl-tohoku/bert-base-japanese", + "cl-tohoku/bert-base-japanese-whole-word-masking", + "cl-tohoku/bert-base-japanese-char", + "cl-tohoku/bert-base-japanese-char-whole-word-masking", + "TurkuNLP/bert-base-finnish-cased-v1", + "TurkuNLP/bert-base-finnish-uncased-v1", + "wietsedv/bert-base-dutch-cased", + # See all BERT models at https://huggingface.co/models?filter=bert +] + + +def load_tf_weights_in_bert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + if version.parse(torch.__version__) > version.parse("1.6.0"): + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long), + persistent=False, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = BertAttention(config, position_embedding_type="absolute") + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + load_tf_weights = load_tf_weights_in_bert + base_model_prefix = "bert" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BertEncoder): + module.gradient_checkpointing = value + + +@dataclass +class BertForPreTrainingOutput(ModelOutput): + """ + Output type of [`BertForPreTraining`]. + + Args: + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: torch.FloatTensor = None + seq_relationship_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +BERT_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`BertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", + BERT_START_DOCSTRING, +) +class BertModel(BertPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next + sentence prediction (classification)` head. + """, + BERT_START_DOCSTRING, +) +class BertForPreTraining(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_sentence_label: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), + the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence + pair (see `input_ids` docstring) Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + kwargs (`Dict[str, any]`, optional, defaults to *{}*): + Used to hide legacy arguments that have been deprecated. + + Returns: + + Example: + + ```python + >>> from transformers import BertTokenizer, BertForPreTraining + >>> import torch + + >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + >>> model = BertForPreTraining.from_pretrained("bert-base-uncased") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING +) +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + +@add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING) +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="'paris'", + expected_loss=0.88, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + if self.config.pad_token_id is None: + raise ValueError("The PAD token should be defined for generation") + + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +@add_start_docstrings( + """Bert Model with a `next sentence prediction (classification)` head on top.""", + BERT_START_DOCSTRING, +) +class BertForNextSentencePrediction(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertOnlyNSPHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring). Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Returns: + + Example: + + ```python + >>> from transformers import BertTokenizer, BertForNextSentencePrediction + >>> import torch + + >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + ``` + """ + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use" + " `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + BERT_START_DOCSTRING, +) +class BertForSequenceClassification(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.bert = BertModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + BERT_START_DOCSTRING, +) +class BertForMultipleChoice(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + BERT_START_DOCSTRING, +) +class BertForTokenClassification(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT, + expected_loss=_TOKEN_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + BERT_START_DOCSTRING, +) +class BertForQuestionAnswering(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_QA, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + qa_target_start_index=_QA_TARGET_START_INDEX, + qa_target_end_index=_QA_TARGET_END_INDEX, + expected_output=_QA_EXPECTED_OUTPUT, + expected_loss=_QA_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/examples/language/roberta/pretraining/model/deberta_v2.py b/examples/language/roberta/pretraining/model/deberta_v2.py new file mode 100644 index 000000000..c6ce82847 --- /dev/null +++ b/examples/language/roberta/pretraining/model/deberta_v2.py @@ -0,0 +1,1631 @@ +# coding=utf-8 +# Copyright 2020 Microsoft and the Hugging Face Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeBERTa-v2 model.""" + +import math +from collections.abc import Sequence +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import softmax_backward_data +from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from transformers.models.deberta_v2.configuration_deberta_v2 import DebertaV2Config +from transformers import T5Tokenizer, T5ForConditionalGeneration, FillMaskPipeline + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DebertaV2Config" +_TOKENIZER_FOR_DOC = "DebertaV2Tokenizer" +_CHECKPOINT_FOR_DOC = "microsoft/deberta-v2-xlarge" + +DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/deberta-v2-xlarge", + "microsoft/deberta-v2-xxlarge", + "microsoft/deberta-v2-xlarge-mnli", + "microsoft/deberta-v2-xxlarge-mnli", +] + + +# Copied from transformers.models.deberta.modeling_deberta.ContextPooler +class ContextPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) + self.dropout = StableDropout(config.pooler_dropout) + self.config = config + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + + context_token = hidden_states[:, 0] + context_token = self.dropout(context_token) + pooled_output = self.dense(context_token) + pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output) + return pooled_output + + @property + def output_dim(self): + return self.config.hidden_size + + +# Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2 +class XSoftmax(torch.autograd.Function): + """ + Masked Softmax which is optimized for saving memory + + Args: + input (`torch.tensor`): The input tensor that will apply softmax. + mask (`torch.IntTensor`): + The mask matrix where 0 indicate that element will be ignored in the softmax calculation. + dim (int): The dimension that will apply softmax + + Example: + + ```python + >>> import torch + >>> from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax + + >>> # Make a tensor + >>> x = torch.randn([4, 20, 100]) + + >>> # Create a mask + >>> mask = (x > 0).int() + + >>> # Specify the dimension to apply softmax + >>> dim = -1 + + >>> y = XSoftmax.apply(x, mask, dim) + ```""" + + @staticmethod + def forward(self, input, mask, dim): + self.dim = dim + rmask = ~(mask.to(torch.bool)) + + output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min)) + output = torch.softmax(output, self.dim) + output.masked_fill_(rmask, 0) + self.save_for_backward(output) + return output + + @staticmethod + def backward(self, grad_output): + (output,) = self.saved_tensors + inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output) + return inputGrad, None, None + + @staticmethod + def symbolic(g, self, mask, dim): + import torch.onnx.symbolic_helper as sym_help + from torch.onnx.symbolic_opset9 import masked_fill, softmax + + mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"]) + r_mask = g.op( + "Cast", + g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), + to_i=sym_help.cast_pytorch_to_onnx["Byte"], + ) + output = masked_fill( + g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) + ) + output = softmax(g, output, dim) + return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) + + +# Copied from transformers.models.deberta.modeling_deberta.DropoutContext +class DropoutContext(object): + def __init__(self): + self.dropout = 0 + self.mask = None + self.scale = 1 + self.reuse_mask = True + + +# Copied from transformers.models.deberta.modeling_deberta.get_mask +def get_mask(input, local_context): + if not isinstance(local_context, DropoutContext): + dropout = local_context + mask = None + else: + dropout = local_context.dropout + dropout *= local_context.scale + mask = local_context.mask if local_context.reuse_mask else None + + if dropout > 0 and mask is None: + mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool) + + if isinstance(local_context, DropoutContext): + if local_context.mask is None: + local_context.mask = mask + + return mask, dropout + + +# Copied from transformers.models.deberta.modeling_deberta.XDropout +class XDropout(torch.autograd.Function): + """Optimized dropout function to save computation and memory by using mask operation instead of multiplication.""" + + @staticmethod + def forward(ctx, input, local_ctx): + mask, dropout = get_mask(input, local_ctx) + ctx.scale = 1.0 / (1 - dropout) + if dropout > 0: + ctx.save_for_backward(mask) + return input.masked_fill(mask, 0) * ctx.scale + else: + return input + + @staticmethod + def backward(ctx, grad_output): + if ctx.scale > 1: + (mask,) = ctx.saved_tensors + return grad_output.masked_fill(mask, 0) * ctx.scale, None + else: + return grad_output, None + + +# Copied from transformers.models.deberta.modeling_deberta.StableDropout +class StableDropout(nn.Module): + """ + Optimized dropout module for stabilizing the training + + Args: + drop_prob (float): the dropout probabilities + """ + + def __init__(self, drop_prob): + super().__init__() + self.drop_prob = drop_prob + self.count = 0 + self.context_stack = None + + def forward(self, x): + """ + Call the module + + Args: + x (`torch.tensor`): The input tensor to apply dropout + """ + if self.training and self.drop_prob > 0: + return XDropout.apply(x, self.get_context()) + return x + + def clear_context(self): + self.count = 0 + self.context_stack = None + + def init_context(self, reuse_mask=True, scale=1): + if self.context_stack is None: + self.context_stack = [] + self.count = 0 + for c in self.context_stack: + c.reuse_mask = reuse_mask + c.scale = scale + + def get_context(self): + if self.context_stack is not None: + if self.count >= len(self.context_stack): + self.context_stack.append(DropoutContext()) + ctx = self.context_stack[self.count] + ctx.dropout = self.drop_prob + self.count += 1 + return ctx + else: + return self.drop_prob + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm +class DebertaV2SelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2 +class DebertaV2Attention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = DisentangledSelfAttention(config) + self.output = DebertaV2SelfOutput(config) + self.config = config + + def forward( + self, + hidden_states, + attention_mask, + output_attentions=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + self_output = self.self( + hidden_states, + attention_mask, + output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if output_attentions: + self_output, att_matrix = self_output + if query_states is None: + query_states = hidden_states + attention_output = self.output(self_output, query_states) + + if output_attentions: + return (attention_output, att_matrix) + else: + return attention_output + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2 +class DebertaV2Intermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm +class DebertaV2Output(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2 +class DebertaV2Layer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = DebertaV2Attention(config) + self.intermediate = DebertaV2Intermediate(config) + self.output = DebertaV2Output(config) + + def forward( + self, + hidden_states, + attention_mask, + query_states=None, + relative_pos=None, + rel_embeddings=None, + output_attentions=False, + ): + attention_output = self.attention( + hidden_states, + attention_mask, + output_attentions=output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if output_attentions: + attention_output, att_matrix = attention_output + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + if output_attentions: + return (layer_output, att_matrix) + else: + return layer_output + + +class ConvLayer(nn.Module): + def __init__(self, config): + super().__init__() + kernel_size = getattr(config, "conv_kernel_size", 3) + groups = getattr(config, "conv_groups", 1) + self.conv_act = getattr(config, "conv_act", "tanh") + self.conv = nn.Conv1d( + config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups + ) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, residual_states, input_mask): + out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous() + rmask = (1 - input_mask).bool() + out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) + out = ACT2FN[self.conv_act](self.dropout(out)) + + layer_norm_input = residual_states + out + output = self.LayerNorm(layer_norm_input).to(layer_norm_input) + + if input_mask is None: + output_states = output + else: + if input_mask.dim() != layer_norm_input.dim(): + if input_mask.dim() == 4: + input_mask = input_mask.squeeze(1).squeeze(1) + input_mask = input_mask.unsqueeze(2) + + input_mask = input_mask.to(output.dtype) + output_states = output * input_mask + + return output_states + + +class DebertaV2Encoder(nn.Module): + """Modified BertEncoder with relative position bias support""" + + def __init__(self, config): + super().__init__() + + self.layer = nn.ModuleList([DebertaV2Layer(config) for _ in range(config.num_hidden_layers)]) + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + + self.position_buckets = getattr(config, "position_buckets", -1) + pos_ebd_size = self.max_relative_positions * 2 + + if self.position_buckets > 0: + pos_ebd_size = self.position_buckets * 2 + + # rel = nn.Parameter(torch.empty((pos_ebd_size, config.hidden_size))) + # self.rel_embeddings = nn.init.normal_(rel, mean=0.0, std=config.initializer_range) + self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size) + + self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")] + + if "layer_norm" in self.norm_rel_ebd: + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) + + self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None + self.gradient_checkpointing = False + + def get_rel_embedding(self): + att_span = self.position_buckets + rel_index = torch.arange(0, att_span * 2).long().to(self.rel_embeddings.weight.device) + rel_embeddings = self.rel_embeddings(rel_index) + # rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None + # rel_embeddings = self.rel_embeddings if self.relative_attention else None + if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd): + rel_embeddings = self.LayerNorm(rel_embeddings) + return rel_embeddings + + def get_attention_mask(self, attention_mask): + if attention_mask.dim() <= 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) + attention_mask = attention_mask.byte() + elif attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): + if self.relative_attention and relative_pos is None: + q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) + relative_pos = build_relative_position( + q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) + return relative_pos + + def forward( + self, + hidden_states, + attention_mask, + output_hidden_states=True, + output_attentions=False, + query_states=None, + relative_pos=None, + return_dict=True, + ): + if attention_mask.dim() <= 2: + input_mask = attention_mask + else: + input_mask = (attention_mask.sum(-2) > 0).byte() + attention_mask = self.get_attention_mask(attention_mask) + relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[0] + else: + next_kv = hidden_states + rel_embeddings = self.get_rel_embedding() + output_states = next_kv + for i, layer_module in enumerate(self.layer): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + output_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + next_kv, + attention_mask, + query_states, + relative_pos, + rel_embeddings, + ) + else: + output_states = layer_module( + next_kv, + attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + ) + + if output_attentions: + output_states, att_m = output_states + + if i == 0 and self.conv is not None: + output_states = self.conv(hidden_states, output_states, input_mask) + + if query_states is not None: + query_states = output_states + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None + else: + next_kv = output_states + + if output_attentions: + all_attentions = all_attentions + (att_m,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states,) + + if not return_dict: + return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +def make_log_bucket_position(relative_pos, bucket_size, max_position): + sign = np.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)) + log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid + bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int) + return bucket_pos + + +def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1): + """ + Build relative position according to the query and key + + We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key + \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q - + P_k\\) + + Args: + query_size (int): the length of query + key_size (int): the length of key + bucket_size (int): the size of position bucket + max_position (int): the maximum allowed absolute position + + Return: + `torch.LongTensor`: A tensor with shape [1, query_size, key_size] + + """ + q_ids = np.arange(0, query_size) + k_ids = np.arange(0, key_size) + rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1)) + if bucket_size > 0 and max_position > 0: + rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) + rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long) + rel_pos_ids = rel_pos_ids[:query_size, :] + rel_pos_ids = rel_pos_ids.unsqueeze(0) + return rel_pos_ids + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand +def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand +def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand +def pos_dynamic_expand(pos_index, p2c_att, key_layer): + return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))) + + +class DisentangledSelfAttention(nn.Module): + """ + Disentangled self-attention module + + Parameters: + config (`DebertaV2Config`): + A model config class instance with the configuration to build a new model. The schema is similar to + *BertConfig*, for more details, please refer [`DebertaV2Config`] + + """ + + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + self.num_attention_heads = config.num_attention_heads + _attention_head_size = config.hidden_size // config.num_attention_heads + self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + + self.share_att_key = getattr(config, "share_att_key", False) + self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.position_buckets = getattr(config, "position_buckets", -1) + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.pos_ebd_size = self.max_relative_positions + if self.position_buckets > 0: + self.pos_ebd_size = self.position_buckets + + self.pos_dropout = StableDropout(config.hidden_dropout_prob) + + if not self.share_att_key: + if "c2p" in self.pos_att_type: + self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + if "p2c" in self.pos_att_type: + self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = StableDropout(config.attention_probs_dropout_prob) + # self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) + + def transpose_for_scores(self, x, attention_heads): + new_x_shape = x.size()[:-1] + (attention_heads, -1) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1)) + + def forward( + self, + hidden_states, + attention_mask, + output_attentions=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + """ + Call the module + + Args: + hidden_states (`torch.FloatTensor`): + Input states to the module usually the output from previous layer, it will be the Q,K and V in + *Attention(Q,K,V)* + + attention_mask (`torch.ByteTensor`): + An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum + sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* + th token. + + output_attentions (`bool`, optional): + Whether return the attention matrix. + + query_states (`torch.FloatTensor`, optional): + The *Q* state in *Attention(Q,K,V)*. + + relative_pos (`torch.LongTensor`): + The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with + values ranging in [*-max_relative_positions*, *max_relative_positions*]. + + rel_embeddings (`torch.FloatTensor`): + The embedding of relative distances. It's a tensor of shape [\\(2 \\times + \\text{max_relative_positions}\\), *hidden_size*]. + + + """ + if query_states is None: + query_states = hidden_states + query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads) + key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads) + value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads) + + rel_att = None + # Take the dot product between "query" and "key" to get the raw attention scores. + scale_factor = 1 + if "c2p" in self.pos_att_type: + scale_factor += 1 + if "p2c" in self.pos_att_type: + scale_factor += 1 + scale = math.sqrt(query_layer.size(-1) * scale_factor) + attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale + if self.relative_attention: + rel_embeddings = self.pos_dropout(rel_embeddings) + rel_att = self.disentangled_attention_bias( + query_layer, key_layer, relative_pos, rel_embeddings, scale_factor + ) + + if rel_att is not None: + attention_scores = attention_scores + rel_att + attention_scores = attention_scores + attention_scores = attention_scores.view( + -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) + ) + + # bsz x height x length x dimension + attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_probs = self.dropout(attention_probs) + context_layer = torch.bmm( + attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer + ) + context_layer = ( + context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)) + .permute(0, 2, 1, 3) + .contiguous() + ) + new_context_layer_shape = context_layer.size()[:-2] + (-1,) + context_layer = context_layer.view(new_context_layer_shape) + if output_attentions: + return (context_layer, attention_probs) + else: + return context_layer + + def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): + if relative_pos is None: + q = query_layer.size(-2) + relative_pos = build_relative_position( + q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) + if relative_pos.dim() == 2: + relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) + elif relative_pos.dim() == 3: + relative_pos = relative_pos.unsqueeze(1) + # bsz x height x query x key + elif relative_pos.dim() != 4: + raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") + + att_span = self.pos_ebd_size + relative_pos = relative_pos.long().to(query_layer.device) + + # rel_index = torch.arange(0, att_span * 2).long().to(query_layer.device) + # rel_embeddings = rel_embeddings(rel_index).unsqueeze(0) + rel_embeddings = rel_embeddings.unsqueeze(0) + # rel_embeddings = rel_embeddings.unsqueeze(0) + # rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0) + if self.share_att_key: + pos_query_layer = self.transpose_for_scores( + self.query_proj(rel_embeddings), self.num_attention_heads + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) + pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) + else: + if "c2p" in self.pos_att_type: + pos_key_layer = self.transpose_for_scores( + self.pos_key_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) + if "p2c" in self.pos_att_type: + pos_query_layer = self.transpose_for_scores( + self.pos_query_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) + + score = 0 + # content->position + if "c2p" in self.pos_att_type: + scale = math.sqrt(pos_key_layer.size(-1) * scale_factor) + c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) + c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) + c2p_att = torch.gather( + c2p_att, + dim=-1, + index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), + ) + score += c2p_att / scale + + # position->content + if "p2c" in self.pos_att_type: + scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) + if key_layer.size(-2) != query_layer.size(-2): + r_pos = build_relative_position( + key_layer.size(-2), + key_layer.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + ).to(query_layer.device) + r_pos = r_pos.unsqueeze(0) + else: + r_pos = relative_pos + + p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) + p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) + p2c_att = torch.gather( + p2c_att, + dim=-1, + index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), + ).transpose(-1, -2) + score += p2c_att / scale + + return score + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm +class DebertaV2Embeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + pad_token_id = getattr(config, "pad_token_id", 0) + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id) + + self.position_biased_input = getattr(config, "position_biased_input", True) + if not self.position_biased_input: + self.position_embeddings = None + else: + self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size) + + if config.type_vocab_size > 0: + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size) + + if self.embedding_size != config.hidden_size: + self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.position_embeddings is not None: + position_embeddings = self.position_embeddings(position_ids.long()) + else: + position_embeddings = torch.zeros_like(inputs_embeds) + + embeddings = inputs_embeds + if self.position_biased_input: + embeddings += position_embeddings + if self.config.type_vocab_size > 0: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings += token_type_embeddings + + if self.embedding_size != self.config.hidden_size: + embeddings = self.embed_proj(embeddings) + + embeddings = self.LayerNorm(embeddings) + + if mask is not None: + if mask.dim() != embeddings.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(embeddings.dtype) + + embeddings = embeddings * mask + + embeddings = self.dropout(embeddings) + return embeddings + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaPreTrainedModel with Deberta->DebertaV2 +class DebertaV2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DebertaV2Config + base_model_prefix = "deberta" + _keys_to_ignore_on_load_missing = ["position_ids"] + _keys_to_ignore_on_load_unexpected = ["position_embeddings"] + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, DebertaV2Encoder): + module.gradient_checkpointing = value + + +DEBERTA_START_DOCSTRING = r""" + The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled + Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build + on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two + improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data. + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior.``` + + + Parameters: + config ([`DebertaV2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DEBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`DebertaV2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.", + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 +class DebertaV2Model(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = DebertaV2Embeddings(config) + self.encoder = DebertaV2Encoder(config) + self.z_steps = 0 + self.config = config + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings.word_embeddings = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError("The prune function is not implemented in DeBERTa model.") + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + mask=attention_mask, + inputs_embeds=inputs_embeds, + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask, + output_hidden_states=True, + output_attentions=output_attentions, + return_dict=return_dict, + ) + encoded_layers = encoder_outputs[1] + + if self.z_steps > 1: + hidden_states = encoded_layers[-2] + layers = [self.encoder.layer[-1] for _ in range(self.z_steps)] + query_states = encoded_layers[-1] + rel_embeddings = self.encoder.get_rel_embedding() + attention_mask = self.encoder.get_attention_mask(attention_mask) + rel_pos = self.encoder.get_rel_pos(embedding_output) + for layer in layers[1:]: + query_states = layer( + hidden_states, + attention_mask, + output_attentions=False, + query_states=query_states, + relative_pos=rel_pos, + rel_embeddings=rel_embeddings, + ) + encoded_layers.append(query_states) + + sequence_output = encoded_layers[-1] + + if not return_dict: + return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states if output_hidden_states else None, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2 +class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.deberta = DebertaV2Model(config) + self.cls = DebertaV2OnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta +class DebertaV2PredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta +class DebertaV2LMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = DebertaV2PredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta +class DebertaV2OnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = DebertaV2LMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +@add_start_docstrings( + """ + DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2 +class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + num_labels = getattr(config, "num_labels", 2) + self.num_labels = num_labels + + self.deberta = DebertaV2Model(config) + self.pooler = ContextPooler(config) + output_dim = self.pooler.output_dim + + self.classifier = nn.Linear(output_dim, num_labels) + drop_out = getattr(config, "cls_dropout", None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = StableDropout(drop_out) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.deberta.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.deberta.set_input_embeddings(new_embeddings) + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + encoder_layer = outputs[0] + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + # regression task + loss_fn = nn.MSELoss() + logits = logits.view(-1).to(labels.dtype) + loss = loss_fn(logits, labels.view(-1)) + elif labels.dim() == 1 or labels.size(-1) == 1: + label_index = (labels >= 0).nonzero() + labels = labels.long() + if label_index.size(0) > 0: + labeled_logits = torch.gather( + logits, 0, label_index.expand(label_index.size(0), logits.size(1)) + ) + labels = torch.gather(labels, 0, label_index.view(-1)) + loss_fct = CrossEntropyLoss() + loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) + else: + loss = torch.tensor(0).to(logits) + else: + log_softmax = nn.LogSoftmax(-1) + loss = -((log_softmax(logits) * labels).sum(-1)).mean() + elif self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@add_start_docstrings( + """ + DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2 +class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaV2Model(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@add_start_docstrings( + """ + DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering with Deberta->DebertaV2 +class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaV2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + DEBERTA_START_DOCSTRING, +) +class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + num_labels = getattr(config, "num_labels", 2) + self.num_labels = num_labels + + self.deberta = DebertaV2Model(config) + self.pooler = ContextPooler(config) + output_dim = self.pooler.output_dim + + self.classifier = nn.Linear(output_dim, 1) + drop_out = getattr(config, "cls_dropout", None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = StableDropout(drop_out) + + self.init_weights() + + def get_input_embeddings(self): + return self.deberta.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.deberta.set_input_embeddings(new_embeddings) + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.deberta( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + encoder_layer = outputs[0] + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/examples/language/roberta/pretraining/nvidia_bert_dataset_provider.py b/examples/language/roberta/pretraining/nvidia_bert_dataset_provider.py new file mode 100644 index 000000000..cce836913 --- /dev/null +++ b/examples/language/roberta/pretraining/nvidia_bert_dataset_provider.py @@ -0,0 +1,182 @@ +import os +import random +import h5py +import logging +import json +import time +from concurrent.futures import ProcessPoolExecutor + +import numpy as np + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.sampler import RandomSampler +from torch.utils.data.distributed import DistributedSampler + +from bert_dataset_provider import BertDatasetProviderInterface +import colossalai.utils as utils + +# Workaround because python functions are not picklable +class WorkerInitObj(object): + def __init__(self, seed): + self.seed = seed + + def __call__(self, id): + np.random.seed(seed=self.seed + id) + random.seed(self.seed + id) + + +def create_pretraining_dataset(input_file, max_predictions_per_seq, + num_workers, train_batch_size, worker_init, + data_sampler): + train_data = pretraining_dataset( + input_file=input_file, max_predictions_per_seq=max_predictions_per_seq) + train_dataloader = DataLoader(train_data, + sampler=data_sampler(train_data), + batch_size=train_batch_size, + num_workers=num_workers, + worker_init_fn=worker_init, + pin_memory=True + ) + return train_dataloader, len(train_data) + + +class pretraining_dataset(Dataset): + def __init__(self, input_file, max_predictions_per_seq): + self.input_file = input_file + self.max_predictions_per_seq = max_predictions_per_seq + f = h5py.File(input_file, "r") + keys = [ + 'input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions' + ] + self.inputs = [np.asarray(f[key][:]) for key in keys] + f.close() + + def __len__(self): + 'Denotes the total number of samples' + return len(self.inputs[0]) + + def __getitem__(self, index): + + [ + input_ids, input_mask, segment_ids, masked_lm_labels + ] = [ + torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else + torch.from_numpy(np.asarray(input[index].astype(np.int64))) + for indice, input in enumerate(self.inputs) + ] + + return [ + input_ids, input_mask, + segment_ids, masked_lm_labels + ] + + +class NvidiaBertDatasetProvider(BertDatasetProviderInterface): + def __init__(self, args, evaluate=False): + self.num_workers = args.num_workers + self.max_seq_length = args.max_seq_length + self.max_predictions_per_seq = args.max_predictions_per_seq + + self.gradient_accumulation_steps = args.gradient_accumulation_steps + if not evaluate: + self.train_micro_batch_size_per_gpu = args.train_micro_batch_size_per_gpu + else: + self.train_micro_batch_size_per_gpu = args.eval_micro_batch_size_per_gpu + self.logger = args.logger + + self.global_rank = dist.get_rank() + self.world_size = dist.get_world_size() + + # Initialize dataset files + if not evaluate: + self.dataset_files = [ + os.path.join(args.data_path_prefix, f) for f in os.listdir(args.data_path_prefix) if + os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f + ] + else: + self.dataset_files = [ + os.path.join(args.eval_data_path_prefix, f) for f in os.listdir(args.eval_data_path_prefix) if + os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f + ] + + self.dataset_files.sort() + # random.shuffle(self.dataset_files) + self.num_files = len(self.dataset_files) + # self.data_sampler = RandomSampler + self.data_sampler = DistributedSampler + + self.worker_init = WorkerInitObj(args.seed + args.local_rank) + self.dataset_future = None + self.pool = ProcessPoolExecutor(1) + self.data_file = None + self.shuffle = True + + if self.global_rank == 0: + self.logger.info( + f"NvidiaBertDatasetProvider - Initialization: num_files = {self.num_files}" + ) + + def get_shard(self, index): + start = time.time() + if self.dataset_future is None: + self.data_file = self._get_shard_file(index) + self.train_dataloader, sample_count = create_pretraining_dataset( + input_file=self.data_file, + max_predictions_per_seq=self.max_predictions_per_seq, + num_workers=self.num_workers, + train_batch_size=self.train_micro_batch_size_per_gpu, + worker_init=self.worker_init, + data_sampler=self.data_sampler) + else: + self.train_dataloader, sample_count = self.dataset_future.result( + timeout=None) + + self.logger.info( + f"Data Loading Completed for Pretraining Data from {self.data_file} with {sample_count} samples took {time.time()-start:.2f}s." + ) + + return self.train_dataloader, sample_count + + def release_shard(self): + del self.train_dataloader + self.pool.shutdown() + + def prefetch_shard(self, index): + self.data_file = self._get_shard_file(index) + self.dataset_future = self.pool.submit( + create_pretraining_dataset, self.data_file, + self.max_predictions_per_seq, self.num_workers, + self.train_micro_batch_size_per_gpu, self.worker_init, + self.data_sampler) + + def get_batch(self, batch_iter): + return batch_iter + + def prefetch_batch(self): + pass + + def _get_shard_file(self, shard_index): + file_index = self._get_shard_file_index(shard_index, self.global_rank) + return self.dataset_files[file_index] + + def _get_shard_file_index(self, shard_index, global_rank): + # if dist.is_initialized() and self.world_size > self.num_files: + # remainder = self.world_size % self.num_files + # file_index = (shard_index * self.world_size) + global_rank + ( + # remainder * shard_index) + # else: + # file_index = shard_index * self.world_size + global_rank + + return shard_index % self.num_files + + def shuffle_dataset(self, epoch): + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.num_files, generator=g).tolist() + new_dataset = [self.dataset_files[i] for i in indices] + self.dataset_files = new_dataset + \ No newline at end of file diff --git a/examples/language/roberta/pretraining/pretrain_utils.py b/examples/language/roberta/pretraining/pretrain_utils.py new file mode 100644 index 000000000..ba17b0f5e --- /dev/null +++ b/examples/language/roberta/pretraining/pretrain_utils.py @@ -0,0 +1,112 @@ +import transformers +import logging +from colossalai.nn.lr_scheduler import LinearWarmupLR +from transformers import get_linear_schedule_with_warmup +from transformers import BertForPreTraining, RobertaForMaskedLM, RobertaConfig +from transformers import GPT2Config, GPT2LMHeadModel +from transformers import AutoTokenizer, AutoModelForMaskedLM +from colossalai.nn.optimizer import FusedAdam +from torch.optim import AdamW +from colossalai.core import global_context as gpc +import torch +import os +import sys +sys.path.append(os.getcwd()) +from model.deberta_v2 import DebertaV2ForMaskedLM +from model.bert import BertForMaskedLM +import torch.nn as nn + +from collections import OrderedDict + +__all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining'] + + +def get_new_state_dict(state_dict, start_index=13): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[start_index:] + new_state_dict[name] = v + return new_state_dict + + +class LMModel(nn.Module): + def __init__(self, model, config, args): + super().__init__() + + self.checkpoint = args.checkpoint_activations + self.config = config + self.model = model + if self.checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, token_type_ids=None, attention_mask=None): + # Only return lm_logits + return self.model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + +def get_model(args, logger): + + if args.mlm == 'bert': + config = transformers.BertConfig.from_json_file(args.bert_config) + model = BertForMaskedLM(config) + elif args.mlm == 'deberta_v2': + config = transformers.DebertaV2Config.from_json_file(args.bert_config) + model = DebertaV2ForMaskedLM(config) + else: + raise Exception("Invalid mlm!") + + if len(args.load_pretrain_model) > 0: + assert os.path.exists(args.load_pretrain_model) + # load_checkpoint(args.load_pretrain_model, model, strict=False) + m_state_dict = torch.load(args.load_pretrain_model, map_location=torch.device(f"cuda:{torch.cuda.current_device()}")) + # new_state_dict = get_new_state_dict(m_state_dict) + model.load_state_dict(m_state_dict, strict=True) # must insure that every process have identical parameters !!!!!!! + logger.info("load model success") + + numel = sum([p.numel() for p in model.parameters()]) + if args.checkpoint_activations: + model.gradient_checkpointing_enable() + # model = LMModel(model, config, args) + + return config, model, numel + + +def get_optimizer(model, lr): + param_optimizer = list(model.named_parameters()) + no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] + + # configure the weight decay for bert models + optimizer_grouped_parameters = [{ + 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], + 'weight_decay': 0.1 + }, { + 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], + 'weight_decay': 0.0 + }] + optimizer = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=[0.9, 0.95]) + return optimizer + + +def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1): + # warmup_steps = int(total_steps * warmup_ratio) + lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, last_epoch=last_epoch) + # lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps) + return lr_scheduler + + +def save_ckpt(model, optimizer, lr_scheduler, path, epoch, shard, global_step): + model_path = path + '_pytorch_model.bin' + optimizer_lr_path = path + '.op_lrs' + checkpoint = {} + checkpoint['optimizer'] = optimizer.state_dict() + checkpoint['lr_scheduler'] = lr_scheduler.state_dict() + checkpoint['epoch'] = epoch + checkpoint['shard'] = shard + checkpoint['global_step'] = global_step + model_state = model.state_dict() #each process must run model.state_dict() + if gpc.get_global_rank() == 0: + torch.save(checkpoint, optimizer_lr_path) + torch.save(model_state, model_path) + + + diff --git a/examples/language/roberta/pretraining/run_pretrain.sh b/examples/language/roberta/pretraining/run_pretrain.sh new file mode 100644 index 000000000..144cd0ab9 --- /dev/null +++ b/examples/language/roberta/pretraining/run_pretrain.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env sh + +root_path=$PWD +PY_FILE_PATH="$root_path/run_pretraining.py" + +tensorboard_path="$root_path/tensorboard" +log_path="$root_path/exp_log" +ckpt_path="$root_path/ckpt" + +colossal_config="$root_path/../configs/colossalai_ddp.py" + +mkdir -p $tensorboard_path +mkdir -p $log_path +mkdir -p $ckpt_path + +export PYTHONPATH=$PWD + +env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \ + --include GPU002,GPU003,GPU004,GPU007 \ + --nproc_per_node=8 \ + $PY_FILE_PATH \ + --master_addr GPU007 \ + --master_port 20024 \ + --lr 2.0e-4 \ + --train_micro_batch_size_per_gpu 190 \ + --eval_micro_batch_size_per_gpu 20 \ + --epoch 15 \ + --data_path_prefix /h5 \ + --eval_data_path_prefix /eval_h5 \ + --tokenizer_path /roberta \ + --bert_config /roberta/config.json \ + --tensorboard_path $tensorboard_path \ + --log_path $log_path \ + --ckpt_path $ckpt_path \ + --colossal_config $colossal_config \ + --log_interval 50 \ + --mlm bert \ + --wandb \ + --checkpoint_activations \ + \ No newline at end of file diff --git a/examples/language/roberta/pretraining/run_pretrain_resume.sh b/examples/language/roberta/pretraining/run_pretrain_resume.sh new file mode 100644 index 000000000..a0704cf7c --- /dev/null +++ b/examples/language/roberta/pretraining/run_pretrain_resume.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env sh + +root_path=$PWD +PY_FILE_PATH="$root_path/run_pretraining.py" + +tensorboard_path="$root_path/tensorboard" +log_path="$root_path/exp_log" +ckpt_path="$root_path/ckpt" + +colossal_config="$root_path/../configs/colossalai_ddp.py" + +mkdir -p $tensorboard_path +mkdir -p $log_path +mkdir -p $ckpt_path + +export PYTHONPATH=$PWD + +env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \ + --include GPU002,GPU003,GPU004,GPU007 \ + --nproc_per_node=8 \ + $PY_FILE_PATH \ + --master_addr GPU007 \ + --master_port 20024 \ + --lr 2.0e-4 \ + --train_micro_batch_size_per_gpu 190 \ + --eval_micro_batch_size_per_gpu 20 \ + --epoch 15 \ + --data_path_prefix /h5 \ + --eval_data_path_prefix /eval_h5 \ + --tokenizer_path /roberta \ + --bert_config /roberta/config.json \ + --tensorboard_path $tensorboard_path \ + --log_path $log_path \ + --ckpt_path $ckpt_path \ + --colossal_config $colossal_config \ + --log_interval 50 \ + --mlm bert \ + --wandb \ + --checkpoint_activations \ + --resume_train \ + --load_pretrain_model /ckpt/1.pt \ + --load_optimizer_lr /ckpt/1.op_lrs \ + \ No newline at end of file diff --git a/examples/language/roberta/pretraining/run_pretraining.py b/examples/language/roberta/pretraining/run_pretraining.py new file mode 100644 index 000000000..9840a122c --- /dev/null +++ b/examples/language/roberta/pretraining/run_pretraining.py @@ -0,0 +1,226 @@ +import colossalai +import math +import torch +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +import colossalai.nn as col_nn +from arguments import parse_args +from pretrain_utils import get_model, get_optimizer, get_lr_scheduler, save_ckpt +from utils.exp_util import get_tflops, get_mem_info, throughput_calculator, log_args +from utils.global_vars import set_global_variables, get_timers, get_tensorboard_writer +from utils.logger import Logger +from evaluation import evaluate +from loss import LossForPretraining + +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import TensorShardStrategy +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_optim import ShardedOptimizerV2 +from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider +from tqdm import tqdm +import os +import time +from functools import partial + +from transformers import AutoTokenizer + +from colossalai.gemini import ChunkManager, GeminiManager +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.utils import get_current_device +from colossalai.nn.parallel import ZeroDDP +from colossalai.zero import ZeroOptimizer +from colossalai.tensor import ProcessGroup +from colossalai.nn.optimizer import HybridAdam + + +def main(): + + args = parse_args() + launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + + os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + + logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug) + + if args.vscode_debug: + colossalai.launch(config={}, + rank=args.rank, + world_size=args.world_size, + host=args.host, + port=args.port, + backend=args.backend) + args.local_rank = -1 + args.log_interval = 1 + else: + colossalai.launch_from_torch(args.colossal_config) #args.colossal_config + args.local_rank = int(os.environ["LOCAL_RANK"]) + logger.info(f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + + f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}') + + log_args(logger, args) + args.tokenizer = tokenizer + args.logger = logger + set_global_variables(launch_time, args.tensorboard_path) + + use_zero = hasattr(gpc.config, 'zero') + world_size = torch.distributed.get_world_size() + + # build model, optimizer and criterion + if use_zero: + shard_strategy = TensorShardStrategy() + with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, + shard_param=True): + + config, model, numel = get_model(args, logger) + # model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True) + else: + config, model, numel = get_model(args, logger) + logger.info("no_zero") + if torch.distributed.get_rank() == 0: + os.mkdir(os.path.join(args.ckpt_path, launch_time)) + + logger.info(f'Model numel: {numel}') + + get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length) + steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader) + total_steps = steps_per_epoch * args.epoch + + # build optimizer and lr_scheduler + + start_epoch = 0 + start_shard = 0 + global_step = 0 + if args.resume_train: + assert os.path.exists(args.load_optimizer_lr) + o_l_state_dict = torch.load(args.load_optimizer_lr, map_location='cpu') + o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1 + optimizer = get_optimizer(model, lr=args.lr) + optimizer.load_state_dict(o_l_state_dict['optimizer']) + lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) #o_l_state_dict['lr_scheduler']['last_epoch'] + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}") + # if you want delete the above three code, have to move the model to gpu, because in optimizer.step() + lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler']) + + start_epoch = o_l_state_dict['epoch'] + start_shard = o_l_state_dict['shard'] + 1 + # global_step = o_l_state_dict['global_step'] + 1 + logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}') + else: + optimizer = get_optimizer(model, lr=args.lr) + lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) + + # optimizer = gpc.config.optimizer.pop('type')( + # model.parameters(), **gpc.config.optimizer) + # optimizer = ShardedOptimizerV2(model, optimizer, initial_scale=2**5) + criterion = LossForPretraining(config.vocab_size) + + # build dataloader + pretrain_dataset_provider = NvidiaBertDatasetProvider(args) + + # initialize with colossalai + engine, _, _, lr_scheduelr = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + lr_scheduler=lr_scheduler) + + logger.info(get_mem_info(prefix='After init model, ')) + + + best_loss = None + eval_loss = 0 + train_loss = 0 + timers = get_timers() + timers('interval_time').start() + timers('epoch_time').start() + timers('shard_time').start() + + for epoch in range(start_epoch, args.epoch): + + for shard in range(start_shard, len(os.listdir(args.data_path_prefix))): + + dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard) + # pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload + if torch.distributed.get_rank() == 0: + iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1) + else: + iterator_data = enumerate(dataset_iterator) + + engine.train() + + for step, batch_data in iterator_data: + + # batch_data = pretrain_dataset_provider.get_batch(batch_index) + input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}") + attention_mask = batch_data[1].cuda(f"cuda:{torch.cuda.current_device()}") + token_type_ids = batch_data[2].cuda(f"cuda:{torch.cuda.current_device()}") + mlm_label = batch_data[3].cuda(f"cuda:{torch.cuda.current_device()}") + # nsp_label = batch_data[5].cuda() + + output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + loss = engine.criterion(output.logits, mlm_label) + pretrain_dataset_provider.prefetch_batch() + + engine.backward(loss) + train_loss += loss.float().item() + # if (step + 1) % args.accumulation_step == 0: + engine.step() + lr_scheduelr.step() + engine.zero_grad() + + global_step += 1 + + if global_step % args.log_interval == 0 and global_step != 0 \ + and torch.distributed.get_rank() == 0: + elapsed_time = timers('interval_time').elapsed(reset=False) + elapsed_time_per_iteration = elapsed_time / global_step + samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(numel, args, config, elapsed_time, global_step, world_size) + + cur_loss = train_loss / args.log_interval + current_lr = lr_scheduelr.get_last_lr()[0] + log_str = f'| epoch: {epoch} | shard: {shard} | step: {global_step} | lr {current_lr:.7f} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \ + f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {math.exp(cur_loss):.3f} | TFLOPS: {get_tflops_func(elapsed_time_per_iteration):.3f} or {tflops:.3f}' + logger.info(log_str, print_=False) + + if args.wandb: + tensorboard_log = get_tensorboard_writer() + tensorboard_log.log_train({ + 'lr': current_lr, + 'loss': cur_loss, + 'ppl': math.exp(cur_loss), + 'mins_batch': elapsed_time_per_iteration + }, global_step) + + train_loss = 0 + + logger.info(f'epoch {epoch} shard {shard} has cost {timers("shard_time").elapsed() / 60 :.3f} mins') + logger.info('*' * 100) + + eval_loss += evaluate(engine, args, logger, global_step) + save_ckpt(engine.model, optimizer, lr_scheduelr, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step) + + + eval_loss /= len(os.listdir(args.data_path_prefix)) + logger.info(f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' + \ + f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}') + logger.info('-' * 100) + if args.wandb and torch.distributed.get_rank() == 0: + tensorboard_log = get_tensorboard_writer() + tensorboard_log.log_eval({ + 'all_eval_shard_loss': eval_loss, + }, epoch) + start_shard = 0 + eval_loss = 0 + + pretrain_dataset_provider.release_shard() + + logger.info('Congratulation, training has finished!!!') + + +if __name__ == '__main__': + main() diff --git a/examples/language/roberta/pretraining/utils/WandbLog.py b/examples/language/roberta/pretraining/utils/WandbLog.py new file mode 100644 index 000000000..9dd28a981 --- /dev/null +++ b/examples/language/roberta/pretraining/utils/WandbLog.py @@ -0,0 +1,46 @@ +import time +import wandb +import os +from torch.utils.tensorboard import SummaryWriter + +class WandbLog: + + @classmethod + def init_wandb(cls, project, notes=None, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None): + wandb.init(project=project, notes=notes, name=name, config=config) + + @classmethod + def log(cls, result, model=None, gradient=None): + wandb.log(result) + + if model: + wandb.watch(model) + + if gradient: + wandb.watch(gradient) + + +class TensorboardLog: + + def __init__(self, location, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None): + if not os.path.exists(location): + os.mkdir(location) + self.writer = SummaryWriter(location, comment=name) + + def log_train(self, result, step): + for k, v in result.items(): + self.writer.add_scalar(f'{k}/train', v, step) + + def log_eval(self, result, step): + for k, v in result.items(): + self.writer.add_scalar(f'{k}/eval', v, step) + + def log_zeroshot(self, result, step): + for k, v in result.items(): + self.writer.add_scalar(f'{k}_acc/eval', v, step) + + + + + + diff --git a/examples/language/roberta/pretraining/utils/exp_util.py b/examples/language/roberta/pretraining/utils/exp_util.py new file mode 100644 index 000000000..a02b0872a --- /dev/null +++ b/examples/language/roberta/pretraining/utils/exp_util.py @@ -0,0 +1,99 @@ +import functools +import os, shutil +import torch +import psutil +from colossalai.core import global_context as gpc + +def logging(s, log_path, print_=True, log_=True): + if print_: + print(s) + if log_: + with open(log_path, 'a+') as f_log: + f_log.write(s + '\n') + +def get_logger(log_path, **kwargs): + return functools.partial(logging, log_path=log_path, **kwargs) + +def create_exp_dir(dir_path, scripts_to_save=None, debug=False): + if debug: + print('Debug Mode : no experiment dir created') + return functools.partial(logging, log_path=None, log_=False) + + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + print('Experiment dir : {}'.format(dir_path)) + if scripts_to_save is not None: + script_path = os.path.join(dir_path, 'scripts') + if not os.path.exists(script_path): + os.makedirs(script_path) + for script in scripts_to_save: + dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script)) + shutil.copyfile(script, dst_file) + + return get_logger(log_path=os.path.join(dir_path, 'log.txt')) + +def get_cpu_mem(): + return psutil.Process().memory_info().rss / 1024**2 + + +def get_gpu_mem(): + return torch.cuda.memory_allocated() / 1024**2 + + +def get_mem_info(prefix=''): + return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + +def get_parameters_in_billions(model, world_size=1): + gpus_per_model = world_size + + approx_parameters_in_billions = sum([sum([p.ds_numel if hasattr(p,'ds_id') else p.nelement() for p in model_module.parameters()]) + for model_module in model]) + + return approx_parameters_in_billions * gpus_per_model / (1e9) + +def throughput_calculator(numel, args, config, iteration_time, total_iterations, world_size=1): + gpus_per_model = 1 + batch_size = args.train_micro_batch_size_per_gpu + samples_per_model = batch_size * args.max_seq_length + model_replica_count = world_size / gpus_per_model + approx_parameters_in_billions = numel + elapsed_time_per_iter = iteration_time / total_iterations + samples_per_second = batch_size / elapsed_time_per_iter + + #flops calculator + hidden_size = config.hidden_size + num_layers = config.num_hidden_layers + vocab_size = config.vocab_size + + # General TFLOPs formula (borrowed from Equation 3 in Section 5.1 of + # https://arxiv.org/pdf/2104.04473.pdf). + # The factor of 4 is when used with activation check-pointing, + # otherwise it will be 3. + checkpoint_activations_factor = 4 if args.checkpoint_activations else 3 + flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers * (hidden_size**2)) * (1. + (args.max_seq_length / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size))) + tflops = flops_per_iteration / (elapsed_time_per_iter * (10**12)) + return samples_per_second, tflops, approx_parameters_in_billions + +def synchronize(): + if not torch.distributed.is_available(): + return + if not torch.distributed.is_intialized(): + return + world_size = torch.distributed.get_world_size() + if world_size == 1: + return + torch.distributed.barrier() + +def log_args(logger, args): + logger.info('--------args----------') + message = '\n'.join([f'{k:<30}: {v}' for k, v in vars(args).items()]) + message += '\n' + message += '\n'.join([f'{k:<30}: {v}' for k, v in gpc.config.items()]) + logger.info(message) + logger.info('--------args----------\n') \ No newline at end of file diff --git a/examples/language/roberta/pretraining/utils/global_vars.py b/examples/language/roberta/pretraining/utils/global_vars.py new file mode 100644 index 000000000..363cbf91c --- /dev/null +++ b/examples/language/roberta/pretraining/utils/global_vars.py @@ -0,0 +1,126 @@ +import time +import torch +from .WandbLog import TensorboardLog + +_GLOBAL_TIMERS = None +_GLOBAL_TENSORBOARD_WRITER = None + + +def set_global_variables(launch_time, tensorboard_path): + _set_timers() + _set_tensorboard_writer(launch_time, tensorboard_path) + +def _set_timers(): + """Initialize timers.""" + global _GLOBAL_TIMERS + _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') + _GLOBAL_TIMERS = Timers() + +def _set_tensorboard_writer(launch_time, tensorboard_path): + """Set tensorboard writer.""" + global _GLOBAL_TENSORBOARD_WRITER + _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, + 'tensorboard writer') + if torch.distributed.get_rank() == 0: + _GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f'/{launch_time}', launch_time) + +def get_timers(): + """Return timers.""" + _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') + return _GLOBAL_TIMERS + +def get_tensorboard_writer(): + """Return tensorboard writer. It can be None so no need + to check if it is initialized.""" + return _GLOBAL_TENSORBOARD_WRITER + +def _ensure_var_is_initialized(var, name): + """Make sure the input variable is not None.""" + assert var is not None, '{} is not initialized.'.format(name) + + +def _ensure_var_is_not_initialized(var, name): + """Make sure the input variable is not None.""" + assert var is None, '{} is already initialized.'.format(name) + + +class _Timer: + """Timer.""" + + def __init__(self, name): + self.name_ = name + self.elapsed_ = 0.0 + self.started_ = False + self.start_time = time.time() + + def start(self): + """Start the timer.""" + # assert not self.started_, 'timer has already been started' + torch.cuda.synchronize() + self.start_time = time.time() + self.started_ = True + + def stop(self): + """Stop the timer.""" + assert self.started_, 'timer is not started' + torch.cuda.synchronize() + self.elapsed_ += (time.time() - self.start_time) + self.started_ = False + + def reset(self): + """Reset timer.""" + self.elapsed_ = 0.0 + self.started_ = False + + def elapsed(self, reset=True): + """Calculate the elapsed time.""" + started_ = self.started_ + # If the timing in progress, end it first. + if self.started_: + self.stop() + # Get the elapsed time. + elapsed_ = self.elapsed_ + # Reset the elapsed time + if reset: + self.reset() + # If timing was in progress, set it back. + if started_: + self.start() + return elapsed_ + + +class Timers: + """Group of timers.""" + + def __init__(self): + self.timers = {} + + def __call__(self, name): + if name not in self.timers: + self.timers[name] = _Timer(name) + return self.timers[name] + + def write(self, names, writer, iteration, normalizer=1.0, reset=False): + """Write timers to a tensorboard writer""" + # currently when using add_scalars, + # torch.utils.add_scalars makes each timer its own run, which + # polutes the runs list, so we just add each as a scalar + assert normalizer > 0.0 + for name in names: + value = self.timers[name].elapsed(reset=reset) / normalizer + writer.add_scalar(name + '-time', value, iteration) + + def log(self, names, normalizer=1.0, reset=True): + """Log a group of timers.""" + assert normalizer > 0.0 + string = 'time (ms)' + for name in names: + elapsed_time = self.timers[name].elapsed( + reset=reset) * 1000.0 / normalizer + string += ' | {}: {:.2f}'.format(name, elapsed_time) + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == ( + torch.distributed.get_world_size() - 1): + print(string, flush=True) + else: + print(string, flush=True) diff --git a/examples/language/roberta/pretraining/utils/logger.py b/examples/language/roberta/pretraining/utils/logger.py new file mode 100644 index 000000000..481c4c6ce --- /dev/null +++ b/examples/language/roberta/pretraining/utils/logger.py @@ -0,0 +1,31 @@ +import os +import logging +import torch.distributed as dist + +logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO) +logger = logging.getLogger(__name__) + + +class Logger(): + def __init__(self, log_path, cuda=False, debug=False): + self.logger = logging.getLogger(__name__) + self.cuda = cuda + self.log_path = log_path + self.debug = debug + + + def info(self, message, log_=True, print_=True, *args, **kwargs): + if (self.cuda and dist.get_rank() == 0) or not self.cuda: + if print_: + self.logger.info(message, *args, **kwargs) + + if log_: + with open(self.log_path, 'a+') as f_log: + f_log.write(message + '\n') + + + def error(self, message, *args, **kwargs): + self.logger.error(message, *args, **kwargs)