[gemini] fix missing return (#5845)

This commit is contained in:
botbw
2024-06-21 11:38:40 +08:00
committed by GitHub
parent bd3e34fef6
commit 8a5c86439a

View File

@@ -450,6 +450,7 @@ class GeminiDDP(ModelWrapper):
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
if not (master_weights) or (enable_gradient_accumulation):
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None:
self.module.zero_grad(set_to_none=True)