diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 3e61c80f2..9329dc052 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -348,12 +348,12 @@ def initialize(model: Union[nn.Module, List[nn.Module]], "added even though not specified in the configuration", ranks=[0]) elif is_using_sequence(): - model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP)) + model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), device_ids=[torch.cuda.current_device()]) if verbose: logger.info( 'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0]) elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE: - model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA)) + model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA), device_ids=[torch.cuda.current_device()]) if verbose: logger.info( 'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])