mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 18:39:56 +00:00
[zero] add inference mode and its unit test (#2418)
This commit is contained in:
@@ -268,12 +268,35 @@ class ZeroDDP(ColoDDP):
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
def _post_forward(self):
|
||||
"""This function is only triggered for inference.
|
||||
"""
|
||||
access_list = list(self.chunk_manager.accessed_chunks)
|
||||
# we need to scatter all accessed chunks and move them to their original places
|
||||
for chunk in access_list:
|
||||
assert chunk.can_release
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
first_param = next(iter(chunk.tensors_info))
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
|
||||
assert self.chunk_manager.accessed_mem == 0
|
||||
# reset all recorded attributes
|
||||
self.gemini_manager.reset_attributes()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
# check whether we are in a inference mode
|
||||
grad_flag = torch.is_grad_enabled()
|
||||
if not grad_flag:
|
||||
assert not self.gemini_manager.is_warmup(), "You should run a completed iteration as your warmup iter"
|
||||
|
||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
self.gemini_manager.pre_iter(*args)
|
||||
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
outputs = self.module(*args, **kwargs)
|
||||
# scatter chunks in the inference mode
|
||||
if not grad_flag:
|
||||
self._post_forward()
|
||||
|
||||
if self.force_outputs_fp32:
|
||||
return _cast_float(outputs, torch.float)
|
||||
return outputs
|
||||
|
Reference in New Issue
Block a user