mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-27 03:26:41 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,2 +0,0 @@
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from .layers.init_method import init_normal, output_init_normal
|
||||
|
||||
|
||||
class BertForPretrain(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size,
|
||||
@@ -34,7 +33,9 @@ class BertForPretrain(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size'
|
||||
assert (
|
||||
max_sequence_length % self.seq_parallel_size == 0
|
||||
), "sequence length is not divisible by the sequence parallel size"
|
||||
self.sub_seq_length = max_sequence_length // self.seq_parallel_size
|
||||
self.init_std = init_std
|
||||
self.num_layers = num_layers
|
||||
@@ -43,28 +44,32 @@ class BertForPretrain(nn.Module):
|
||||
num_tokentypes = 0
|
||||
|
||||
self.preprocessor = PreProcessor(self.sub_seq_length)
|
||||
self.embedding = Embedding(hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dropout_prob=dropout_prob,
|
||||
num_tokentypes=num_tokentypes)
|
||||
self.embedding = Embedding(
|
||||
hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dropout_prob=dropout_prob,
|
||||
num_tokentypes=num_tokentypes,
|
||||
)
|
||||
self.bert_layers = nn.ModuleList()
|
||||
|
||||
for i in range(num_layers):
|
||||
bert_layer = BertLayer(layer_number=i + 1,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=dropout_prob,
|
||||
mlp_ratio=mlp_ratio,
|
||||
hidden_dropout=dropout_prob,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
is_naive_fp16=is_naive_fp16)
|
||||
bert_layer = BertLayer(
|
||||
layer_number=i + 1,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=dropout_prob,
|
||||
mlp_ratio=mlp_ratio,
|
||||
hidden_dropout=dropout_prob,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
is_naive_fp16=is_naive_fp16,
|
||||
)
|
||||
self.bert_layers.append(bert_layer)
|
||||
|
||||
self.layer_norm = LayerNorm(hidden_size)
|
||||
self.head = BertDualHead(hidden_size,
|
||||
self.embedding.word_embedding_weight.size(0),
|
||||
add_binary_head=add_binary_head)
|
||||
self.head = BertDualHead(
|
||||
hidden_size, self.embedding.word_embedding_weight.size(0), add_binary_head=add_binary_head
|
||||
)
|
||||
self.reset_parameters()
|
||||
|
||||
def _init_normal(self, tensor):
|
||||
@@ -122,27 +127,30 @@ class BertForPretrain(nn.Module):
|
||||
|
||||
|
||||
class PipelineBertForPretrain(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
max_sequence_length,
|
||||
num_attention_heads,
|
||||
num_layers,
|
||||
add_binary_head,
|
||||
is_naive_fp16,
|
||||
num_tokentypes=2,
|
||||
dropout_prob=0.1,
|
||||
mlp_ratio=4,
|
||||
init_std=0.02,
|
||||
convert_fp16_to_fp32_in_softmax=False,
|
||||
first_stage=True,
|
||||
last_stage=True,
|
||||
start_idx=None,
|
||||
end_idx=None):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
max_sequence_length,
|
||||
num_attention_heads,
|
||||
num_layers,
|
||||
add_binary_head,
|
||||
is_naive_fp16,
|
||||
num_tokentypes=2,
|
||||
dropout_prob=0.1,
|
||||
mlp_ratio=4,
|
||||
init_std=0.02,
|
||||
convert_fp16_to_fp32_in_softmax=False,
|
||||
first_stage=True,
|
||||
last_stage=True,
|
||||
start_idx=None,
|
||||
end_idx=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size'
|
||||
assert (
|
||||
max_sequence_length % self.seq_parallel_size == 0
|
||||
), "sequence length is not divisible by the sequence parallel size"
|
||||
self.sub_seq_length = max_sequence_length // self.seq_parallel_size
|
||||
self.init_std = init_std
|
||||
self.num_layers = num_layers
|
||||
@@ -156,11 +164,13 @@ class PipelineBertForPretrain(nn.Module):
|
||||
self.preprocessor = PreProcessor(self.sub_seq_length)
|
||||
|
||||
if self.first_stage:
|
||||
self.embedding = Embedding(hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dropout_prob=dropout_prob,
|
||||
num_tokentypes=num_tokentypes)
|
||||
self.embedding = Embedding(
|
||||
hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dropout_prob=dropout_prob,
|
||||
num_tokentypes=num_tokentypes,
|
||||
)
|
||||
|
||||
# transformer layers
|
||||
self.bert_layers = nn.ModuleList()
|
||||
@@ -170,14 +180,16 @@ class PipelineBertForPretrain(nn.Module):
|
||||
end_idx = num_layers
|
||||
|
||||
for i in range(start_idx, end_idx):
|
||||
bert_layer = BertLayer(layer_number=i + 1,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=dropout_prob,
|
||||
mlp_ratio=mlp_ratio,
|
||||
hidden_dropout=dropout_prob,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
is_naive_fp16=is_naive_fp16)
|
||||
bert_layer = BertLayer(
|
||||
layer_number=i + 1,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=dropout_prob,
|
||||
mlp_ratio=mlp_ratio,
|
||||
hidden_dropout=dropout_prob,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
is_naive_fp16=is_naive_fp16,
|
||||
)
|
||||
self.bert_layers.append(bert_layer)
|
||||
|
||||
if self.last_stage:
|
||||
@@ -256,7 +268,7 @@ def _filter_kwargs(func, kwargs):
|
||||
return {k: v for k, v in kwargs.items() if k in sig.parameters}
|
||||
|
||||
|
||||
def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||||
def build_pipeline_bert(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
||||
logger = get_dist_logger()
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
@@ -265,12 +277,12 @@ def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **k
|
||||
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
|
||||
models = []
|
||||
for start, end in parts:
|
||||
kwargs['num_layers'] = num_layers
|
||||
kwargs['start_idx'] = start
|
||||
kwargs['end_idx'] = end
|
||||
kwargs['first_stage'] = start == 0
|
||||
kwargs['last_stage'] = end == num_layers
|
||||
logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
|
||||
kwargs["num_layers"] = num_layers
|
||||
kwargs["start_idx"] = start
|
||||
kwargs["end_idx"] = end
|
||||
kwargs["first_stage"] = start == 0
|
||||
kwargs["last_stage"] = end == num_layers
|
||||
logger.info(f"Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers")
|
||||
chunk = PipelineBertForPretrain(**_filter_kwargs(PipelineBertForPretrain.__init__, kwargs)).to(device)
|
||||
if start == 0:
|
||||
wrapper.register_module(chunk.embedding.word_embeddings)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .embedding import VocabEmbedding, Embedding
|
||||
from .bert_layer import BertLayer
|
||||
from .embedding import Embedding, VocabEmbedding
|
||||
from .head import BertDualHead
|
||||
from .preprocess import PreProcessor
|
||||
|
||||
@@ -20,18 +20,20 @@ class BertLayer(nn.Module):
|
||||
output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
layer_number,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout,
|
||||
mlp_ratio,
|
||||
hidden_dropout,
|
||||
is_naive_fp16,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
fp32_residual_connection=False,
|
||||
bias_dropout_fusion: bool = True,
|
||||
convert_fp16_to_fp32_in_softmax: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
layer_number,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout,
|
||||
mlp_ratio,
|
||||
hidden_dropout,
|
||||
is_naive_fp16,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
fp32_residual_connection=False,
|
||||
bias_dropout_fusion: bool = True,
|
||||
convert_fp16_to_fp32_in_softmax: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_number = layer_number
|
||||
|
||||
@@ -50,7 +52,8 @@ class BertLayer(nn.Module):
|
||||
layer_number=layer_number,
|
||||
apply_query_key_layer_scaling=True,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
fp16=is_naive_fp16)
|
||||
fp16=is_naive_fp16,
|
||||
)
|
||||
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.bias_dropout_fusion = bias_dropout_fusion
|
||||
@@ -90,8 +93,9 @@ class BertLayer(nn.Module):
|
||||
|
||||
# re-enable torch grad to enable fused optimization.
|
||||
with torch.enable_grad():
|
||||
layernorm_input = bias_dropout_add_func(attention_output, attention_bias.expand_as(residual), residual,
|
||||
self.hidden_dropout)
|
||||
layernorm_input = bias_dropout_add_func(
|
||||
attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout
|
||||
)
|
||||
|
||||
# Layer norm post the self attention.
|
||||
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
|
||||
|
||||
def bias_dropout_add(x, bias, residual, prob, training):
|
||||
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
|
||||
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
|
||||
@@ -10,4 +11,5 @@ def bias_dropout_add(x, bias, residual, prob, training):
|
||||
def get_bias_dropout_add(training):
|
||||
def _bias_dropout_add(x, bias, residual, prob):
|
||||
return bias_dropout_add(x, bias, residual, prob, training)
|
||||
return _bias_dropout_add
|
||||
|
||||
return _bias_dropout_add
|
||||
|
||||
@@ -5,7 +5,6 @@ import torch.nn.init as init
|
||||
|
||||
|
||||
class VocabEmbedding(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_embeddings, embedding_dim):
|
||||
super(VocabEmbedding, self).__init__()
|
||||
# Keep the input dimensions.
|
||||
@@ -13,26 +12,29 @@ class VocabEmbedding(torch.nn.Module):
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = None
|
||||
self.max_norm = None
|
||||
self.norm_type = 2.
|
||||
self.norm_type = 2.0
|
||||
self.scale_grad_by_freq = False
|
||||
self.sparse = False
|
||||
self._weight = None
|
||||
|
||||
# Allocate weights and initialize.
|
||||
self.weight = nn.Parameter(torch.empty(
|
||||
self.num_embeddings, self.embedding_dim))
|
||||
self.weight = nn.Parameter(torch.empty(self.num_embeddings, self.embedding_dim))
|
||||
init.xavier_uniform_(self.weight)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
output = F.embedding(hidden_state, self.weight,
|
||||
self.padding_idx, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq,
|
||||
self.sparse)
|
||||
output = F.embedding(
|
||||
hidden_state,
|
||||
self.weight,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f'VocabEmbedding(num_embeddings={self.num_embeddings}, ' \
|
||||
f'embedding_dim={self.embedding_dim})'
|
||||
return f"VocabEmbedding(num_embeddings={self.num_embeddings}, " f"embedding_dim={self.embedding_dim})"
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
@@ -48,12 +50,7 @@ class Embedding(nn.Module):
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
max_sequence_length,
|
||||
embedding_dropout_prob,
|
||||
num_tokentypes):
|
||||
def __init__(self, hidden_size, vocab_size, max_sequence_length, embedding_dropout_prob, num_tokentypes):
|
||||
super(Embedding, self).__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
@@ -62,16 +59,14 @@ class Embedding(nn.Module):
|
||||
self.word_embeddings = VocabEmbedding(vocab_size, self.hidden_size)
|
||||
|
||||
# Position embedding (serial).
|
||||
self.position_embeddings = torch.nn.Embedding(
|
||||
max_sequence_length, self.hidden_size)
|
||||
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
|
||||
|
||||
# Token type embedding.
|
||||
# Add this as an optional field that can be added through
|
||||
# method call so we can load a pretrain model without
|
||||
# token types and add them as needed.
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
|
||||
self.hidden_size)
|
||||
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
|
||||
else:
|
||||
self.tokentype_embeddings = None
|
||||
|
||||
|
||||
@@ -3,12 +3,10 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from loss_func.cross_entropy import vocab_cross_entropy
|
||||
|
||||
import colossalai
|
||||
from colossalai.kernel import LayerNorm
|
||||
from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
||||
from .embedding import VocabEmbedding
|
||||
from .linear import Linear
|
||||
from .pooler import Pooler
|
||||
|
||||
@@ -26,7 +24,6 @@ class BertLMHead(nn.Module):
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
):
|
||||
|
||||
super(BertLMHead, self).__init__()
|
||||
self.bias = torch.nn.Parameter(torch.zeros(vocab_size))
|
||||
|
||||
@@ -46,7 +43,6 @@ class BertLMHead(nn.Module):
|
||||
|
||||
|
||||
class BertBinaryHead(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size):
|
||||
super().__init__()
|
||||
self.pooler = Pooler(hidden_size)
|
||||
@@ -62,7 +58,6 @@ class BertBinaryHead(nn.Module):
|
||||
|
||||
|
||||
class BertDualHead(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, vocab_size, add_binary_head):
|
||||
super().__init__()
|
||||
self.lm_head = BertLMHead(vocab_size, hidden_size)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import torch
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def init_normal(tensor, sigma):
|
||||
"""Init method based on N(0, sigma)."""
|
||||
torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Parameter
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
@@ -24,11 +24,7 @@ class Linear(nn.Module):
|
||||
adding bias but instead return it.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
bias=True,
|
||||
skip_bias_add=False):
|
||||
def __init__(self, input_size, output_size, bias=True, skip_bias_add=False):
|
||||
super(Linear, self).__init__()
|
||||
|
||||
# Keep input parameters
|
||||
@@ -36,9 +32,12 @@ class Linear(nn.Module):
|
||||
self.output_size = output_size
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
self.weight = Parameter(torch.empty(self.output_size,
|
||||
self.input_size,
|
||||
))
|
||||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
self.output_size,
|
||||
self.input_size,
|
||||
)
|
||||
)
|
||||
init.normal_(self.weight)
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.output_size))
|
||||
@@ -46,7 +45,7 @@ class Linear(nn.Module):
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, input_):
|
||||
# Matrix multiply.
|
||||
@@ -59,5 +58,7 @@ class Linear(nn.Module):
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \
|
||||
f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})'
|
||||
return (
|
||||
f"Linear(in_features={self.input_size}, out_features={self.output_size}, "
|
||||
+ f"bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})"
|
||||
)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .linear import Linear
|
||||
from colossalai.kernel.jit import bias_gelu_impl
|
||||
|
||||
from .linear import Linear
|
||||
|
||||
|
||||
class TransformerMLP(nn.Module):
|
||||
"""MLP.
|
||||
@@ -18,19 +18,13 @@ class TransformerMLP(nn.Module):
|
||||
super(TransformerMLP, self).__init__()
|
||||
|
||||
# Project to 4h.
|
||||
self.dense_h_to_4h = Linear(
|
||||
hidden_size,
|
||||
int(hidden_size*mlp_ratio),
|
||||
skip_bias_add=True)
|
||||
self.dense_h_to_4h = Linear(hidden_size, int(hidden_size * mlp_ratio), skip_bias_add=True)
|
||||
|
||||
self.bias_gelu_fusion = fuse_gelu
|
||||
self.activation_func = F.gelu
|
||||
|
||||
# Project back to h.
|
||||
self.dense_4h_to_h = Linear(
|
||||
int(hidden_size*mlp_ratio),
|
||||
hidden_size,
|
||||
skip_bias_add=True)
|
||||
self.dense_4h_to_h = Linear(int(hidden_size * mlp_ratio), hidden_size, skip_bias_add=True)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# hidden states should be in the shape of [s, b, h]
|
||||
@@ -39,11 +33,9 @@ class TransformerMLP(nn.Module):
|
||||
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
|
||||
|
||||
if self.bias_gelu_fusion:
|
||||
intermediate_parallel = \
|
||||
bias_gelu_impl(intermediate_parallel, bias_parallel)
|
||||
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
|
||||
else:
|
||||
intermediate_parallel = \
|
||||
self.activation_func(intermediate_parallel + bias_parallel)
|
||||
intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel)
|
||||
|
||||
# [s, b, h]
|
||||
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .linear import Linear
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from colossalai.legacy.core import global_context as gpc
|
||||
|
||||
|
||||
class PreProcessor(nn.Module):
|
||||
|
||||
def __init__(self, sub_seq_length):
|
||||
super().__init__()
|
||||
self.sub_seq_length = sub_seq_length
|
||||
@@ -15,10 +14,9 @@ class PreProcessor(nn.Module):
|
||||
# Create position ids
|
||||
seq_length = token_ids.size(1)
|
||||
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
||||
position_ids = torch.arange(seq_length * local_rank,
|
||||
seq_length * (local_rank + 1),
|
||||
dtype=torch.long,
|
||||
device=token_ids.device)
|
||||
position_ids = torch.arange(
|
||||
seq_length * local_rank, seq_length * (local_rank + 1), dtype=torch.long, device=token_ids.device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
|
||||
|
||||
return position_ids
|
||||
@@ -42,7 +40,7 @@ class PreProcessor(nn.Module):
|
||||
extended_attention_mask = attention_mask_bss.unsqueeze(1)
|
||||
|
||||
# Convert attention mask to binary:
|
||||
extended_attention_mask = (extended_attention_mask < 0.5)
|
||||
extended_attention_mask = extended_attention_mask < 0.5
|
||||
|
||||
return extended_attention_mask
|
||||
|
||||
|
||||
Reference in New Issue
Block a user