mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-18 01:12:42 +00:00
fixed ddp bug on torch 1.8 (#194)
This commit is contained in:
parent
569357fea0
commit
765db512b5
@ -348,12 +348,12 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||||||
"added even though not specified in the configuration",
|
"added even though not specified in the configuration",
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
elif is_using_sequence():
|
elif is_using_sequence():
|
||||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP))
|
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), device_ids=[torch.cuda.current_device()])
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.info(
|
logger.info(
|
||||||
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0])
|
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0])
|
||||||
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
||||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
|
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA), device_ids=[torch.cuda.current_device()])
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.info(
|
logger.info(
|
||||||
'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
|
'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
|
||||||
|
Loading…
Reference in New Issue
Block a user