[auto-parallel] add auto-offload feature (#3154)

* add auto-offload feature

* polish code

* fix syn offload runtime pass bug

* add offload example

* fix offload testing bug

* fix example testing bug
This commit is contained in:
Zihao
2023-03-21 14:17:41 +08:00
committed by GitHub
parent 258b43317c
commit 18dbe76cae
18 changed files with 2833 additions and 0 deletions

View File

@@ -0,0 +1,37 @@
# Auto-Offload Demo with GPT2
## Requirements
Before you can launch training, you need to install the following requirements.
### Install PyTorch
```bash
#conda
conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
#pip
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
```
### Install [Colossal-AI v0.2.0](https://colossalai.org/download/) From Official Website
```bash
pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org
```
### Install transformers
```bash
pip install transformers
```
## Dataset
For simplicity, the input data is randonly generated here.
## Training
```bash
#Run the auto offload on GPT with default setting and a dummy dataset.
bash run.sh
```

View File

@@ -0,0 +1,65 @@
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel
class GPTLMModel(nn.Module):
def __init__(self,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_seq_len=1024,
vocab_size=50257):
super().__init__()
self.model = GPT2LMHeadModel(
GPT2Config(n_embd=hidden_size,
n_layer=num_layers,
n_head=num_attention_heads,
n_positions=max_seq_len,
n_ctx=max_seq_len,
vocab_size=vocab_size))
def forward(self, input_ids, attention_mask):
# Only return lm_logits
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
def get_gpt2_components(model_type: str, batch_size: int):
vocab_size = 1024
seq_len = 8
def gpt2_model_builder():
if model_type == "gpt2_medium":
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16)
elif model_type == "gpt2_xl":
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32)
elif model_type == "gpt2_10b":
return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16)
elif model_type == "gpt2_14b":
return GPTLMModel(hidden_size=4096, num_layers=70, num_attention_heads=16)
elif model_type == "gpt2_20b":
return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16)
elif model_type == "gpt2_24b":
return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16)
else:
raise TypeError(f"model_builder {model_type}")
def gpt2_data_gen(device="cuda"):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
attention_mask = torch.ones_like(input_ids, device=device)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs
return gpt2_model_builder, gpt2_data_gen

View File

@@ -0,0 +1,2 @@
colossalai >= 0.1.12
torch >= 1.8.1

View File

@@ -0,0 +1,8 @@
export BATCH_SIZE=${BATCH_SIZE:-64}
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
export MEMORY_BUDGET=${MEMORY_BUDGET:-16}
export SOLVER_TYPE=${SOLVER_TYPE:-"asyn"}
mkdir -p offload_logs
python train_gpt_offload.py --model_type=${MODEL_TYPE} --memory_budget=${MEMORY_BUDGET} --solver_type=${SOLVER_TYPE} --batch_size=${BATCH_SIZE} 2>&1 | tee ./offload_logs/${MODEL_TYPE}_bs_${BATCH_SIZE}_st_${SOLVER_TYPE}.log

View File

@@ -0,0 +1,94 @@
import time
import pytest
import argparse
from functools import partial
import torch
from torch.utils._pytree import tree_map
import torch.multiprocessing as mp
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.fx.profiler import parameter_size
from colossalai.utils import free_port, get_current_device
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML
from model_zoo import get_gpt2_components, GPTLMLoss
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', type=str, default="gpt2_medium")
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--solver_type', type=str, default='asyn')
parser.add_argument('--memory_budget', type=float, default=16)
return parser.parse_args()
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
def train_gpt(args):
memory_budget = args.memory_budget * 1024 * 1024 * 1024
solver_type = args.solver_type
model_type = args.model_type
batch_size = args.batch_size
# build model
model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size)
label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device())
criterion = GPTLMLoss()
start_time = time.time()
model = model_builder()
model.train()
param_size = parameter_size(model) / 1024 ** 2 / 2
init_time = time.time() - start_time
print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")
data_args = data_gen(device="cpu")
wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x
data_args = tree_map(wrap_fn, data_args)
start_time = time.time()
model = memory_optimize(model, data_args, memory_budget, solver_type)
solver_time = time.time() - start_time
print(f"solver_time={solver_time:.3f} s")
hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3)
optim = AMPOptimizer(hybrid_optimizer, model)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
time_list = []
data_args = data_gen(device="cuda")
data_args = tree_map(wrap_fn, data_args)
for step in range(10):
optim.zero_grad()
torch.cuda.synchronize()
start_time = time.time()
loss = criterion(model(**data_args), label)
optim.backward(loss)
torch.cuda.synchronize()
time_list.append(time.time() - start_time)
optim.step()
torch.cuda.synchronize()
exec_time = sum(sorted(time_list)[:5]) / 5
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
print(f'solver_type: {solver_type} | model_type: {model_type}')
print(
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
)
print(time_list)
def run(rank, world_size, port, args):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
train_gpt(args)
if __name__ == '__main__':
args = parse_args()
run_func = partial(run, world_size=1, port=free_port(), args=args)
mp.spawn(run_func, nprocs=1)