[gemini] accelerate inference (#3641)

* [gemini] support don't scatter after inference

* [chat] update colossalai strategy

* [chat] fix opt benchmark

* [chat] update opt benchmark

* [gemini] optimize inference

* [test] add gemini inference test

* [chat] fix unit test ci

* [chat] fix ci

* [chat] fix ci

* [chat] skip checkpoint test
This commit is contained in:
Hongxin Liu
2023-04-26 16:32:40 +08:00
committed by GitHub
parent 4b3240cb59
commit 50793b35f4
13 changed files with 162 additions and 157 deletions

View File

@@ -1,5 +1,6 @@
import itertools
from collections import OrderedDict
from contextlib import nullcontext
from functools import partial
from typing import Dict, Iterator, List, Optional, Union
@@ -49,6 +50,7 @@ class ZeroDDP(ColoDDP):
Defaults to False.
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference.
"""
def __init__(self,
@@ -56,7 +58,8 @@ class ZeroDDP(ColoDDP):
gemini_manager: GeminiManager,
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False) -> None:
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True) -> None:
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
@@ -67,6 +70,7 @@ class ZeroDDP(ColoDDP):
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict()
self.scatter_after_inference = scatter_after_inference
self._logger = get_dist_logger()
@@ -108,8 +112,6 @@ class ZeroDDP(ColoDDP):
first_param = next(iter(chunk.tensors_info))
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
assert self.chunk_manager.accessed_mem == 0
# reset all recorded attributes
self.gemini_manager.reset_attributes()
def forward(self, *args, **kwargs):
# check whether we are in a inference mode
@@ -120,17 +122,35 @@ class ZeroDDP(ColoDDP):
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter(*args)
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
# scatter chunks in the inference mode
if not grad_flag:
self._post_forward()
outputs = self._inference_forward(*args, **kwargs)
else:
self.gemini_manager.pre_iter(*args)
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
if self.force_outputs_fp32:
return _cast_float(outputs, torch.float)
return outputs
def _inference_forward(self, *args, **kwargs):
"""This function is only triggered for inference.
"""
fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook)
if not self.scatter_after_inference:
# gather all chunks
for chunk in self.chunk_manager.get_chunks(self.fp16_params):
self.chunk_manager.access_chunk(chunk)
fwd_ctx = nullcontext()
with fwd_ctx:
outputs = self.module(*args, **kwargs)
if self.scatter_after_inference:
# scatter chunks
self._post_forward()
# reset all recorded attributes
self.gemini_manager.reset_attributes()
return outputs
def _setup_grads_ptr(self):
for p in self.module.parameters():
if is_ddp_ignored(p):
@@ -678,6 +698,7 @@ class GeminiDDP(ZeroDDP):
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
search_range_mb: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32,
@@ -722,4 +743,5 @@ class GeminiDDP(ZeroDDP):
strict_ddp_flag=strict_ddp_mode,
verbose=verbose)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode,
scatter_after_inference)