From 765db512b5c7b1ba20d90f3aa4071f25c7afea7a Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 28 Jan 2022 15:14:04 +0800 Subject: [PATCH] fixed ddp bug on torch 1.8 (#194) --- colossalai/initialize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 3e61c80f2..9329dc052 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -348,12 +348,12 @@ def initialize(model: Union[nn.Module, List[nn.Module]], "added even though not specified in the configuration", ranks=[0]) elif is_using_sequence(): - model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP)) + model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), device_ids=[torch.cuda.current_device()]) if verbose: logger.info( 'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0]) elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE: - model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA)) + model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA), device_ids=[torch.cuda.current_device()]) if verbose: logger.info( 'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])