[chore] refactor & sync

This commit is contained in:
hxwang
2024-05-16 07:22:10 +00:00
parent 4148ceed9f
commit 2e68eebdfe
7 changed files with 82 additions and 46 deletions

View File

@@ -131,9 +131,10 @@ class GeminiDDP(ModelWrapper):
offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
max_prefetch=max_prefetch
)
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(self.gemini_manager, max_prefetch=max_prefetch)
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.grads_device: Dict[torch.Tensor, torch.device] = dict()