mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-24 17:33:39 +00:00
[tutorial] modify hands-on of auto activation checkpoint (#1920)
* [sc] SC tutorial for auto checkpoint * [sc] polish examples * [sc] polish readme * [sc] polish readme and help information * [sc] polish readme and help information * [sc] modify auto checkpoint benchmark * [sc] remove imgs
This commit is contained in:
@@ -154,3 +154,21 @@ def gpt2_xl(checkpoint=False):
|
||||
|
||||
def gpt2_6b(checkpoint=False):
|
||||
return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'):
|
||||
"""
|
||||
Generate random data for gpt2 benchmarking
|
||||
"""
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
||||
attention_mask = torch.ones_like(input_ids, device=device)
|
||||
return (input_ids, attention_mask), attention_mask
|
||||
|
||||
|
||||
def data_gen_resnet(batch_size, shape, device='cuda:0'):
|
||||
"""
|
||||
Generate random data for resnet benchmarking
|
||||
"""
|
||||
data = torch.empty(batch_size, *shape, device=device)
|
||||
label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
|
||||
return (data,), label
|
||||
|
||||
Reference in New Issue
Block a user