sync before creating empty grad

This commit is contained in:
ver217 2022-03-16 13:40:19 +08:00
parent ea6905a898
commit fce9432f08
2 changed files with 7 additions and 4 deletions

View File

@ -218,6 +218,7 @@ class ShardedModelV2(nn.Module):
else: else:
self._reduce_scatter_callback(param, new_grad) self._reduce_scatter_callback(param, new_grad)
orig_grad_data.record_stream(self.comm_stream) orig_grad_data.record_stream(self.comm_stream)
torch.cuda.current_stream().wait_stream(self.comm_stream)
empty_grad = torch.empty_like(grad) empty_grad = torch.empty_like(grad)
free_storage(empty_grad) free_storage(empty_grad)
return empty_grad return empty_grad

View File

@ -2,12 +2,14 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import copy import copy
from asyncio.log import logger
from functools import partial from functools import partial
import colossalai import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.logging import get_dist_logger
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
@ -18,12 +20,12 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads_padding, run_fwd_bwd from common import CONFIG, check_grads_padding, run_fwd_bwd
from colossalai.zero.sharded_model.utils import col_model_deepcopy
def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_strategy): def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
logger = get_dist_logger()
logger.set_level('DEBUG')
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy() shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
@ -60,8 +62,8 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s
check_grads_padding(model, zero_model, loose=True) check_grads_padding(model, zero_model, loose=True)
print('overall cuda ', zero_model._memstats_collector._overall_cuda) # logger.debug('overall cuda ', zero_model._memstats_collector._overall_cuda)
print('model cuda ', zero_model._memstats_collector._model_data_cuda) # logger.debug('model cuda ', zero_model._memstats_collector._model_data_cuda)
@pytest.mark.dist @pytest.mark.dist