mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[auto-parallel] add auto-offload feature (#3154)
* add auto-offload feature * polish code * fix syn offload runtime pass bug * add offload example * fix offload testing bug * fix example testing bug
This commit is contained in:
37
examples/language/gpt/experiments/auto_offload/README.md
Normal file
37
examples/language/gpt/experiments/auto_offload/README.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Auto-Offload Demo with GPT2
|
||||
|
||||
## Requirements
|
||||
|
||||
Before you can launch training, you need to install the following requirements.
|
||||
|
||||
### Install PyTorch
|
||||
|
||||
```bash
|
||||
#conda
|
||||
conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
|
||||
#pip
|
||||
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
|
||||
```
|
||||
|
||||
### Install [Colossal-AI v0.2.0](https://colossalai.org/download/) From Official Website
|
||||
|
||||
```bash
|
||||
pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org
|
||||
```
|
||||
|
||||
### Install transformers
|
||||
|
||||
```bash
|
||||
pip install transformers
|
||||
```
|
||||
|
||||
## Dataset
|
||||
|
||||
For simplicity, the input data is randonly generated here.
|
||||
|
||||
## Training
|
||||
|
||||
```bash
|
||||
#Run the auto offload on GPT with default setting and a dummy dataset.
|
||||
bash run.sh
|
||||
```
|
65
examples/language/gpt/experiments/auto_offload/model_zoo.py
Normal file
65
examples/language/gpt/experiments/auto_offload/model_zoo.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257):
|
||||
super().__init__()
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size))
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
def get_gpt2_components(model_type: str, batch_size: int):
|
||||
vocab_size = 1024
|
||||
seq_len = 8
|
||||
|
||||
def gpt2_model_builder():
|
||||
if model_type == "gpt2_medium":
|
||||
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16)
|
||||
elif model_type == "gpt2_xl":
|
||||
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32)
|
||||
elif model_type == "gpt2_10b":
|
||||
return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16)
|
||||
elif model_type == "gpt2_14b":
|
||||
return GPTLMModel(hidden_size=4096, num_layers=70, num_attention_heads=16)
|
||||
elif model_type == "gpt2_20b":
|
||||
return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16)
|
||||
elif model_type == "gpt2_24b":
|
||||
return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16)
|
||||
else:
|
||||
raise TypeError(f"model_builder {model_type}")
|
||||
|
||||
def gpt2_data_gen(device="cuda"):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
||||
attention_mask = torch.ones_like(input_ids, device=device)
|
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return kwargs
|
||||
|
||||
return gpt2_model_builder, gpt2_data_gen
|
@@ -0,0 +1,2 @@
|
||||
colossalai >= 0.1.12
|
||||
torch >= 1.8.1
|
8
examples/language/gpt/experiments/auto_offload/run.sh
Normal file
8
examples/language/gpt/experiments/auto_offload/run.sh
Normal file
@@ -0,0 +1,8 @@
|
||||
export BATCH_SIZE=${BATCH_SIZE:-64}
|
||||
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
|
||||
export MEMORY_BUDGET=${MEMORY_BUDGET:-16}
|
||||
export SOLVER_TYPE=${SOLVER_TYPE:-"asyn"}
|
||||
|
||||
mkdir -p offload_logs
|
||||
|
||||
python train_gpt_offload.py --model_type=${MODEL_TYPE} --memory_budget=${MEMORY_BUDGET} --solver_type=${SOLVER_TYPE} --batch_size=${BATCH_SIZE} 2>&1 | tee ./offload_logs/${MODEL_TYPE}_bs_${BATCH_SIZE}_st_${SOLVER_TYPE}.log
|
@@ -0,0 +1,94 @@
|
||||
import time
|
||||
import pytest
|
||||
import argparse
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
|
||||
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
|
||||
from colossalai.auto_parallel.offload.solver import NOT_NVML
|
||||
from model_zoo import get_gpt2_components, GPTLMLoss
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_type', type=str, default="gpt2_medium")
|
||||
parser.add_argument('--batch_size', type=int, default=64)
|
||||
parser.add_argument('--solver_type', type=str, default='asyn')
|
||||
parser.add_argument('--memory_budget', type=float, default=16)
|
||||
return parser.parse_args()
|
||||
|
||||
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
|
||||
def train_gpt(args):
|
||||
memory_budget = args.memory_budget * 1024 * 1024 * 1024
|
||||
solver_type = args.solver_type
|
||||
model_type = args.model_type
|
||||
batch_size = args.batch_size
|
||||
|
||||
# build model
|
||||
model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size)
|
||||
label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device())
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
start_time = time.time()
|
||||
model = model_builder()
|
||||
model.train()
|
||||
param_size = parameter_size(model) / 1024 ** 2 / 2
|
||||
init_time = time.time() - start_time
|
||||
print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")
|
||||
|
||||
data_args = data_gen(device="cpu")
|
||||
wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x
|
||||
data_args = tree_map(wrap_fn, data_args)
|
||||
start_time = time.time()
|
||||
model = memory_optimize(model, data_args, memory_budget, solver_type)
|
||||
solver_time = time.time() - start_time
|
||||
print(f"solver_time={solver_time:.3f} s")
|
||||
|
||||
hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3)
|
||||
optim = AMPOptimizer(hybrid_optimizer, model)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
time_list = []
|
||||
data_args = data_gen(device="cuda")
|
||||
data_args = tree_map(wrap_fn, data_args)
|
||||
for step in range(10):
|
||||
optim.zero_grad()
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
loss = criterion(model(**data_args), label)
|
||||
optim.backward(loss)
|
||||
torch.cuda.synchronize()
|
||||
time_list.append(time.time() - start_time)
|
||||
optim.step()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
exec_time = sum(sorted(time_list)[:5]) / 5
|
||||
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
|
||||
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
|
||||
print(f'solver_type: {solver_type} | model_type: {model_type}')
|
||||
print(
|
||||
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
|
||||
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
|
||||
)
|
||||
print(time_list)
|
||||
|
||||
def run(rank, world_size, port, args):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
train_gpt(args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
run_func = partial(run, world_size=1, port=free_port(), args=args)
|
||||
mp.spawn(run_func, nprocs=1)
|
Reference in New Issue
Block a user