From fce9432f08b80ab4b4bcdd75f221e4908b96dafc Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 16 Mar 2022 13:40:19 +0800 Subject: [PATCH] sync before creating empty grad --- colossalai/zero/sharded_model/sharded_model_v2.py | 1 + tests/test_zero_data_parallel/test_shard_model_v2.py | 10 ++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 594a14ec7..c33aa9599 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -218,6 +218,7 @@ class ShardedModelV2(nn.Module): else: self._reduce_scatter_callback(param, new_grad) orig_grad_data.record_stream(self.comm_stream) + torch.cuda.current_stream().wait_stream(self.comm_stream) empty_grad = torch.empty_like(grad) free_storage(empty_grad) return empty_grad diff --git a/tests/test_zero_data_parallel/test_shard_model_v2.py b/tests/test_zero_data_parallel/test_shard_model_v2.py index a2cae3ee5..c6afb75a3 100644 --- a/tests/test_zero_data_parallel/test_shard_model_v2.py +++ b/tests/test_zero_data_parallel/test_shard_model_v2.py @@ -2,12 +2,14 @@ # -*- encoding: utf-8 -*- import copy +from asyncio.log import logger from functools import partial import colossalai import pytest import torch import torch.multiprocessing as mp +from colossalai.logging import get_dist_logger from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) @@ -18,12 +20,12 @@ from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP from common import CONFIG, check_grads_padding, run_fwd_bwd -from colossalai.zero.sharded_model.utils import col_model_deepcopy def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_strategy): colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - + logger = get_dist_logger() + logger.set_level('DEBUG') test_models = ['repeated_computed_layers', 'resnet18', 'bert'] shard_strategy = shard_strategy() for model_name in test_models: @@ -60,8 +62,8 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s check_grads_padding(model, zero_model, loose=True) - print('overall cuda ', zero_model._memstats_collector._overall_cuda) - print('model cuda ', zero_model._memstats_collector._model_data_cuda) + # logger.debug('overall cuda ', zero_model._memstats_collector._overall_cuda) + # logger.debug('model cuda ', zero_model._memstats_collector._model_data_cuda) @pytest.mark.dist