[gemini] use compute_chunk to find next chunk

This commit is contained in:
hxwang
2024-05-16 13:17:26 +08:00
parent b2e9745888
commit 4148ceed9f
5 changed files with 52 additions and 79 deletions

View File

@@ -133,6 +133,7 @@ class GeminiDDP(ModelWrapper):
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
)
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(self.gemini_manager, max_prefetch=max_prefetch)
self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
@@ -157,8 +158,6 @@ class GeminiDDP(ModelWrapper):
for p in module.parameters():
param_order.append(p)
self.param_op_hook = GeminiZeROHook(self.gemini_manager, param_order=param_order, max_prefetch=max_prefetch)
for name, param in module.named_parameters():
self.param2name[param] = name
for m_name, m_var in module.named_modules():