mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[example] make gpt example directory more clear (#2353)
This commit is contained in:
73
examples/language/gpt/gemini/commons/model_zoo.py
Normal file
73
examples/language/gpt/gemini/commons/model_zoo.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from torch import nn
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
|
||||
## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257,
|
||||
checkpoint=False):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.config = GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size)
|
||||
self.model = GPT2LMHeadModel(self.config)
|
||||
if checkpoint:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
||||
|
||||
|
||||
def gpt2_medium(checkpoint=False):
|
||||
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_xl(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_10b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_14b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=4096, num_layers=70, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_20b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_24b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def model_builder(model_size: str) -> callable:
|
||||
if model_size == "gpt2_medium":
|
||||
return gpt2_medium
|
||||
elif model_size == "gpt2_xl":
|
||||
return gpt2_xl
|
||||
elif model_size == "gpt2_10b":
|
||||
return gpt2_10b
|
||||
elif model_size == "gpt2_14b":
|
||||
return gpt2_14b
|
||||
elif model_size == "gpt2_20b":
|
||||
return gpt2_20b
|
||||
elif model_size == "gpt2_24b":
|
||||
return gpt2_24b
|
||||
else:
|
||||
raise TypeError(f"model_builder {model_size}")
|
||||
|
||||
|
||||
__all__ = ['model_builder']
|
12
examples/language/gpt/gemini/commons/utils.py
Normal file
12
examples/language/gpt/gemini/commons/utils.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import torch
|
||||
|
||||
|
||||
# Randomly Generated Data
|
||||
def get_data(batch_size, seq_len, vocab_size):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
Reference in New Issue
Block a user