diff --git a/colossalai/nn/_ops/element_wise.py b/colossalai/nn/_ops/element_wise.py index f479960c5..2de51e24a 100644 --- a/colossalai/nn/_ops/element_wise.py +++ b/colossalai/nn/_ops/element_wise.py @@ -34,17 +34,15 @@ def register_elementwise_op(op): dist_attr=input_tensor.dist_spec)) -@colo_op_impl(torch.relu_) -def elementwise_op(input_tensor): - torch.relu_(input_tensor.data) - return input_tensor - - -@colo_op_impl(Tensor.add_) -def elementwise_op(input_tensor: ColoTensor, *args, **kwargs): - input_tensor = input_tensor.data.add_(*args, **kwargs) - return input_tensor +# @colo_op_impl(torch.relu_) +# def elementwise_op(input_tensor): +# torch.relu_(input_tensor.data) +# return input_tensor +# @colo_op_impl(Tensor.add_) +# def elementwise_op(input_tensor: ColoTensor, *args, **kwargs): +# input_tensor = input_tensor.data.add_(*args, **kwargs) +# return input_tensor # Tensor op register_elementwise_op(Tensor.abs) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index f47676908..78b6b499e 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -272,7 +272,7 @@ class ZeroDDP(ColoDDP): p.grad = None def _post_backward(self): - # assert self.chunk_manager.accessed_mem == 0 + assert self.chunk_manager.accessed_mem == 0 self._setup_grads_ptr() self._logger.debug( f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}' diff --git a/tests/test_gemini/update/test_fwd_bwd.py b/tests/test_gemini/update/test_fwd_bwd.py index aa2da5beb..b1a71502b 100644 --- a/tests/test_gemini/update/test_fwd_bwd.py +++ b/tests/test_gemini/update/test_fwd_bwd.py @@ -33,7 +33,7 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module): @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('keep_gather', [False, True]) -@parameterize('model_name', ['gpt2', 'bert', 'resnet18']) +@parameterize('model_name', ['gpt2', 'bert']) @parameterize('use_grad_checkpoint', [False, True]) def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False): set_seed(42) @@ -78,7 +78,7 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch torch.max(torch.abs(loss - torch_loss)).item(), loss, torch_loss) # FIXME(1SAA) bert and resnet18 can not pass the check_grad - # check_grad(model, torch_model) + check_grad(model, torch_model) def run_dist(rank, world_size, port):