[tutorial] modify hands-on of auto activation checkpoint (#1920)

* [sc] SC tutorial for auto checkpoint

* [sc] polish examples

* [sc] polish readme

* [sc] polish readme and help information

* [sc] polish readme and help information

* [sc] modify auto checkpoint benchmark

* [sc] remove imgs
This commit is contained in:
Boyuan Yao
2022-11-12 18:21:03 +08:00
committed by GitHub
parent ff16773ded
commit 24cbee0ebe
5 changed files with 135 additions and 192 deletions

View File

@@ -154,3 +154,21 @@ def gpt2_xl(checkpoint=False):
def gpt2_6b(checkpoint=False):
return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'):
"""
Generate random data for gpt2 benchmarking
"""
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
attention_mask = torch.ones_like(input_ids, device=device)
return (input_ids, attention_mask), attention_mask
def data_gen_resnet(batch_size, shape, device='cuda:0'):
"""
Generate random data for resnet benchmarking
"""
data = torch.empty(batch_size, *shape, device=device)
label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
return (data,), label