[auto-parallel] add auto-offload feature (#3154)

* add auto-offload feature

* polish code

* fix syn offload runtime pass bug

* add offload example

* fix offload testing bug

* fix example testing bug
This commit is contained in:
Zihao
2023-03-21 14:17:41 +08:00
committed by GitHub
parent 258b43317c
commit 18dbe76cae
18 changed files with 2833 additions and 0 deletions

View File

@@ -0,0 +1,86 @@
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel
from transformers import BertConfig, BertLMHeadModel
from tests.components_to_test.registry import non_distributed_component_funcs
class GPTLMModel(nn.Module):
def __init__(self,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_seq_len=1024,
vocab_size=50257):
super().__init__()
self.model = GPT2LMHeadModel(
GPT2Config(n_embd=hidden_size,
n_layer=num_layers,
n_head=num_attention_heads,
n_positions=max_seq_len,
n_ctx=max_seq_len,
vocab_size=vocab_size))
def forward(self, input_ids, attention_mask):
# Only return lm_logits
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
class LMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
class BertLMModel(nn.Module):
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=32, vocab_size=30522):
super().__init__()
self.model = BertLMHeadModel(BertConfig(n_embd=hidden_size, num_hidden_layers=num_layers, hidden_size=hidden_size,
num_attention_heads=num_attention_heads, max_position_embeddings=hidden_size,
vocab_size=vocab_size))
def forward(self, input_ids, attention_mask):
# Only return lm_logits
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
@non_distributed_component_funcs.register(name='bert_')
def get_bert_components():
vocab_size = 1024
seq_len = 64
batchSize = 64
def bert_model_builder():
model = BertLMModel(hidden_size=8192, num_layers=4, num_attention_heads=32, vocab_size=vocab_size)
return model
def bert_data_gen(device="meta"):
input_ids = torch.randint(0, vocab_size, (batchSize, seq_len), device=device)
attention_mask = torch.ones_like(input_ids, device=device)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs
return bert_model_builder, bert_data_gen
@non_distributed_component_funcs.register(name='gpt2_')
def get_gpt2_components():
vocab_size = 1024
seq_len = 8
batchSize = 64
def gpt2_model_builder():
model = GPTLMModel(hidden_size=8192, num_layers=2, num_attention_heads=32, vocab_size=vocab_size)
return model
def gpt2_data_gen(device="meta"):
input_ids = torch.randint(0, vocab_size, (batchSize, seq_len), device=device)
attention_mask = torch.ones_like(input_ids, device=device)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs
return gpt2_model_builder, gpt2_data_gen

View File

@@ -0,0 +1,150 @@
import time
import pytest
from functools import partial
import torch
from torch.utils._pytree import tree_map
import torch.multiprocessing as mp
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.fx.profiler import parameter_size
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.utils import free_port, get_current_device
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML
from colossalai.testing import parameterize
from tests.test_tensor.common_utils import set_seed
from tests.test_auto_parallel.test_offload.model_utils import *
@parameterize('model_name', ['gpt2_'])
@parameterize('memory_budget', [5000])
@parameterize('solver_name', ['asyn'])
def exam_fwd_bwd(
model_name: str,
memory_budget: float,
solver_name: str
):
# build model
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, data_gen = get_components_func()
label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device())
criterion = LMLoss()
set_seed(42)
start_time = time.time()
model = model_builder()
model.train()
param_size = parameter_size(model) / 1024 ** 2 / 2
init_time = time.time() - start_time
print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")
data_args = data_gen(device="cpu")
wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x
data_args = tree_map(wrap_fn, data_args)
start_time = time.time()
model = memory_optimize(model, data_args, memory_budget * 1024 * 1024, solver_name)
solver_time = time.time() - start_time
print(f"solver_time={solver_time:.3f} s")
hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3)
optim = AMPOptimizer(hybrid_optimizer, model)
with ColoInitContext(device=torch.device('cpu')):
gemini_model = model_builder()
gemini_model.train()
hybrid_optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
gemini_config = dict(strict_ddp_mode=False,
device=torch.device('cpu'),
placement_policy='cpu',
pin_memory=True,
hidden_dim=8192,
search_range_mb=128)
gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config)
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
# test gemini
time_list = []
set_seed(42)
data_args = data_gen(device="cuda")
for step in range(10):
gemini_optim.zero_grad()
torch.cuda.synchronize()
start_time = time.time()
gemini_out = gemini_model(**data_args)
gemini_loss = criterion(gemini_out, label)
gemini_optim.backward(gemini_loss)
torch.cuda.synchronize()
time_list.append(time.time() - start_time)
gemini_optim.step()
torch.cuda.synchronize()
exec_time = sum(sorted(time_list)[:5]) / 5
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
print(f'gemini | model_name: {model_name}')
print(
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
)
print(time_list)
del data_args
del gemini_model
del gemini_optim
del gemini_out
del gemini_loss
# test asyn offload
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
time_list = []
set_seed(42)
data_args = data_gen(device="cuda")
data_args = tree_map(wrap_fn, data_args)
for step in range(10):
optim.zero_grad()
torch.cuda.synchronize()
start_time = time.time()
loss = criterion(model(**data_args), label)
optim.backward(loss)
torch.cuda.synchronize()
time_list.append(time.time() - start_time)
optim.step()
torch.cuda.synchronize()
exec_time = sum(sorted(time_list)[:5]) / 5
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
print(f'solver_name: {solver_name} | model_name: {model_name}')
print(
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
)
print(time_list)
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
def test_perf(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_fwd_bwd()
if __name__ == '__main__':
run_func = partial(test_perf, world_size=1, port=free_port())
mp.spawn(run_func, nprocs=1)

View File

@@ -0,0 +1,62 @@
import pytest
import torch.fx
from torch.fx import GraphModule
from torch.utils._pytree import tree_map
from colossalai.fx import ColoTracer, is_compatible_with_meta
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.auto_parallel.offload.region_manager import RegionManager
from colossalai.auto_parallel.offload.solver import SolverFactory, NOT_NVML
from colossalai.testing import parameterize
from tests.test_auto_parallel.test_offload.model_utils import *
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
@parameterize('model_name', ['gpt2_', 'bert_'])
@parameterize('memory_budget', [4000])
@parameterize('solver_name', ['syn', 'asyn'])
def solver_test(model_name: str,
memory_budget: float,
solver_name: str):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, data_gen = get_components_func()
data_args = data_gen(device="cpu")
wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x
data_args = tree_map(wrap_fn, data_args)
model = model_builder()
model.train()
model = model.cpu().half()
tracer = ColoTracer()
assert is_compatible_with_meta()
wrap_fn = lambda x: x.to("meta") if isinstance(x, torch.Tensor) else x
meta_args = tree_map(wrap_fn, data_args)
graph = tracer.trace(model, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__)
interp = MetaInfoProp(gm)
interp.propagate(*meta_args.values())
region_manager = RegionManager(graph, solver_name=solver_name)
region_manager._pre_process()
region_list = region_manager.region_list
solver_cls = SolverFactory.create(solver_name)
memory_budget = memory_budget * 1024 * 1024
solver = solver_cls(region_list, memory_budget)
solver._call_solver()
assert solver.best_ts.peak_mem < memory_budget
print("****************** execution plan *******************")
for region in region_list:
need_offload = region.need_offload
to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None
print(f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}')
for region in region_list.__reversed__():
need_offload = region.need_offload
to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None
print(f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}')
if __name__ == '__main__':
solver_test()