From ceb1e262e765242c1f130aa72ab9d5e2289162be Mon Sep 17 00:00:00 2001 From: Tong Li Date: Wed, 14 Aug 2024 11:22:39 +0800 Subject: [PATCH] fix sync condition (#6000) --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index d2933a4af..e5acdb051 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1326,8 +1326,10 @@ class HybridParallelPlugin(PipelinePluginBase): ) # run with gradients accumulation - if model.require_grad_sync == False or ( - isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False + if ( + model.require_grad_sync == False + or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False) + or not torch.is_grad_enabled() ): return outputs