[gemini] accelerate inference (#3641)

* [gemini] support don't scatter after inference

* [chat] update colossalai strategy

* [chat] fix opt benchmark

* [chat] update opt benchmark

* [gemini] optimize inference

* [test] add gemini inference test

* [chat] fix unit test ci

* [chat] fix ci

* [chat] fix ci

* [chat] skip checkpoint test
This commit is contained in:
Hongxin Liu
2023-04-26 16:32:40 +08:00
committed by GitHub
parent 4b3240cb59
commit 50793b35f4
13 changed files with 162 additions and 157 deletions

View File

@@ -9,11 +9,11 @@ from . import (
resnet,
simple_net,
)
from .utils import run_fwd_bwd
from .utils import run_fwd, run_fwd_bwd
from . import albert # isort:skip
__all__ = [
'bert', 'gpt2', 'hanging_param_model', 'inline_op_model', 'nested_model', 'repeated_computed_layers', 'resnet',
'simple_net', 'run_fwd_bwd', 'albert', 'beit'
'simple_net', 'run_fwd_bwd', 'albert', 'beit', 'run_fwd'
]

View File

@@ -1,2 +1,2 @@
from .dummy_data_generator import DummyDataGenerator
from .executor import run_fwd_bwd
from .executor import run_fwd, run_fwd_bwd

View File

@@ -1,9 +1,9 @@
import torch
def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor:
"""run_fwd_bwd
run fwd and bwd for the model
def run_fwd(model, data, label, criterion) -> torch.Tensor:
"""run_fwd
run fwd for the model
Args:
model (torch.nn.Module): a PyTorch model
@@ -22,6 +22,23 @@ def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor:
loss = model(data, label)
loss = loss.float()
return loss
def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor:
"""run_fwd_bwd
run fwd and bwd for the model
Args:
model (torch.nn.Module): a PyTorch model
data (torch.Tensor): input data
label (torch.Tensor): label
criterion (Optional[Callable]): a function of criterion
Returns:
torch.Tensor: loss of fwd
"""
loss = run_fwd(model, data, label, criterion)
if optimizer:
optimizer.backward(loss)
else:

View File

@@ -12,7 +12,7 @@ from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test import run_fwd, run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import set_seed
@@ -89,10 +89,65 @@ def exam_gpt_fwd_bwd(
check_grad(model, torch_model)
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('keep_gather', [False, True])
@parameterize('model_name', ['gpt2', 'bert', 'albert'])
@parameterize('scatter_after_inference', [False, True])
def exam_gpt_inference(
placement_policy,
keep_gather,
model_name: str,
scatter_after_inference: bool = False,
):
init_device = get_current_device()
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
set_seed(42)
with ColoInitContext(device=init_device):
model = model_builder()
set_seed(42)
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True, scatter_after_inference=scatter_after_inference)
pg = ProcessGroup()
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
set_seed(pg.dp_local_rank())
model.eval()
torch_model.eval()
for i, (input_ids, label) in enumerate(train_dataloader):
# you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
if i > 0:
break
with torch.no_grad():
input_ids, label = input_ids.cuda(), label.cuda()
torch_loss = run_fwd(torch_model, input_ids, label, criterion)
loss = run_fwd(model, input_ids, label, criterion)
assert torch.equal(torch_loss, loss)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd()
exam_gpt_inference()
@pytest.mark.dist