mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-11 06:13:59 +00:00
[hotfix] fix auto tensor placement policy (#753)
This commit is contained in:
parent
84c6700b2a
commit
8f7ce94b8e
@ -53,10 +53,9 @@ class ShardedModelV2(nn.Module):
|
|||||||
If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used.
|
If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used.
|
||||||
If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used.
|
If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used.
|
||||||
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
|
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
|
||||||
|
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
|
||||||
Defaults to 'cuda'.
|
Defaults to 'cuda'.
|
||||||
offload_config (Optional[dict], optional): We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload. Defaults to None.
|
|
||||||
gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0.
|
gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0.
|
||||||
use_memory_tracer (bool, optional): Whether to use memoty tracer. Defaults to False.
|
|
||||||
reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad.
|
reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad.
|
||||||
Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation.
|
Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation.
|
||||||
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
|
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
|
||||||
|
@ -45,7 +45,8 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
|||||||
|
|
||||||
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||||
super().__init__(None, mem_stats_collector=mem_stats_collector)
|
super().__init__(None, mem_stats_collector=mem_stats_collector)
|
||||||
self._warmup_non_model_data_ratio: float = 0.2
|
# model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||||
|
self._warmup_non_model_data_ratio: float = 0.8
|
||||||
|
|
||||||
def evict_tensors(self,
|
def evict_tensors(self,
|
||||||
hold_cuda_tensor_list: List[StatefulTensor],
|
hold_cuda_tensor_list: List[StatefulTensor],
|
||||||
|
Loading…
Reference in New Issue
Block a user