[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -18,18 +18,21 @@ from colossalai.legacy.utils.activation_checkpoint import checkpoint
from colossalai.utils import checkpoint
__all__ = [
'GPTMLP1D', 'GPTSelfAttention1D', 'GPTTransformerLayer1D', 'FusedGPTSelfAttention1D', 'FusedGPTTransformerLayer1D'
"GPTMLP1D",
"GPTSelfAttention1D",
"GPTTransformerLayer1D",
"FusedGPTSelfAttention1D",
"FusedGPTTransformerLayer1D",
]
class GPTMLP1D(ParallelLayer):
def __init__(
self,
in_features: int,
mlp_ratio: int,
act_func: str = 'gelu',
dropout_prob: float = 0.,
act_func: str = "gelu",
dropout_prob: float = 0.0,
dtype=None,
checkpoint: bool = False,
skip_bias_add: bool = False,
@@ -82,7 +85,6 @@ class GPTMLP1D(ParallelLayer):
class GenericGPTSelfAttention1D(ParallelLayer):
def __init__(
self,
hidden_size: int,
@@ -118,8 +120,10 @@ class GenericGPTSelfAttention1D(ParallelLayer):
def _forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads_per_partition, 3 * self.attention_head_size)
new_qkv_shape = query_key_value.shape[:-1] + (
self.num_attention_heads_per_partition,
3 * self.attention_head_size,
)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(query_key_value, 3, dim=-1)
@@ -152,28 +156,32 @@ class GenericGPTSelfAttention1D(ParallelLayer):
class GPTSelfAttention1D(GenericGPTSelfAttention1D):
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False,
max_position_embeddings=1024):
super().__init__(hidden_size,
num_attention_heads,
attention_dropout_prob,
hidden_dropout_prob,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings)
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False,
max_position_embeddings=1024,
):
super().__init__(
hidden_size,
num_attention_heads,
attention_dropout_prob,
hidden_dropout_prob,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings,
)
self.softmax = nn.Softmax(dim=-1)
max_positions = max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions),
dtype=torch.uint8)).view(1, 1, max_positions, max_positions),
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
)
self.register_buffer("masked_bias", torch.tensor(-1e4))
@@ -181,7 +189,7 @@ class GPTSelfAttention1D(GenericGPTSelfAttention1D):
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# causal mask
query_length, key_length = query_layer.size(-2), key_layer.size(-2)
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores))
if attention_mask is not None:
# Apply the attention mask
@@ -191,50 +199,56 @@ class GPTSelfAttention1D(GenericGPTSelfAttention1D):
class FusedGPTSelfAttention1D(GenericGPTSelfAttention1D):
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False,
max_position_embeddings=1024):
super().__init__(hidden_size,
num_attention_heads,
attention_dropout_prob,
hidden_dropout_prob,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings)
self.softmax = kernel.FusedScaleMaskSoftmax(input_in_fp16=True,
input_in_bf16=False,
attn_mask_type=AttnMaskType.causal,
scaled_masked_softmax_fusion=True,
mask_func=None,
softmax_in_fp32=True,
scale=math.sqrt(self.attention_head_size))
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False,
max_position_embeddings=1024,
):
super().__init__(
hidden_size,
num_attention_heads,
attention_dropout_prob,
hidden_dropout_prob,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings,
)
self.softmax = kernel.FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
attn_mask_type=AttnMaskType.causal,
scaled_masked_softmax_fusion=True,
mask_func=None,
softmax_in_fp32=True,
scale=math.sqrt(self.attention_head_size),
)
def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
return self.softmax(attention_scores, attention_mask)
class GenericGPTTransformerLayer1D(ParallelLayer):
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False,
max_position_embeddings: int = 1024,
layer_norm_epsilon: float = 1e-5,
apply_post_layer_norm: bool = False,
attention=None,
layer_norm=None):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
act_func: str = "gelu",
mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0.0,
hidden_dropout_prob: float = 0.0,
dtype=None,
checkpoint: bool = False,
max_position_embeddings: int = 1024,
layer_norm_epsilon: float = 1e-5,
apply_post_layer_norm: bool = False,
attention=None,
layer_norm=None,
):
super().__init__()
self.checkpoint = checkpoint
self.dtype = dtype
@@ -288,62 +302,68 @@ class GenericGPTTransformerLayer1D(ParallelLayer):
class GPTTransformerLayer1D(GenericGPTTransformerLayer1D):
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4,
attention_dropout_prob: float = 0,
hidden_dropout_prob: float = 0,
dtype=None,
checkpoint: bool = False,
max_position_embeddings: int = 1024,
layer_norm_epsilon: float = 0.00001,
apply_post_layer_norm: bool = False):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
act_func: str = "gelu",
mlp_ratio: float = 4,
attention_dropout_prob: float = 0,
hidden_dropout_prob: float = 0,
dtype=None,
checkpoint: bool = False,
max_position_embeddings: int = 1024,
layer_norm_epsilon: float = 0.00001,
apply_post_layer_norm: bool = False,
):
attention = GPTSelfAttention1D
layer_norm = nn.LayerNorm
super().__init__(hidden_size,
num_attention_heads,
act_func=act_func,
mlp_ratio=mlp_ratio,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings,
layer_norm_epsilon=layer_norm_epsilon,
apply_post_layer_norm=apply_post_layer_norm,
attention=attention,
layer_norm=layer_norm)
super().__init__(
hidden_size,
num_attention_heads,
act_func=act_func,
mlp_ratio=mlp_ratio,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings,
layer_norm_epsilon=layer_norm_epsilon,
apply_post_layer_norm=apply_post_layer_norm,
attention=attention,
layer_norm=layer_norm,
)
class FusedGPTTransformerLayer1D(GenericGPTTransformerLayer1D):
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4,
attention_dropout_prob: float = 0,
hidden_dropout_prob: float = 0,
dtype=None,
checkpoint: bool = False,
max_position_embeddings: int = 1024,
layer_norm_epsilon: float = 0.00001,
apply_post_layer_norm: bool = False):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
act_func: str = "gelu",
mlp_ratio: float = 4,
attention_dropout_prob: float = 0,
hidden_dropout_prob: float = 0,
dtype=None,
checkpoint: bool = False,
max_position_embeddings: int = 1024,
layer_norm_epsilon: float = 0.00001,
apply_post_layer_norm: bool = False,
):
attention = FusedGPTSelfAttention1D
layer_norm = kernel.LayerNorm
super().__init__(hidden_size,
num_attention_heads,
act_func=act_func,
mlp_ratio=mlp_ratio,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings,
layer_norm_epsilon=layer_norm_epsilon,
apply_post_layer_norm=apply_post_layer_norm,
attention=attention,
layer_norm=layer_norm)
super().__init__(
hidden_size,
num_attention_heads,
act_func=act_func,
mlp_ratio=mlp_ratio,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings,
layer_norm_epsilon=layer_norm_epsilon,
apply_post_layer_norm=apply_post_layer_norm,
attention=attention,
layer_norm=layer_norm,
)