fix sync condition (#6000)

This commit is contained in:
Tong Li
2024-08-14 11:22:39 +08:00
committed by GitHub
parent ed97d3a5d3
commit ceb1e262e7

View File

@@ -1326,8 +1326,10 @@ class HybridParallelPlugin(PipelinePluginBase):
)
# run with gradients accumulation
if model.require_grad_sync == False or (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
if (
model.require_grad_sync == False
or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
or not torch.is_grad_enabled()
):
return outputs