[fix] remove unnecessary dp_size assert (#5351)

* fix: remove unnecessary assert

* test: add more 3d plugin tests

* fix: add warning
This commit is contained in:
Wenhao Chen
2024-02-02 14:40:20 +08:00
committed by GitHub
parent ffffc32dc7
commit 1c790c0877
2 changed files with 21 additions and 1 deletions

View File

@@ -1,5 +1,6 @@
import ctypes
import random
import warnings
from contextlib import contextmanager
from functools import partial
from types import MethodType
@@ -1134,7 +1135,12 @@ class HybridParallelPlugin(PipelinePluginBase):
tp_process_group=self.tp_group,
)
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
if self.dp_size == 1:
warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you are not intended to use cpu_offload, please consider set zero_stage=0."
)
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(
optimizer,