mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[gemini] accelerate inference (#3641)
* [gemini] support don't scatter after inference * [chat] update colossalai strategy * [chat] fix opt benchmark * [chat] update opt benchmark * [gemini] optimize inference * [test] add gemini inference test * [chat] fix unit test ci * [chat] fix ci * [chat] fix ci * [chat] skip checkpoint test
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Dict, Iterator, List, Optional, Union
|
||||
|
||||
@@ -49,6 +50,7 @@ class ZeroDDP(ColoDDP):
|
||||
Defaults to False.
|
||||
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
|
||||
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
|
||||
scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -56,7 +58,8 @@ class ZeroDDP(ColoDDP):
|
||||
gemini_manager: GeminiManager,
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False) -> None:
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True) -> None:
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
@@ -67,6 +70,7 @@ class ZeroDDP(ColoDDP):
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||
self.param2name: Dict[nn.Parameter, str] = dict()
|
||||
self.name2param: Dict[str, nn.Parameter] = dict()
|
||||
self.scatter_after_inference = scatter_after_inference
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
@@ -108,8 +112,6 @@ class ZeroDDP(ColoDDP):
|
||||
first_param = next(iter(chunk.tensors_info))
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
|
||||
assert self.chunk_manager.accessed_mem == 0
|
||||
# reset all recorded attributes
|
||||
self.gemini_manager.reset_attributes()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
# check whether we are in a inference mode
|
||||
@@ -120,17 +122,35 @@ class ZeroDDP(ColoDDP):
|
||||
|
||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
self.gemini_manager.pre_iter(*args)
|
||||
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
outputs = self.module(*args, **kwargs)
|
||||
# scatter chunks in the inference mode
|
||||
if not grad_flag:
|
||||
self._post_forward()
|
||||
outputs = self._inference_forward(*args, **kwargs)
|
||||
else:
|
||||
self.gemini_manager.pre_iter(*args)
|
||||
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
outputs = self.module(*args, **kwargs)
|
||||
|
||||
if self.force_outputs_fp32:
|
||||
return _cast_float(outputs, torch.float)
|
||||
return outputs
|
||||
|
||||
def _inference_forward(self, *args, **kwargs):
|
||||
"""This function is only triggered for inference.
|
||||
"""
|
||||
fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook)
|
||||
if not self.scatter_after_inference:
|
||||
# gather all chunks
|
||||
for chunk in self.chunk_manager.get_chunks(self.fp16_params):
|
||||
self.chunk_manager.access_chunk(chunk)
|
||||
fwd_ctx = nullcontext()
|
||||
with fwd_ctx:
|
||||
outputs = self.module(*args, **kwargs)
|
||||
if self.scatter_after_inference:
|
||||
# scatter chunks
|
||||
self._post_forward()
|
||||
# reset all recorded attributes
|
||||
self.gemini_manager.reset_attributes()
|
||||
return outputs
|
||||
|
||||
def _setup_grads_ptr(self):
|
||||
for p in self.module.parameters():
|
||||
if is_ddp_ignored(p):
|
||||
@@ -678,6 +698,7 @@ class GeminiDDP(ZeroDDP):
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True,
|
||||
search_range_mb: int = 32,
|
||||
hidden_dim: Optional[int] = None,
|
||||
min_chunk_size_mb: float = 32,
|
||||
@@ -722,4 +743,5 @@ class GeminiDDP(ZeroDDP):
|
||||
strict_ddp_flag=strict_ddp_mode,
|
||||
verbose=verbose)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)
|
||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode,
|
||||
scatter_after_inference)
|
||||
|
Reference in New Issue
Block a user