mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-01 05:05:08 +00:00
* [sc] SC tutorial for auto checkpoint * [sc] polish examples * [sc] polish readme * [sc] polish readme and help information * [sc] polish readme and help information
75 lines
2.9 KiB
Python
75 lines
2.9 KiB
Python
import time
|
|
from argparse import ArgumentParser
|
|
from copy import deepcopy
|
|
from functools import partial
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
import torch.multiprocessing as mp
|
|
import torchvision.models as tm
|
|
from bench_utils import bench
|
|
|
|
import colossalai
|
|
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
|
|
from colossalai.fx import metainfo_trace, symbolic_trace
|
|
from colossalai.utils import free_port
|
|
|
|
|
|
def data_gen(batch_size, shape, device='cuda'):
|
|
"""
|
|
Generate random data for benchmarking
|
|
"""
|
|
data = torch.empty(batch_size, *shape, device=device)
|
|
label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
|
|
return (data,), label
|
|
|
|
|
|
def _resnet152_benchmark(rank, world_size, port, num_steps):
|
|
"""Resnet152 benchmark
|
|
This benchmark test the through put of Resnet152 with our activation solver given the memory budget of 95% of
|
|
maximum GPU memory, and with the batch size of [512, 1024, 2048]
|
|
"""
|
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
model = tm.resnet152()
|
|
gm = symbolic_trace(model)
|
|
raw_graph = deepcopy(gm.graph)
|
|
peak_mems, through_puts, batch_sizes = [], [], [512, 1024, 2048]
|
|
for batch_size in batch_sizes:
|
|
batch_size = int(batch_size)
|
|
gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device='meta'))
|
|
solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info()[0] * 0.95)
|
|
gm.graph = solver.solve()
|
|
peak_mem, step_time = bench(gm,
|
|
torch.nn.CrossEntropyLoss(),
|
|
partial(data_gen, batch_size=batch_size, shape=(3, 224, 224)),
|
|
num_steps=num_steps)
|
|
peak_mems.append(peak_mem)
|
|
through_puts.append(batch_size / step_time * 1.0e3)
|
|
gm.graph = deepcopy(raw_graph)
|
|
|
|
# print results
|
|
print("===============test summary================")
|
|
for batch_size, peak_mem, through_put in zip(batch_sizes, peak_mems, through_puts):
|
|
print(f'batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s')
|
|
|
|
plt.plot(batch_sizes, through_puts)
|
|
plt.xlabel("batch size")
|
|
plt.ylabel("through put (images/s)")
|
|
plt.title("Resnet152 benchmark")
|
|
plt.savefig("resnet152_benchmark.png")
|
|
|
|
|
|
def resnet152_benchmark(num_steps):
|
|
world_size = 1
|
|
run_func_module = partial(_resnet152_benchmark, world_size=world_size, port=free_port(), num_steps=num_steps)
|
|
mp.spawn(run_func_module, nprocs=world_size)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = ArgumentParser("ResNet152 Auto Activation Through Put Benchmark")
|
|
parser.add_argument("--num_steps", type=int, default=5, help="number of test steps for benchmark, default 5")
|
|
args = parser.parse_args()
|
|
|
|
resnet152_benchmark(args.num_steps)
|