ColossalAI/examples/tutorial/auto_parallel/demo_resnet152.py
Boyuan Yao d5c5bc219e
[SC] add GPT example for auto checkpoint (#1889)
* [sc] SC tutorial for auto checkpoint

* [sc] polish examples

* [sc] polish readme

* [sc] polish readme and help information

* [sc] polish readme and help information
2022-11-11 23:17:25 +08:00

75 lines
2.9 KiB
Python

import time
from argparse import ArgumentParser
from copy import deepcopy
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from bench_utils import bench
import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port
def data_gen(batch_size, shape, device='cuda'):
"""
Generate random data for benchmarking
"""
data = torch.empty(batch_size, *shape, device=device)
label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
return (data,), label
def _resnet152_benchmark(rank, world_size, port, num_steps):
"""Resnet152 benchmark
This benchmark test the through put of Resnet152 with our activation solver given the memory budget of 95% of
maximum GPU memory, and with the batch size of [512, 1024, 2048]
"""
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = tm.resnet152()
gm = symbolic_trace(model)
raw_graph = deepcopy(gm.graph)
peak_mems, through_puts, batch_sizes = [], [], [512, 1024, 2048]
for batch_size in batch_sizes:
batch_size = int(batch_size)
gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device='meta'))
solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info()[0] * 0.95)
gm.graph = solver.solve()
peak_mem, step_time = bench(gm,
torch.nn.CrossEntropyLoss(),
partial(data_gen, batch_size=batch_size, shape=(3, 224, 224)),
num_steps=num_steps)
peak_mems.append(peak_mem)
through_puts.append(batch_size / step_time * 1.0e3)
gm.graph = deepcopy(raw_graph)
# print results
print("===============test summary================")
for batch_size, peak_mem, through_put in zip(batch_sizes, peak_mems, through_puts):
print(f'batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s')
plt.plot(batch_sizes, through_puts)
plt.xlabel("batch size")
plt.ylabel("through put (images/s)")
plt.title("Resnet152 benchmark")
plt.savefig("resnet152_benchmark.png")
def resnet152_benchmark(num_steps):
world_size = 1
run_func_module = partial(_resnet152_benchmark, world_size=world_size, port=free_port(), num_steps=num_steps)
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == "__main__":
parser = ArgumentParser("ResNet152 Auto Activation Through Put Benchmark")
parser.add_argument("--num_steps", type=int, default=5, help="number of test steps for benchmark, default 5")
args = parser.parse_args()
resnet152_benchmark(args.num_steps)