mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-25 18:13:56 +00:00 
			
		
		
		
	* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
		
			
				
	
	
		
			175 lines
		
	
	
		
			5.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			175 lines
		
	
	
		
			5.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import time
 | |
| from copy import deepcopy
 | |
| from typing import Callable, Tuple
 | |
| 
 | |
| import numpy as np
 | |
| import torch
 | |
| import torch.nn as nn
 | |
| from transformers import GPT2Config, GPT2LMHeadModel
 | |
| 
 | |
| from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
 | |
| from colossalai.fx import metainfo_trace
 | |
| 
 | |
| 
 | |
| def bench(
 | |
|     gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5
 | |
| ) -> Tuple[int, int]:
 | |
|     """Benchmarking a given graph module
 | |
|     Args:
 | |
|         gm (torch.fx.GraphModule): The graph module to benchmark.
 | |
|         criterion (torch.nn.Module): Loss function.
 | |
|         data_gen (Callable): Data generator.
 | |
|         num_steps (int, optional): Number of test steps. Defaults to 5.
 | |
|     Returns:
 | |
|         Tuple[int, int]: peak memory in MB and step time in MS.
 | |
|     """
 | |
|     gm.train()
 | |
|     gm.cuda()
 | |
|     step_time = float("inf")
 | |
|     torch.cuda.synchronize()
 | |
|     torch.cuda.empty_cache()
 | |
|     torch.cuda.reset_peak_memory_stats()
 | |
|     cached = torch.cuda.max_memory_allocated(device="cuda")
 | |
|     try:
 | |
|         for _ in range(num_steps):
 | |
|             args, label = data_gen()
 | |
|             output, loss = None, None
 | |
| 
 | |
|             torch.cuda.synchronize(device="cuda")
 | |
|             start = time.time()
 | |
|             output = gm(*args)
 | |
|             loss = criterion(output, label)
 | |
|             loss.backward()
 | |
|             torch.cuda.synchronize(device="cuda")
 | |
|             step_time = min(step_time, time.time() - start)
 | |
| 
 | |
|             for child in gm.children():
 | |
|                 for param in child.parameters():
 | |
|                     param.grad = None
 | |
|             del args, label, output, loss
 | |
|     except:
 | |
|         del args, label, output, loss
 | |
|     gm.to("cpu")
 | |
|     torch.cuda.empty_cache()
 | |
|     peak_mem = (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2
 | |
|     return peak_mem, step_time * 1.0e3
 | |
| 
 | |
| 
 | |
| def bench_rotor(
 | |
|     gm: torch.fx.GraphModule,
 | |
|     criterion: torch.nn.Module,
 | |
|     data_gen: Callable,
 | |
|     num_steps: int = 5,
 | |
|     sample_points: int = 20,
 | |
|     free_memory: int = torch.cuda.mem_get_info()[0],
 | |
|     start_factor: int = 4,
 | |
| ) -> Tuple[np.array, list, list]:
 | |
|     """Auto Checkpoint Rotor Algorithm benchmarking
 | |
|     Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data.
 | |
|     Args:
 | |
|         gm (torch.fx.GraphModule): The graph module to benchmark.
 | |
|         criterion (torch.nn.Module): Loss function.
 | |
|         data_gen (Callable): Data generator.
 | |
|         num_steps (int, optional): Number of test steps. Defaults to 5.
 | |
|         sample_points (int, optional): Number of sample points. Defaults to 20.
 | |
|         free_memory (int, optional): Max memory budget in Byte. Defaults to torch.cuda.mem_get_info()[0].
 | |
|         start_factor (int, optional): Start memory budget factor for benchmark, the start memory budget
 | |
|         will be free_memory / start_factor. Defaults to 4.
 | |
|     Returns:
 | |
|         Tuple[np.array, list, list]: return budgets vector (MB), peak memory vector (MB), step time vector (MS).
 | |
|     """
 | |
|     peak_hist, step_hist = [], []
 | |
|     raw_graph = deepcopy(gm.graph)
 | |
|     for budget in np.linspace(free_memory // start_factor, free_memory, sample_points):
 | |
|         gm = metainfo_trace(gm, *data_gen()[0])
 | |
|         solver = CheckpointSolverRotor(gm.graph, free_memory=budget)
 | |
|         try:
 | |
|             gm.graph = solver.solve(verbose=False)
 | |
|             peak_memory, step_time = bench(gm, criterion, data_gen, num_steps=num_steps)
 | |
|         except:
 | |
|             peak_memory, step_time = budget / 1024**2, float("inf")
 | |
|         peak_hist.append(peak_memory)
 | |
|         step_hist.append(step_time)
 | |
|         gm.graph = deepcopy(raw_graph)
 | |
|     return np.linspace(free_memory // start_factor, free_memory, sample_points) / 1024**2, peak_hist, step_hist
 | |
| 
 | |
| 
 | |
| class GPTLMModel(nn.Module):
 | |
|     """
 | |
|     GPT Model
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         hidden_size=768,
 | |
|         num_layers=12,
 | |
|         num_attention_heads=12,
 | |
|         max_seq_len=1024,
 | |
|         vocab_size=50257,
 | |
|         checkpoint=False,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         self.checkpoint = checkpoint
 | |
|         self.model = GPT2LMHeadModel(
 | |
|             GPT2Config(
 | |
|                 n_embd=hidden_size,
 | |
|                 n_layer=num_layers,
 | |
|                 n_head=num_attention_heads,
 | |
|                 n_positions=max_seq_len,
 | |
|                 n_ctx=max_seq_len,
 | |
|                 vocab_size=vocab_size,
 | |
|             )
 | |
|         )
 | |
|         if checkpoint:
 | |
|             self.model.gradient_checkpointing_enable()
 | |
| 
 | |
|     def forward(self, input_ids, attention_mask):
 | |
|         # Only return lm_logits
 | |
|         return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
 | |
| 
 | |
| 
 | |
| class GPTLMLoss(nn.Module):
 | |
|     """
 | |
|     GPT Loss
 | |
|     """
 | |
| 
 | |
|     def __init__(self):
 | |
|         super().__init__()
 | |
|         self.loss_fn = nn.CrossEntropyLoss()
 | |
| 
 | |
|     def forward(self, logits, labels):
 | |
|         shift_logits = logits[..., :-1, :].contiguous()
 | |
|         shift_labels = labels[..., 1:].contiguous()
 | |
|         # Flatten the tokens
 | |
|         return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
 | |
| 
 | |
| 
 | |
| def gpt2_medium(checkpoint=False):
 | |
|     return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
 | |
| 
 | |
| 
 | |
| def gpt2_xl(checkpoint=False):
 | |
|     return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)
 | |
| 
 | |
| 
 | |
| def gpt2_6b(checkpoint=False):
 | |
|     return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
 | |
| 
 | |
| 
 | |
| def data_gen_gpt2(batch_size, seq_len, vocab_size, device="cuda:0"):
 | |
|     """
 | |
|     Generate random data for gpt2 benchmarking
 | |
|     """
 | |
|     input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
 | |
|     attention_mask = torch.ones_like(input_ids, device=device)
 | |
|     return (input_ids, attention_mask), attention_mask
 | |
| 
 | |
| 
 | |
| def data_gen_resnet(batch_size, shape, device="cuda:0"):
 | |
|     """
 | |
|     Generate random data for resnet benchmarking
 | |
|     """
 | |
|     data = torch.empty(batch_size, *shape, device=device)
 | |
|     label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
 | |
|     return (data,), label
 |