mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[gemini] prefetch chunks
This commit is contained in:
@@ -78,6 +78,7 @@ class GeminiDDP(ModelWrapper):
|
||||
chunk_init_device: torch.device = torch.device("cpu"),
|
||||
placement_policy: str = "static",
|
||||
enable_gradient_accumulation: bool = False,
|
||||
max_prefetch: int = 0,
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
@@ -132,7 +133,6 @@ class GeminiDDP(ModelWrapper):
|
||||
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
|
||||
)
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
||||
self.fp32_params: List[torch.Tensor] = list()
|
||||
self.fp16_params: List[ColoParameter] = list()
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||
@@ -157,6 +157,8 @@ class GeminiDDP(ModelWrapper):
|
||||
for p in module.parameters():
|
||||
param_order.append(p)
|
||||
|
||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager, param_order=param_order, max_prefetch=max_prefetch)
|
||||
|
||||
for name, param in module.named_parameters():
|
||||
self.param2name[param] = name
|
||||
for m_name, m_var in module.named_modules():
|
||||
|
Reference in New Issue
Block a user