mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-23 07:39:31 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -20,20 +20,22 @@ def _benchmark(rank, world_size, port):
|
||||
only result in minor performance drop. So at last we might be able to find better training batch size for our
|
||||
model (combine with large batch training optimizer such as LAMB).
|
||||
"""
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
model = tm.resnet152()
|
||||
gm = symbolic_trace(model)
|
||||
raw_graph = deepcopy(gm.graph)
|
||||
peak_mems, through_puts, batch_sizes = [], [], [512, 1024, 2048]
|
||||
for batch_size in batch_sizes:
|
||||
batch_size = int(batch_size)
|
||||
gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device='meta'))
|
||||
gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device="meta"))
|
||||
solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info()[0] * 0.95)
|
||||
gm.graph = solver.solve()
|
||||
peak_mem, step_time = bench(gm,
|
||||
torch.nn.CrossEntropyLoss(),
|
||||
partial(data_gen_resnet, batch_size=batch_size, shape=(3, 224, 224)),
|
||||
num_steps=5)
|
||||
peak_mem, step_time = bench(
|
||||
gm,
|
||||
torch.nn.CrossEntropyLoss(),
|
||||
partial(data_gen_resnet, batch_size=batch_size, shape=(3, 224, 224)),
|
||||
num_steps=5,
|
||||
)
|
||||
peak_mems.append(peak_mem)
|
||||
through_puts.append(batch_size / step_time * 1.0e3)
|
||||
gm.graph = deepcopy(raw_graph)
|
||||
@@ -41,7 +43,7 @@ def _benchmark(rank, world_size, port):
|
||||
# print results
|
||||
print("===============benchmark summary================")
|
||||
for batch_size, peak_mem, through_put in zip(batch_sizes, peak_mems, through_puts):
|
||||
print(f'batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s')
|
||||
print(f"batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s")
|
||||
|
||||
|
||||
def auto_activation_checkpoint_batchsize_benchmark():
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from functools import partial
|
||||
|
||||
@@ -8,7 +7,6 @@ import torchvision.models as tm
|
||||
from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium
|
||||
|
||||
import colossalai
|
||||
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
|
||||
from colossalai.fx import metainfo_trace, symbolic_trace
|
||||
from colossalai.testing import spawn
|
||||
|
||||
@@ -19,37 +17,33 @@ def _benchmark(rank, world_size, port, args):
|
||||
The benchmark will sample in a range of memory budget for each model and output the benchmark summary and
|
||||
data visualization of peak memory vs. budget memory and relative step time vs. peak memory.
|
||||
"""
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
if args.model == 'resnet50':
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
if args.model == "resnet50":
|
||||
model = tm.resnet50()
|
||||
data_gen = partial(data_gen_resnet, batch_size=128, shape=(3, 224, 224))
|
||||
gm = symbolic_trace(model)
|
||||
gm = metainfo_trace(gm, torch.empty(128, 3, 224, 224, device='meta'))
|
||||
gm = metainfo_trace(gm, torch.empty(128, 3, 224, 224, device="meta"))
|
||||
loss = torch.nn.CrossEntropyLoss()
|
||||
else:
|
||||
model = gpt2_medium()
|
||||
data_gen = partial(data_gen_gpt2, batch_size=8, seq_len=1024, vocab_size=50257)
|
||||
data, mask = data_gen(device='meta')[0]
|
||||
gm = symbolic_trace(model, meta_args={'input_ids': data, 'attention_mask': mask})
|
||||
data, mask = data_gen(device="meta")[0]
|
||||
gm = symbolic_trace(model, meta_args={"input_ids": data, "attention_mask": mask})
|
||||
gm = metainfo_trace(gm, data, mask)
|
||||
loss = GPTLMLoss()
|
||||
|
||||
free_memory = 11000 * 1024**2 if args.model == 'resnet50' else 56000 * 1024**2
|
||||
start_factor = 4 if args.model == 'resnet50' else 10
|
||||
free_memory = 11000 * 1024**2 if args.model == "resnet50" else 56000 * 1024**2
|
||||
start_factor = 4 if args.model == "resnet50" else 10
|
||||
|
||||
# trace and benchmark
|
||||
budgets, peak_hist, step_hist = bench_rotor(gm,
|
||||
loss,
|
||||
data_gen,
|
||||
num_steps=5,
|
||||
sample_points=15,
|
||||
free_memory=free_memory,
|
||||
start_factor=start_factor)
|
||||
budgets, peak_hist, step_hist = bench_rotor(
|
||||
gm, loss, data_gen, num_steps=5, sample_points=15, free_memory=free_memory, start_factor=start_factor
|
||||
)
|
||||
|
||||
# print summary
|
||||
print("==============benchmark summary==============")
|
||||
for budget, peak, step in zip(budgets, peak_hist, step_hist):
|
||||
print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS')
|
||||
print(f"memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS")
|
||||
|
||||
# plot valid results
|
||||
fig, axs = plt.subplots(1, 2, figsize=(16, 8))
|
||||
@@ -57,14 +51,14 @@ def _benchmark(rank, world_size, port, args):
|
||||
|
||||
# plot peak memory vs. budget memory
|
||||
axs[0].plot(budgets[valid_idx:], peak_hist[valid_idx:])
|
||||
axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle='--')
|
||||
axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle="--")
|
||||
axs[0].set_xlabel("Budget Memory (MB)")
|
||||
axs[0].set_ylabel("Peak Memory (MB)")
|
||||
axs[0].set_title("Peak Memory vs. Budget Memory")
|
||||
|
||||
# plot relative step time vs. budget memory
|
||||
axs[1].plot(peak_hist[valid_idx:], [step_time / step_hist[-1] for step_time in step_hist[valid_idx:]])
|
||||
axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle='--')
|
||||
axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle="--")
|
||||
axs[1].set_xlabel("Peak Memory (MB)")
|
||||
axs[1].set_ylabel("Relative Step Time")
|
||||
axs[1].set_title("Step Time vs. Peak Memory")
|
||||
@@ -81,7 +75,7 @@ def auto_activation_checkpoint_benchmark(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser("Auto Activation Checkpoint Solver Benchmark")
|
||||
parser.add_argument("--model", type=str, default='gpt2', choices=['gpt2', 'resnet50'])
|
||||
parser.add_argument("--model", type=str, default="gpt2", choices=["gpt2", "resnet50"])
|
||||
args = parser.parse_args()
|
||||
|
||||
auto_activation_checkpoint_benchmark(args)
|
||||
|
@@ -17,14 +17,14 @@ def synthesize_data():
|
||||
|
||||
|
||||
def main():
|
||||
colossalai.launch_from_torch(config='./config.py')
|
||||
colossalai.launch_from_torch(config="./config.py")
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# trace the model with meta data
|
||||
model = resnet50(num_classes=10).cuda()
|
||||
|
||||
input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
|
||||
input_sample = {"x": torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to("meta")}
|
||||
device_mesh = DeviceMesh(physical_mesh_id=torch.tensor([0, 1, 2, 3]), mesh_shape=[2, 2], init_process_group=True)
|
||||
model, solution = initialize_model(model, input_sample, device_mesh=device_mesh, return_solution=True)
|
||||
|
||||
@@ -88,8 +88,9 @@ def main():
|
||||
|
||||
logger.info(
|
||||
f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}",
|
||||
ranks=[0])
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@@ -1,22 +1,19 @@
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as tm
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
|
||||
from colossalai.fx import metainfo_trace
|
||||
|
||||
|
||||
def bench(gm: torch.fx.GraphModule,
|
||||
criterion: torch.nn.Module,
|
||||
data_gen: Callable,
|
||||
num_steps: int = 5) -> Tuple[int, int]:
|
||||
def bench(
|
||||
gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5
|
||||
) -> Tuple[int, int]:
|
||||
"""Benchmarking a given graph module
|
||||
Args:
|
||||
gm (torch.fx.GraphModule): The graph module to benchmark.
|
||||
@@ -28,7 +25,7 @@ def bench(gm: torch.fx.GraphModule,
|
||||
"""
|
||||
gm.train()
|
||||
gm.cuda()
|
||||
step_time = float('inf')
|
||||
step_time = float("inf")
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
@@ -58,13 +55,15 @@ def bench(gm: torch.fx.GraphModule,
|
||||
return peak_mem, step_time * 1.0e3
|
||||
|
||||
|
||||
def bench_rotor(gm: torch.fx.GraphModule,
|
||||
criterion: torch.nn.Module,
|
||||
data_gen: Callable,
|
||||
num_steps: int = 5,
|
||||
sample_points: int = 20,
|
||||
free_memory: int = torch.cuda.mem_get_info()[0],
|
||||
start_factor: int = 4) -> Tuple[np.array, list, list]:
|
||||
def bench_rotor(
|
||||
gm: torch.fx.GraphModule,
|
||||
criterion: torch.nn.Module,
|
||||
data_gen: Callable,
|
||||
num_steps: int = 5,
|
||||
sample_points: int = 20,
|
||||
free_memory: int = torch.cuda.mem_get_info()[0],
|
||||
start_factor: int = 4,
|
||||
) -> Tuple[np.array, list, list]:
|
||||
"""Auto Checkpoint Rotor Algorithm benchmarking
|
||||
Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data.
|
||||
Args:
|
||||
@@ -88,7 +87,7 @@ def bench_rotor(gm: torch.fx.GraphModule,
|
||||
gm.graph = solver.solve(verbose=False)
|
||||
peak_memory, step_time = bench(gm, criterion, data_gen, num_steps=num_steps)
|
||||
except:
|
||||
peak_memory, step_time = budget / 1024**2, float('inf')
|
||||
peak_memory, step_time = budget / 1024**2, float("inf")
|
||||
peak_hist.append(peak_memory)
|
||||
step_hist.append(step_time)
|
||||
gm.graph = deepcopy(raw_graph)
|
||||
@@ -100,22 +99,27 @@ class GPTLMModel(nn.Module):
|
||||
GPT Model
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257,
|
||||
checkpoint=False):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257,
|
||||
checkpoint=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size))
|
||||
GPT2Config(
|
||||
n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
)
|
||||
if checkpoint:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
@@ -152,7 +156,7 @@ def gpt2_6b(checkpoint=False):
|
||||
return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'):
|
||||
def data_gen_gpt2(batch_size, seq_len, vocab_size, device="cuda:0"):
|
||||
"""
|
||||
Generate random data for gpt2 benchmarking
|
||||
"""
|
||||
@@ -161,7 +165,7 @@ def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'):
|
||||
return (input_ids, attention_mask), attention_mask
|
||||
|
||||
|
||||
def data_gen_resnet(batch_size, shape, device='cuda:0'):
|
||||
def data_gen_resnet(batch_size, shape, device="cuda:0"):
|
||||
"""
|
||||
Generate random data for resnet benchmarking
|
||||
"""
|
||||
|
@@ -1,13 +1,13 @@
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
setup(
|
||||
name='auto_parallel',
|
||||
version='0.0.1',
|
||||
description='',
|
||||
name="auto_parallel",
|
||||
version="0.0.1",
|
||||
description="",
|
||||
packages=find_packages(),
|
||||
install_requires=[
|
||||
'torch',
|
||||
'numpy',
|
||||
'tqdm',
|
||||
"torch",
|
||||
"numpy",
|
||||
"tqdm",
|
||||
],
|
||||
)
|
||||
|
Reference in New Issue
Block a user