mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[chore] sync
This commit is contained in:
@@ -329,6 +329,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
chunk_init_device: Optional[torch.device] = None,
|
||||
placement_policy: str = "static",
|
||||
enable_gradient_accumulation: bool = False,
|
||||
max_prefetch:int = 0,
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
@@ -386,6 +387,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
memstats=memstats,
|
||||
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
|
||||
master_weights=master_weights,
|
||||
max_prefetch=max_prefetch,
|
||||
)
|
||||
self.zero_optim_config = dict(
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||
|
Reference in New Issue
Block a user