[SC] add GPT example for auto checkpoint (#1889)

* [sc] SC tutorial for auto checkpoint

* [sc] polish examples

* [sc] polish readme

* [sc] polish readme and help information

* [sc] polish readme and help information
This commit is contained in:
Boyuan Yao
2022-11-11 23:17:25 +08:00
committed by GitHub
parent 11ee8ae478
commit d5c5bc219e
9 changed files with 470 additions and 889 deletions

View File

@@ -1,16 +1,33 @@
import time
from copy import deepcopy
from functools import partial
from typing import Callable, Tuple
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as tm
from transformers import GPT2Config, GPT2LMHeadModel
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace
def bench(gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5):
def bench(gm: torch.fx.GraphModule,
criterion: torch.nn.Module,
data_gen: Callable,
num_steps: int = 5) -> Tuple[int, int]:
"""Benchmarking a given graph module
Args:
gm (torch.fx.GraphModule): The graph module to benchmark.
criterion (torch.nn.Module): Loss function.
data_gen (Callable): Data generator.
num_steps (int, optional): Number of test steps. Defaults to 5.
Returns:
Tuple[int, int]: peak memory in MB and step time in MS.
"""
gm.train()
gm.cuda()
step_time = float('inf')
@@ -39,7 +56,8 @@ def bench(gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callab
del args, label, output, loss
gm.to("cpu")
torch.cuda.empty_cache()
return (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2, step_time * 1.0e3
peak_mem = (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2
return peak_mem, step_time * 1.0e3
def bench_rotor(gm: torch.fx.GraphModule,
@@ -47,19 +65,92 @@ def bench_rotor(gm: torch.fx.GraphModule,
data_gen: Callable,
num_steps: int = 5,
sample_points: int = 20,
free_memory: int = torch.cuda.mem_get_info()[0]):
free_memory: int = torch.cuda.mem_get_info()[0],
start_factor: int = 4) -> Tuple[np.array, list, list]:
"""Auto Checkpoint Rotor Algorithm benchmarking
Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data.
Args:
gm (torch.fx.GraphModule): The graph module to benchmark.
criterion (torch.nn.Module): Loss function.
data_gen (Callable): Data generator.
num_steps (int, optional): Number of test steps. Defaults to 5.
sample_points (int, optional): Number of sample points. Defaults to 20.
free_memory (int, optional): Max memory budget in Byte. Defaults to torch.cuda.mem_get_info()[0].
start_factor (int, optional): Start memory budget factor for benchmark, the start memory budget
will be free_memory / start_factor. Defaults to 4.
Returns:
Tuple[np.array, list, list]: return budgets vector (MB), peak memory vector (MB), step time vector (MS).
"""
peak_hist, step_hist = [], []
for budget in np.linspace(free_memory // 5, free_memory, sample_points):
raw_graph = deepcopy(gm.graph)
for budget in np.linspace(free_memory // start_factor, free_memory, sample_points):
gm = metainfo_trace(gm, *data_gen()[0])
solver = CheckpointSolverRotor(gm.graph, free_memory=budget)
try:
gm.graph = solver.solve()
peak_memory, step_time = bench(gm,
criterion,
partial(data_gen, batch_size=2048, shape=(3, 224, 224)),
num_steps=num_steps)
gm.graph = solver.solve(verbose=False)
peak_memory, step_time = bench(gm, criterion, data_gen, num_steps=num_steps)
except:
peak_memory, step_time = budget / 1024**2, float('inf')
peak_hist.append(peak_memory)
step_hist.append(step_time)
return peak_hist, step_hist
gm.graph = deepcopy(raw_graph)
return np.linspace(free_memory // start_factor, free_memory, sample_points) / 1024**2, peak_hist, step_hist
class GPTLMModel(nn.Module):
"""
GPT Model
"""
def __init__(self,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_seq_len=1024,
vocab_size=50257,
checkpoint=False):
super().__init__()
self.checkpoint = checkpoint
self.model = GPT2LMHeadModel(
GPT2Config(n_embd=hidden_size,
n_layer=num_layers,
n_head=num_attention_heads,
n_positions=max_seq_len,
n_ctx=max_seq_len,
vocab_size=vocab_size))
if checkpoint:
self.model.gradient_checkpointing_enable()
def forward(self, input_ids, attention_mask):
# Only return lm_logits
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
class GPTLMLoss(nn.Module):
"""
GPT Loss
"""
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
def gpt2_medium(checkpoint=False):
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
def gpt2_xl(checkpoint=False):
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)
def gpt2_6b(checkpoint=False):
return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)