mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-30 22:24:21 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,22 +1,19 @@
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as tm
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
|
||||
from colossalai.fx import metainfo_trace
|
||||
|
||||
|
||||
def bench(gm: torch.fx.GraphModule,
|
||||
criterion: torch.nn.Module,
|
||||
data_gen: Callable,
|
||||
num_steps: int = 5) -> Tuple[int, int]:
|
||||
def bench(
|
||||
gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5
|
||||
) -> Tuple[int, int]:
|
||||
"""Benchmarking a given graph module
|
||||
Args:
|
||||
gm (torch.fx.GraphModule): The graph module to benchmark.
|
||||
@@ -28,7 +25,7 @@ def bench(gm: torch.fx.GraphModule,
|
||||
"""
|
||||
gm.train()
|
||||
gm.cuda()
|
||||
step_time = float('inf')
|
||||
step_time = float("inf")
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
@@ -58,13 +55,15 @@ def bench(gm: torch.fx.GraphModule,
|
||||
return peak_mem, step_time * 1.0e3
|
||||
|
||||
|
||||
def bench_rotor(gm: torch.fx.GraphModule,
|
||||
criterion: torch.nn.Module,
|
||||
data_gen: Callable,
|
||||
num_steps: int = 5,
|
||||
sample_points: int = 20,
|
||||
free_memory: int = torch.cuda.mem_get_info()[0],
|
||||
start_factor: int = 4) -> Tuple[np.array, list, list]:
|
||||
def bench_rotor(
|
||||
gm: torch.fx.GraphModule,
|
||||
criterion: torch.nn.Module,
|
||||
data_gen: Callable,
|
||||
num_steps: int = 5,
|
||||
sample_points: int = 20,
|
||||
free_memory: int = torch.cuda.mem_get_info()[0],
|
||||
start_factor: int = 4,
|
||||
) -> Tuple[np.array, list, list]:
|
||||
"""Auto Checkpoint Rotor Algorithm benchmarking
|
||||
Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data.
|
||||
Args:
|
||||
@@ -88,7 +87,7 @@ def bench_rotor(gm: torch.fx.GraphModule,
|
||||
gm.graph = solver.solve(verbose=False)
|
||||
peak_memory, step_time = bench(gm, criterion, data_gen, num_steps=num_steps)
|
||||
except:
|
||||
peak_memory, step_time = budget / 1024**2, float('inf')
|
||||
peak_memory, step_time = budget / 1024**2, float("inf")
|
||||
peak_hist.append(peak_memory)
|
||||
step_hist.append(step_time)
|
||||
gm.graph = deepcopy(raw_graph)
|
||||
@@ -100,22 +99,27 @@ class GPTLMModel(nn.Module):
|
||||
GPT Model
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257,
|
||||
checkpoint=False):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257,
|
||||
checkpoint=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size))
|
||||
GPT2Config(
|
||||
n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
)
|
||||
if checkpoint:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
@@ -152,7 +156,7 @@ def gpt2_6b(checkpoint=False):
|
||||
return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'):
|
||||
def data_gen_gpt2(batch_size, seq_len, vocab_size, device="cuda:0"):
|
||||
"""
|
||||
Generate random data for gpt2 benchmarking
|
||||
"""
|
||||
@@ -161,7 +165,7 @@ def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'):
|
||||
return (input_ids, attention_mask), attention_mask
|
||||
|
||||
|
||||
def data_gen_resnet(batch_size, shape, device='cuda:0'):
|
||||
def data_gen_resnet(batch_size, shape, device="cuda:0"):
|
||||
"""
|
||||
Generate random data for resnet benchmarking
|
||||
"""
|
||||
|
Reference in New Issue
Block a user