From bb4e9a311a7a32acb6370f39e6b1a3e4c250b885 Mon Sep 17 00:00:00 2001 From: HELSON Date: Wed, 11 Jan 2023 10:07:37 +0800 Subject: [PATCH] [zero] add inference mode and its unit test (#2418) --- colossalai/gemini/gemini_mgr.py | 18 ++- colossalai/nn/parallel/data_parallel.py | 23 ++++ tests/test_gemini/update/test_inference.py | 122 +++++++++++++++++++++ 3 files changed, 157 insertions(+), 6 deletions(-) create mode 100644 tests/test_gemini/update/test_inference.py diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py index 08961b958..08fc0cf92 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/gemini/gemini_mgr.py @@ -50,6 +50,17 @@ class GeminiManager: self._warmup = True self._comp_cuda_demand_time = 0 + def reset_attributes(self): + self._compute_idx = -1 + self._h2d_volume = 0 + self._d2h_volume = 0 + self._layout_time = 0 + self._evict_time = 0 + self._comp_cuda_demand_time = 0 + + def is_warmup(self): + return self._warmup + def memstats(self): """memstats @@ -73,12 +84,7 @@ class GeminiManager: if self._mem_stats_collector and self._warmup: self._mem_stats_collector.finish_collection() self._warmup = False - self._compute_idx = -1 - self._h2d_volume = 0 - self._d2h_volume = 0 - self._layout_time = 0 - self._evict_time = 0 - self._comp_cuda_demand_time = 0 + self.reset_attributes() def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: """ Adjust the layout of stateful tensors according to the information provided diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index a7d79be16..5e547059a 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -268,12 +268,35 @@ class ZeroDDP(ColoDDP): self._logger = get_dist_logger() + def _post_forward(self): + """This function is only triggered for inference. + """ + access_list = list(self.chunk_manager.accessed_chunks) + # we need to scatter all accessed chunks and move them to their original places + for chunk in access_list: + assert chunk.can_release + self.chunk_manager.release_chunk(chunk) + first_param = next(iter(chunk.tensors_info)) + self.chunk_manager.move_chunk(chunk, self.grads_device[first_param]) + assert self.chunk_manager.accessed_mem == 0 + # reset all recorded attributes + self.gemini_manager.reset_attributes() + def forward(self, *args, **kwargs): + # check whether we are in a inference mode + grad_flag = torch.is_grad_enabled() + if not grad_flag: + assert not self.gemini_manager.is_warmup(), "You should run a completed iteration as your warmup iter" + args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) self.module.zero_grad(set_to_none=True) self.gemini_manager.pre_iter(*args) with ColoParamOpHookManager.use_hooks(self.param_op_hook): outputs = self.module(*args, **kwargs) + # scatter chunks in the inference mode + if not grad_flag: + self._post_forward() + if self.force_outputs_fp32: return _cast_float(outputs, torch.float) return outputs diff --git a/tests/test_gemini/update/test_inference.py b/tests/test_gemini/update/test_inference.py new file mode 100644 index 000000000..aec945fc9 --- /dev/null +++ b/tests/test_gemini/update/test_inference.py @@ -0,0 +1,122 @@ +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer +from colossalai.nn.parallel import ZeroDDP +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx +from tests.components_to_test import run_fwd_bwd +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import debug_print, set_seed + + +def check_param(model: ZeroDDP, torch_model: torch.nn.Module): + zero_dict = model.state_dict(only_rank_0=False) + torch_dict = torch_model.state_dict() + + for key, value in torch_dict.items(): + # key is 'module.model.PARAMETER', so we truncate it + key = key[7:] + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) + assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('model_name', ['gpt2']) +def exam_inference(placement_policy, model_name: str): + set_seed(19360226) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + torch_model = model_builder().cuda() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + + init_dev = get_current_device() + with ColoInitContext(device=init_dev): + model = model_builder() + + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + p.data.copy_(torch_p.data) + + world_size = torch.distributed.get_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = False + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128) + + model.eval() + torch_model.eval() + + set_seed(dist.get_rank() * 3 + 128) + train_dataloader = iter(train_dataloader) + + def train_iter(): + input_ids, label = next(train_dataloader) + input_ids, label = input_ids.cuda(), label.cuda() + zero_optim.zero_grad() + torch_optim.zero_grad() + torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) + loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) + assert_close(torch_loss, loss) + zero_optim.step() + torch_optim.step() + check_param(model, torch_model) + + def inference_iter(): + input_ids, label = next(train_dataloader) + input_ids, label = input_ids.cuda(), label.cuda() + with torch.no_grad(): + torch_output = torch_model(input_ids) + torch_loss = criterion(torch_output.float(), label) + zero_output = model(input_ids) + zero_loss = criterion(zero_output.float(), label) + assert_close(torch_loss, zero_loss) + + train_iter() + inference_iter() + train_iter() + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_inference() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_inference(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_inference(1)