mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-11-03 23:48:41 +00:00 
			
		
		
		
	moved env variables to global variables; (#215)
added branch context; added vocab parallel layers; moved split_batch from load_batch to tensor parallel embedding layers; updated gpt model; updated unit test cases; fixed few collective communicator bugs
This commit is contained in:
		@@ -3,12 +3,20 @@ from typing import Callable
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from colossalai import nn as col_nn
 | 
			
		||||
from colossalai.nn.layer.utils import CheckpointModule
 | 
			
		||||
from colossalai.registry import LAYERS, MODELS, LOSSES
 | 
			
		||||
from colossalai.builder.pipeline import partition_uniform
 | 
			
		||||
from colossalai.context import ParallelMode
 | 
			
		||||
from colossalai.core import global_context as gpc
 | 
			
		||||
from colossalai.logging import get_dist_logger
 | 
			
		||||
from colossalai.nn.layer.utils import CheckpointModule, divide
 | 
			
		||||
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
 | 
			
		||||
from colossalai.registry import LAYERS, LOSSES, MODELS
 | 
			
		||||
from colossalai.utils import get_current_device
 | 
			
		||||
from torch import dtype, nn
 | 
			
		||||
 | 
			
		||||
__all__ = ['GPT', 'GPTLMLoss', 'gpt2_small', 'gpt2_medium', 'gpt2_large', 'gpt2_xl', 'gpt3']
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'GPT', 'GPTLMLoss', 'gpt2_small', 'gpt2_medium', 'gpt2_large', 'gpt2_xl', 'gpt2_8B', 'gpt2_xl_pipeline',
 | 
			
		||||
    'gpt2_8B_pipeline', 'gpt3', 'gpt3_pipeline'
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@LAYERS.register_module
 | 
			
		||||
@@ -18,7 +26,7 @@ class GPTEmbedding(nn.Module):
 | 
			
		||||
                 vocab_size: int,
 | 
			
		||||
                 max_position_embeddings: int,
 | 
			
		||||
                 num_tokentypes: int = 0,
 | 
			
		||||
                 padding_idx: int = 0,
 | 
			
		||||
                 padding_idx: int = None,
 | 
			
		||||
                 dropout: float = 0.,
 | 
			
		||||
                 dtype: dtype = None) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
@@ -34,7 +42,7 @@ class GPTEmbedding(nn.Module):
 | 
			
		||||
    def word_embedding_weight(self):
 | 
			
		||||
        return self.word_embeddings.weight
 | 
			
		||||
 | 
			
		||||
    def forward(self, input_ids, position_ids=None, tokentype_ids=None):
 | 
			
		||||
    def forward(self, input_ids, attention_mask=None, position_ids=None, tokentype_ids=None):
 | 
			
		||||
        seq_length = input_ids.size(1)
 | 
			
		||||
        if position_ids is None:
 | 
			
		||||
            position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0)
 | 
			
		||||
@@ -42,7 +50,20 @@ class GPTEmbedding(nn.Module):
 | 
			
		||||
        if self.tokentype_embeddings is not None and tokentype_ids is not None:
 | 
			
		||||
            x = x + self.tokentype_embeddings(tokentype_ids)
 | 
			
		||||
        x = self.dropout(x)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
        # We create a 3D attention mask from a 2D tensor mask.
 | 
			
		||||
        # Sizes are [batch_size, 1, 1, to_seq_length]
 | 
			
		||||
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
 | 
			
		||||
        # Adapted from huggingface
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            batch_size = input_ids.shape[0]
 | 
			
		||||
            attention_mask = attention_mask.view(batch_size, -1)
 | 
			
		||||
            attention_mask = col_nn.partition_batch(attention_mask)
 | 
			
		||||
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
 | 
			
		||||
            attention_mask = attention_mask.to(dtype=x.dtype)  # fp16 compatibility
 | 
			
		||||
            attention_mask = (1.0 - attention_mask) * -10000.0
 | 
			
		||||
 | 
			
		||||
        return x, attention_mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@LAYERS.register_module
 | 
			
		||||
@@ -53,20 +74,32 @@ class GPTSelfAttention(nn.Module):
 | 
			
		||||
                 attention_dropout: float,
 | 
			
		||||
                 dropout: float,
 | 
			
		||||
                 bias: bool = True,
 | 
			
		||||
                 fuse_scale_mask_softmax: bool = False,
 | 
			
		||||
                 dtype: dtype = None) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.attention_head_size = dim // num_heads
 | 
			
		||||
        self.fuse_scale_mask_softmax = fuse_scale_mask_softmax
 | 
			
		||||
        self.attention_head_size = divide(dim, num_heads)
 | 
			
		||||
        self.query_key_value = col_nn.Linear(dim, 3 * dim, dtype=dtype, bias=bias)
 | 
			
		||||
        if fuse_scale_mask_softmax:
 | 
			
		||||
            from colossalai.kernel import FusedScaleMaskSoftmax
 | 
			
		||||
            from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
 | 
			
		||||
            self.softmax = FusedScaleMaskSoftmax(input_in_fp16=True,
 | 
			
		||||
                                                 input_in_bf16=False,
 | 
			
		||||
                                                 attn_mask_type=AttnMaskType.causal,
 | 
			
		||||
                                                 scaled_masked_softmax_fusion=True,
 | 
			
		||||
                                                 mask_func=None,
 | 
			
		||||
                                                 softmax_in_fp32=True,
 | 
			
		||||
                                                 scale=math.sqrt(self.attention_head_size))
 | 
			
		||||
        else:
 | 
			
		||||
            self.softmax = nn.Softmax(dim=-1)
 | 
			
		||||
        self.attention_dropout = col_nn.Dropout(attention_dropout)
 | 
			
		||||
        self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True)
 | 
			
		||||
        self.dropout = col_nn.Dropout(dropout)
 | 
			
		||||
        self.softmax = nn.Softmax(dim=-1)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, attention_mask=None):
 | 
			
		||||
        qkv = self.query_key_value(x)
 | 
			
		||||
        all_head_size = qkv.shape[-1] // 3
 | 
			
		||||
        num_attention_heads = all_head_size // self.attention_head_size
 | 
			
		||||
        num_attention_heads = divide(all_head_size, self.attention_head_size)
 | 
			
		||||
        new_qkv_shape = qkv.shape[:-1] + \
 | 
			
		||||
            (num_attention_heads, 3 * self.attention_head_size)
 | 
			
		||||
        qkv = qkv.view(new_qkv_shape)
 | 
			
		||||
@@ -74,17 +107,20 @@ class GPTSelfAttention(nn.Module):
 | 
			
		||||
        q, k, v = torch.chunk(qkv, 3, dim=-1)
 | 
			
		||||
 | 
			
		||||
        x = torch.matmul(q, k.transpose(-1, -2))
 | 
			
		||||
        x = x / math.sqrt(self.attention_head_size)
 | 
			
		||||
 | 
			
		||||
        # causal mask
 | 
			
		||||
        q_len, k_len = q.size(-2), k.size(-2)
 | 
			
		||||
        causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8,
 | 
			
		||||
                                            device=get_current_device())).view(1, 1, q_len, k_len).bool()
 | 
			
		||||
        x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device()))
 | 
			
		||||
        if self.fuse_scale_mask_softmax:
 | 
			
		||||
            x = self.softmax(x, attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            x = x / math.sqrt(self.attention_head_size)
 | 
			
		||||
            # causal mask
 | 
			
		||||
            q_len, k_len = q.size(-2), k.size(-2)
 | 
			
		||||
            causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8,
 | 
			
		||||
                                                device=get_current_device())).view(1, 1, q_len, k_len).bool()
 | 
			
		||||
            x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device()))
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
                x = x + attention_mask
 | 
			
		||||
            x = self.softmax(x)
 | 
			
		||||
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            x = x + attention_mask
 | 
			
		||||
        x = self.softmax(x)
 | 
			
		||||
        x = self.attention_dropout(x)
 | 
			
		||||
 | 
			
		||||
        x = torch.matmul(x, v)
 | 
			
		||||
@@ -102,15 +138,16 @@ class GPTSelfAttention(nn.Module):
 | 
			
		||||
class GPTMLP(nn.Module):
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 dim: int,
 | 
			
		||||
                 mlp_ratio: int,
 | 
			
		||||
                 mlp_ratio: float,
 | 
			
		||||
                 activation: Callable,
 | 
			
		||||
                 dropout: float,
 | 
			
		||||
                 dtype: dtype = None,
 | 
			
		||||
                 bias: bool = True):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.dense_1 = col_nn.Linear(dim, mlp_ratio * dim, dtype=dtype, bias=bias)
 | 
			
		||||
        intermediate_dim = int(dim * mlp_ratio)
 | 
			
		||||
        self.dense_1 = col_nn.Linear(dim, intermediate_dim, dtype=dtype, bias=bias)
 | 
			
		||||
        self.activation = activation
 | 
			
		||||
        self.dense_2 = col_nn.Linear(mlp_ratio * dim, dim, dtype=dtype, bias=bias)
 | 
			
		||||
        self.dense_2 = col_nn.Linear(intermediate_dim, dim, dtype=dtype, bias=bias)
 | 
			
		||||
        self.dropout = col_nn.Dropout(dropout)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
@@ -126,27 +163,44 @@ class GPTBlock(CheckpointModule):
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 dim: int,
 | 
			
		||||
                 num_heads: int,
 | 
			
		||||
                 mlp_ratio: int,
 | 
			
		||||
                 mlp_ratio: float,
 | 
			
		||||
                 activation: Callable,
 | 
			
		||||
                 attention_dropout: float = 0.,
 | 
			
		||||
                 dropout: float = 0.,
 | 
			
		||||
                 layernorm_epsilon: float = 1e-5,
 | 
			
		||||
                 dtype: dtype = None,
 | 
			
		||||
                 bias: bool = True,
 | 
			
		||||
                 apply_post_layernorm: bool = False,
 | 
			
		||||
                 fuse_scale_mask_softmax: bool = False,
 | 
			
		||||
                 checkpoint: bool = False):
 | 
			
		||||
        super().__init__(checkpoint=checkpoint)
 | 
			
		||||
        self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
 | 
			
		||||
        super().__init__(checkpoint)
 | 
			
		||||
        self.apply_post_layernorm = apply_post_layernorm
 | 
			
		||||
        self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
 | 
			
		||||
        self.attn = GPTSelfAttention(dim=dim,
 | 
			
		||||
                                     num_heads=num_heads,
 | 
			
		||||
                                     attention_dropout=attention_dropout,
 | 
			
		||||
                                     dropout=dropout,
 | 
			
		||||
                                     bias=bias,
 | 
			
		||||
                                     fuse_scale_mask_softmax=fuse_scale_mask_softmax,
 | 
			
		||||
                                     dtype=dtype)
 | 
			
		||||
        self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
 | 
			
		||||
        self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
 | 
			
		||||
        self.mlp = GPTMLP(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dropout=dropout, dtype=dtype, bias=bias)
 | 
			
		||||
 | 
			
		||||
    def _forward(self, x, attention_mask=None):
 | 
			
		||||
        x = x + self.attn(self.norm1(x), attention_mask)
 | 
			
		||||
        x = x + self.mlp(self.norm2(x))
 | 
			
		||||
        if not self.apply_post_layernorm:
 | 
			
		||||
            residual = x
 | 
			
		||||
        x = self.norm1(x)
 | 
			
		||||
        if self.apply_post_layernorm:
 | 
			
		||||
            residual = x
 | 
			
		||||
        x = residual + self.attn(x, attention_mask)
 | 
			
		||||
 | 
			
		||||
        if not self.apply_post_layernorm:
 | 
			
		||||
            residual = x
 | 
			
		||||
        x = self.norm2(x)
 | 
			
		||||
        if self.apply_post_layernorm:
 | 
			
		||||
            residual = x
 | 
			
		||||
        x = residual + self.mlp(x)
 | 
			
		||||
 | 
			
		||||
        return x, attention_mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -161,6 +215,10 @@ class GPTLMHead(nn.Module):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.dense = col_nn.Classifier(dim, vocab_size, word_embeeding_weight, bias=bias, dtype=dtype)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def weight(self):
 | 
			
		||||
        return self.dense.weight
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        x = self.dense(x)
 | 
			
		||||
        return x
 | 
			
		||||
@@ -187,18 +245,19 @@ class GPT(nn.Module):
 | 
			
		||||
                 dim: int = 768,
 | 
			
		||||
                 num_heads: int = 12,
 | 
			
		||||
                 depth: int = 12,
 | 
			
		||||
                 mlp_ratio: int = 4,
 | 
			
		||||
                 mlp_ratio: float = 4.0,
 | 
			
		||||
                 dropout: float = 0.1,
 | 
			
		||||
                 embedding_dropout: float = 0.1,
 | 
			
		||||
                 attention_dropout: float = 0.1,
 | 
			
		||||
                 layernorm_epsilon: float = 1e-5,
 | 
			
		||||
                 activation: Callable = nn.functional.gelu,
 | 
			
		||||
                 checkpoint: bool = False,
 | 
			
		||||
                 padding_idx: int = None,
 | 
			
		||||
                 dtype: dtype = None,
 | 
			
		||||
                 bias: bool = True,
 | 
			
		||||
                 padding_idx: int = 0) -> None:
 | 
			
		||||
                 apply_post_layernorm: bool = False,
 | 
			
		||||
                 fuse_scale_mask_softmax: bool = False,
 | 
			
		||||
                 checkpoint: bool = False) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.dtype = dtype
 | 
			
		||||
        self.embed = GPTEmbedding(embedding_dim=dim,
 | 
			
		||||
                                  vocab_size=vocab_size,
 | 
			
		||||
                                  max_position_embeddings=max_position_embeddings,
 | 
			
		||||
@@ -213,8 +272,11 @@ class GPT(nn.Module):
 | 
			
		||||
                activation=activation,
 | 
			
		||||
                attention_dropout=attention_dropout,
 | 
			
		||||
                dropout=dropout,
 | 
			
		||||
                layernorm_epsilon=layernorm_epsilon,
 | 
			
		||||
                dtype=dtype,
 | 
			
		||||
                bias=bias,
 | 
			
		||||
                apply_post_layernorm=apply_post_layernorm,
 | 
			
		||||
                fuse_scale_mask_softmax=fuse_scale_mask_softmax,
 | 
			
		||||
                checkpoint=checkpoint,
 | 
			
		||||
            ) for _ in range(depth)
 | 
			
		||||
        ])
 | 
			
		||||
@@ -224,22 +286,10 @@ class GPT(nn.Module):
 | 
			
		||||
        self.head = GPTLMHead(dim=dim,
 | 
			
		||||
                              vocab_size=vocab_size,
 | 
			
		||||
                              word_embeeding_weight=self.embed.word_embedding_weight,
 | 
			
		||||
                              bias=bias,
 | 
			
		||||
                              dtype=dtype)
 | 
			
		||||
 | 
			
		||||
    def forward(self, input_ids, attention_mask=None):
 | 
			
		||||
        # We create a 3D attention mask from a 2D tensor mask.
 | 
			
		||||
        # Sizes are [batch_size, 1, 1, to_seq_length]
 | 
			
		||||
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
 | 
			
		||||
        # Adapted from huggingface
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            batch_size = input_ids.shape[0]
 | 
			
		||||
            attention_mask = attention_mask.view(batch_size, -1)
 | 
			
		||||
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
 | 
			
		||||
            attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
 | 
			
		||||
            attention_mask = (1.0 - attention_mask) * -10000.0
 | 
			
		||||
 | 
			
		||||
        x = self.embed(input_ids)
 | 
			
		||||
        x, attention_mask = self.embed(input_ids, attention_mask)
 | 
			
		||||
 | 
			
		||||
        for block in self.blocks:
 | 
			
		||||
            x, attention_mask = block(x, attention_mask)
 | 
			
		||||
@@ -249,11 +299,103 @@ class GPT(nn.Module):
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PipelineGPT(nn.Module):
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 vocab_size: int = 50304,
 | 
			
		||||
                 max_position_embeddings: int = 1024,
 | 
			
		||||
                 dim: int = 768,
 | 
			
		||||
                 num_heads: int = 12,
 | 
			
		||||
                 depth: int = 12,
 | 
			
		||||
                 mlp_ratio: float = 4.0,
 | 
			
		||||
                 dropout: float = 0.1,
 | 
			
		||||
                 embedding_dropout: float = 0.1,
 | 
			
		||||
                 attention_dropout: float = 0.1,
 | 
			
		||||
                 layernorm_epsilon: float = 1e-5,
 | 
			
		||||
                 activation: Callable = nn.functional.gelu,
 | 
			
		||||
                 padding_idx: int = None,
 | 
			
		||||
                 dtype: dtype = None,
 | 
			
		||||
                 bias: bool = True,
 | 
			
		||||
                 apply_post_layernorm: bool = False,
 | 
			
		||||
                 fuse_scale_mask_softmax: bool = False,
 | 
			
		||||
                 checkpoint: bool = False,
 | 
			
		||||
                 first: bool = False,
 | 
			
		||||
                 last: bool = False):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.checkpoint = checkpoint
 | 
			
		||||
        self.first = first
 | 
			
		||||
        self.last = last
 | 
			
		||||
        if first:
 | 
			
		||||
            self.embed = GPTEmbedding(embedding_dim=dim,
 | 
			
		||||
                                      vocab_size=vocab_size,
 | 
			
		||||
                                      max_position_embeddings=max_position_embeddings,
 | 
			
		||||
                                      padding_idx=padding_idx,
 | 
			
		||||
                                      dropout=embedding_dropout,
 | 
			
		||||
                                      dtype=dtype)
 | 
			
		||||
        self.blocks = nn.ModuleList([
 | 
			
		||||
            GPTBlock(
 | 
			
		||||
                dim=dim,
 | 
			
		||||
                num_heads=num_heads,
 | 
			
		||||
                mlp_ratio=mlp_ratio,
 | 
			
		||||
                activation=activation,
 | 
			
		||||
                attention_dropout=attention_dropout,
 | 
			
		||||
                dropout=dropout,
 | 
			
		||||
                layernorm_epsilon=layernorm_epsilon,
 | 
			
		||||
                dtype=dtype,
 | 
			
		||||
                bias=bias,
 | 
			
		||||
                apply_post_layernorm=apply_post_layernorm,
 | 
			
		||||
                fuse_scale_mask_softmax=fuse_scale_mask_softmax,
 | 
			
		||||
                checkpoint=checkpoint,
 | 
			
		||||
            ) for _ in range(depth)
 | 
			
		||||
        ])
 | 
			
		||||
        if self.last:
 | 
			
		||||
            self.norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
 | 
			
		||||
            self.head = GPTLMHead(dim=dim, vocab_size=vocab_size, dtype=dtype)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x=None, input_ids=None, attention_mask=None):
 | 
			
		||||
        if self.first:
 | 
			
		||||
            x, attention_mask = self.embed(input_ids, attention_mask)
 | 
			
		||||
 | 
			
		||||
        for block in self.blocks:
 | 
			
		||||
            x, attention_mask = block(x, attention_mask)
 | 
			
		||||
 | 
			
		||||
        if self.last:
 | 
			
		||||
            x = self.head(self.norm(x))
 | 
			
		||||
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _create_gpt_model(**model_kwargs):
 | 
			
		||||
    model = GPT(**model_kwargs)
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _create_gpt_pipeline_model(depth=48, num_chunks=1, layer_partitions=None, **model_kwargs):
 | 
			
		||||
    logger = get_dist_logger()
 | 
			
		||||
    pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
 | 
			
		||||
    pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
 | 
			
		||||
    rank = gpc.get_global_rank()
 | 
			
		||||
    wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
 | 
			
		||||
    parts = partition_uniform(depth, pipeline_size,
 | 
			
		||||
                              num_chunks)[pipeline_rank] if layer_partitions is None else layer_partitions
 | 
			
		||||
    models = []
 | 
			
		||||
    for start, end in parts:
 | 
			
		||||
        model_kwargs['first'] = start == 0
 | 
			
		||||
        model_kwargs['last'] = end == depth
 | 
			
		||||
        model_kwargs['depth'] = end - start
 | 
			
		||||
        chunk = PipelineGPT(**model_kwargs).to(get_current_device())
 | 
			
		||||
        if start == 0:
 | 
			
		||||
            wrapper.register_parameter(chunk.embed.word_embedding_weight)
 | 
			
		||||
        elif end == depth:
 | 
			
		||||
            wrapper.register_parameter(chunk.head.weight)
 | 
			
		||||
        models.append(chunk)
 | 
			
		||||
        logger.info(f'==> Rank {rank} built layer {start}-{end} / total {depth}')
 | 
			
		||||
    if len(models) == 1:
 | 
			
		||||
        model = models[0]
 | 
			
		||||
    else:
 | 
			
		||||
        model = nn.ModuleList(models)
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@MODELS.register_module
 | 
			
		||||
def gpt2_small(**kwargs):
 | 
			
		||||
    model_kwargs = dict(dim=768, depth=12, num_heads=12, **kwargs)
 | 
			
		||||
@@ -262,23 +404,47 @@ def gpt2_small(**kwargs):
 | 
			
		||||
 | 
			
		||||
@MODELS.register_module
 | 
			
		||||
def gpt2_medium(**kwargs):
 | 
			
		||||
    model_kwargs = dict(dim=1024, depth=24, num_heads=16, **kwargs)
 | 
			
		||||
    model_kwargs = dict(dim=1024, depth=24, num_heads=8, **kwargs)
 | 
			
		||||
    return _create_gpt_model(**model_kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@MODELS.register_module
 | 
			
		||||
def gpt2_large(**kwargs):
 | 
			
		||||
    model_kwargs = dict(dim=1280, depth=36, num_heads=20, **kwargs)
 | 
			
		||||
    model_kwargs = dict(dim=1536, depth=36, num_heads=12, **kwargs)
 | 
			
		||||
    return _create_gpt_model(**model_kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@MODELS.register_module
 | 
			
		||||
def gpt2_xl(**kwargs):
 | 
			
		||||
    model_kwargs = dict(dim=1600, depth=48, num_heads=25, **kwargs)
 | 
			
		||||
    model_kwargs = dict(dim=1600, depth=48, num_heads=16, **kwargs)
 | 
			
		||||
    return _create_gpt_model(**model_kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@MODELS.register_module
 | 
			
		||||
def gpt3(**kwargs):
 | 
			
		||||
    model_kwargs = dict(dim=12288, max_position_embeddings=2048, depth=96, num_heads=96, **kwargs)
 | 
			
		||||
def gpt2_8B(**kwargs):
 | 
			
		||||
    model_kwargs = dict(dim=3072, depth=72, num_heads=24, **kwargs)
 | 
			
		||||
    return _create_gpt_model(**model_kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@MODELS.register_module
 | 
			
		||||
def gpt2_xl_pipeline(**kwargs):
 | 
			
		||||
    model_kwargs = dict(dim=1600, depth=48, num_heads=20, **kwargs)
 | 
			
		||||
    return _create_gpt_pipeline_model(**model_kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@MODELS.register_module
 | 
			
		||||
def gpt2_8B_pipeline(**kwargs):
 | 
			
		||||
    model_kwargs = dict(dim=3072, depth=72, num_heads=24, **kwargs)
 | 
			
		||||
    return _create_gpt_pipeline_model(**model_kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@MODELS.register_module
 | 
			
		||||
def gpt3(**kwargs):
 | 
			
		||||
    model_kwargs = dict(dim=12288, depth=96, num_heads=96, **kwargs)
 | 
			
		||||
    return _create_gpt_model(**model_kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@MODELS.register_module
 | 
			
		||||
def gpt3_pipeline(**kwargs):
 | 
			
		||||
    model_kwargs = dict(dim=12288, depth=96, num_heads=96, **kwargs)
 | 
			
		||||
    return _create_gpt_pipeline_model(**model_kwargs)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user