mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-28 20:30:42 +00:00
[SC] add GPT example for auto checkpoint (#1889)
* [sc] SC tutorial for auto checkpoint * [sc] polish examples * [sc] polish readme * [sc] polish readme and help information * [sc] polish readme and help information
This commit is contained in:
@@ -1,16 +1,33 @@
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as tm
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
|
||||
from colossalai.fx import metainfo_trace
|
||||
|
||||
|
||||
def bench(gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5):
|
||||
def bench(gm: torch.fx.GraphModule,
|
||||
criterion: torch.nn.Module,
|
||||
data_gen: Callable,
|
||||
num_steps: int = 5) -> Tuple[int, int]:
|
||||
"""Benchmarking a given graph module
|
||||
|
||||
Args:
|
||||
gm (torch.fx.GraphModule): The graph module to benchmark.
|
||||
criterion (torch.nn.Module): Loss function.
|
||||
data_gen (Callable): Data generator.
|
||||
num_steps (int, optional): Number of test steps. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: peak memory in MB and step time in MS.
|
||||
"""
|
||||
gm.train()
|
||||
gm.cuda()
|
||||
step_time = float('inf')
|
||||
@@ -39,7 +56,8 @@ def bench(gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callab
|
||||
del args, label, output, loss
|
||||
gm.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
return (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2, step_time * 1.0e3
|
||||
peak_mem = (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2
|
||||
return peak_mem, step_time * 1.0e3
|
||||
|
||||
|
||||
def bench_rotor(gm: torch.fx.GraphModule,
|
||||
@@ -47,19 +65,92 @@ def bench_rotor(gm: torch.fx.GraphModule,
|
||||
data_gen: Callable,
|
||||
num_steps: int = 5,
|
||||
sample_points: int = 20,
|
||||
free_memory: int = torch.cuda.mem_get_info()[0]):
|
||||
free_memory: int = torch.cuda.mem_get_info()[0],
|
||||
start_factor: int = 4) -> Tuple[np.array, list, list]:
|
||||
"""Auto Checkpoint Rotor Algorithm benchmarking
|
||||
Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data.
|
||||
|
||||
Args:
|
||||
gm (torch.fx.GraphModule): The graph module to benchmark.
|
||||
criterion (torch.nn.Module): Loss function.
|
||||
data_gen (Callable): Data generator.
|
||||
num_steps (int, optional): Number of test steps. Defaults to 5.
|
||||
sample_points (int, optional): Number of sample points. Defaults to 20.
|
||||
free_memory (int, optional): Max memory budget in Byte. Defaults to torch.cuda.mem_get_info()[0].
|
||||
start_factor (int, optional): Start memory budget factor for benchmark, the start memory budget
|
||||
will be free_memory / start_factor. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
Tuple[np.array, list, list]: return budgets vector (MB), peak memory vector (MB), step time vector (MS).
|
||||
"""
|
||||
peak_hist, step_hist = [], []
|
||||
for budget in np.linspace(free_memory // 5, free_memory, sample_points):
|
||||
raw_graph = deepcopy(gm.graph)
|
||||
for budget in np.linspace(free_memory // start_factor, free_memory, sample_points):
|
||||
gm = metainfo_trace(gm, *data_gen()[0])
|
||||
solver = CheckpointSolverRotor(gm.graph, free_memory=budget)
|
||||
try:
|
||||
gm.graph = solver.solve()
|
||||
peak_memory, step_time = bench(gm,
|
||||
criterion,
|
||||
partial(data_gen, batch_size=2048, shape=(3, 224, 224)),
|
||||
num_steps=num_steps)
|
||||
gm.graph = solver.solve(verbose=False)
|
||||
peak_memory, step_time = bench(gm, criterion, data_gen, num_steps=num_steps)
|
||||
except:
|
||||
peak_memory, step_time = budget / 1024**2, float('inf')
|
||||
peak_hist.append(peak_memory)
|
||||
step_hist.append(step_time)
|
||||
return peak_hist, step_hist
|
||||
gm.graph = deepcopy(raw_graph)
|
||||
return np.linspace(free_memory // start_factor, free_memory, sample_points) / 1024**2, peak_hist, step_hist
|
||||
|
||||
|
||||
class GPTLMModel(nn.Module):
|
||||
"""
|
||||
GPT Model
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257,
|
||||
checkpoint=False):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size))
|
||||
if checkpoint:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
"""
|
||||
GPT Loss
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
|
||||
def gpt2_medium(checkpoint=False):
|
||||
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_xl(checkpoint=False):
|
||||
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_6b(checkpoint=False):
|
||||
return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
Reference in New Issue
Block a user