[gemini]remove registered gradients hooks (#5696)

* fix gemini

fix gemini

* fix

fix
This commit is contained in:
flybird11111
2024-05-09 10:29:49 +08:00
committed by GitHub
parent 22297789ab
commit d4c5ef441e
5 changed files with 93 additions and 46 deletions

View File

@@ -26,7 +26,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
chunk_manager = model.chunk_manager
param_list = [p for p in model.parameters()]
chunk_list = chunk_manager.get_chunks(param_list)
if not model.reuse_fp16_chunk:
if not model.chunk_manager.reuse_fp16_chunk:
chunk_list = [chunk.grad_chunk for chunk in chunk_list]
for chunk in chunk_list:
chunk_manager.access_chunk(chunk)