mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[booster] add warning for torch fsdp plugin doc (#3833)
This commit is contained in:
@@ -3,10 +3,10 @@ from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import warnings
|
||||
from packaging import version
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
from torch.distributed.fsdp import FullStateDictConfig
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
@@ -202,6 +202,11 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||
|
||||
# wrap the model with PyTorch FSDP
|
||||
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
|
||||
|
||||
if len(optimizer.param_groups) > 1:
|
||||
warnings.warn(
|
||||
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
|
||||
)
|
||||
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
|
||||
|
||||
if not isinstance(optimizer, FSDPOptimizerWrapper):
|
||||
|
Reference in New Issue
Block a user