fix sync condition (#6000)

This commit is contained in:
Tong Li 2024-08-14 11:22:39 +08:00 committed by GitHub
parent ed97d3a5d3
commit ceb1e262e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1326,8 +1326,10 @@ class HybridParallelPlugin(PipelinePluginBase):
) )
# run with gradients accumulation # run with gradients accumulation
if model.require_grad_sync == False or ( if (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False model.require_grad_sync == False
or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
or not torch.is_grad_enabled()
): ):
return outputs return outputs