From e6212f56cd36cd5da91038c33932468d8fca5b89 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 13 Apr 2022 09:59:05 +0800 Subject: [PATCH] [hotfix] fix memory leak in backward of sharded model (#741) --- .../zero/sharded_model/sharded_model_v2.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 028d0854c..5c5c2c421 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -303,41 +303,38 @@ class ShardedModelV2(nn.Module): assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients' if not self._require_backward_grad_sync: return - + # used to cheat Pytorch, since we can't return None + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + # As torch didn't allow modifying grad in hook, we make a copy + grad = grad.clone() if param.colo_attr.is_replicated: self._reduce_scatter_handler(param, grad) else: self._save_grad(param, grad) - - # used to cheat Pytorch, since we can't return None - empty_grad = torch.empty_like(grad) - free_storage(empty_grad) return empty_grad def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None: self.comm_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.comm_stream): - new_grad = grad.clone() if self.fp32_reduce_scatter: - new_grad.data = new_grad.data.to(param.dtype) + grad.data = grad.data.to(param.dtype) if self.gradient_predivide_factor > 1.0: # Average grad by world_size for consistency with PyTorch DDP. - new_grad.data.div_(self.gradient_predivide_factor) - orig_grad_data = new_grad.data + grad.data.div_(self.gradient_predivide_factor) if self.world_size > 1: - grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size()) + grad_chunks = chunk_and_pad(grad, self.reduce_scatter_process_group.size()) self.reducer.reduce_scatter_async(grad_chunks, group=self.reduce_scatter_process_group, callback_fn=functools.partial(self._reduce_scatter_callback, param)) else: - self._reduce_scatter_callback(param, new_grad) - orig_grad_data.record_stream(self.comm_stream) + self._reduce_scatter_callback(param, grad) torch.cuda.current_stream().wait_stream(self.comm_stream) def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None: assert isinstance(reduced_grad, torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}" - reduced_grad = reduced_grad.view(-1) + reduced_grad.data = reduced_grad.data.view(-1) if self.gradient_postdivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. reduced_grad.data.div_(self.gradient_postdivide_factor)