This commit is contained in:
Tong Li 2024-08-14 07:19:34 +00:00
parent b841ded016
commit 409f4b5ab3

View File

@ -1332,7 +1332,7 @@ class HybridParallelPlugin(PipelinePluginBase):
or not torch.is_grad_enabled() or not torch.is_grad_enabled()
): ):
return outputs return outputs
print("Show torch status:", torch.is_grad_enabled())
# Synchronize the grads of shared parameters of the model. # Synchronize the grads of shared parameters of the model.
model.sync_shared_params() model.sync_shared_params()
# Synchronize sequence parallelism gradients of the model. # Synchronize sequence parallelism gradients of the model.