[booster] add warning for torch fsdp plugin doc (#3833)

This commit is contained in:
wukong1992
2023-05-25 14:00:02 +08:00
committed by GitHub
parent 84500b7799
commit 3229f93e30
3 changed files with 12 additions and 1 deletions

View File

@@ -3,10 +3,10 @@ from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import warnings
from packaging import version
from torch.distributed import ProcessGroup
if version.parse(torch.__version__) >= version.parse('1.12.0'):
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
@@ -202,6 +202,11 @@ class TorchFSDPPlugin(DPPluginBase):
# wrap the model with PyTorch FSDP
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
if len(optimizer.param_groups) > 1:
warnings.warn(
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
)
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
if not isinstance(optimizer, FSDPOptimizerWrapper):