mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[gemini] improve compatibility and add static placement policy (#4479)
* [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import functools
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from time import time
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
@@ -7,6 +8,7 @@ import torch
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.zero.gemini.chunk import Chunk
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .memory_tracer import ChunkMemStatsCollector
|
||||
@@ -17,7 +19,8 @@ class PlacementPolicy(ABC):
|
||||
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
**kwargs) -> None:
|
||||
self.chunk_manager = chunk_manager
|
||||
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
|
||||
|
||||
@@ -25,57 +28,87 @@ class PlacementPolicy(ABC):
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_default_device() -> torch.device:
|
||||
return torch.device('cpu')
|
||||
@abstractmethod
|
||||
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
|
||||
torch.device]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CPUPlacementPolicy(PlacementPolicy):
|
||||
class StaticPlacementPolicy(PlacementPolicy):
|
||||
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
shard_param_frac: float = 1.0,
|
||||
offload_optim_frac: float = 0.0,
|
||||
offload_param_frac: float = 0.0,
|
||||
**kwargs) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
|
||||
warnings.warn('offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0')
|
||||
offload_param_frac = 0.0
|
||||
self.shard_param_frac = shard_param_frac
|
||||
self.offload_optim_frac = offload_optim_frac
|
||||
self.offload_param_frac = offload_param_frac
|
||||
# these should be initialized in setup_grads_device
|
||||
self.keep_gathered_chunk_mem = 0.0
|
||||
self.keep_cuda_chunk_mem = 0.0
|
||||
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||
volume = 0
|
||||
start = time()
|
||||
can_shard_chunk_mem = sum(chunk.chunk_mem for chunk in can_evict_chunks)
|
||||
can_offload_chunk_mem = can_shard_chunk_mem
|
||||
for chunk in can_evict_chunks:
|
||||
if can_shard_chunk_mem <= self.keep_gathered_chunk_mem:
|
||||
break
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
# real saved mem is chunk_mem - shard_mem, for simplicity we use chunk_mem
|
||||
can_shard_chunk_mem -= chunk.chunk_mem
|
||||
for chunk in can_evict_chunks:
|
||||
if can_offload_chunk_mem <= self.keep_cuda_chunk_mem:
|
||||
break
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
volume += chunk.chunk_mem
|
||||
return volume, time() - start
|
||||
# real saved mem is shard_mem, for simplicity we use chunk_mem
|
||||
can_offload_chunk_mem -= chunk.chunk_mem
|
||||
return 0, 0.0
|
||||
|
||||
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
|
||||
torch.device]) -> None:
|
||||
total_chunk_mem = sum(self.chunk_manager.get_chunk(p).chunk_mem for p in params)
|
||||
|
||||
class CUDAPlacementPolicy(PlacementPolicy):
|
||||
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||
return 0, 0
|
||||
|
||||
@staticmethod
|
||||
def get_default_device() -> torch.device:
|
||||
return get_current_device()
|
||||
offload_optim_chunk_mem = total_chunk_mem * self.offload_optim_frac
|
||||
offloaded_optim_chunk_mem = 0
|
||||
chunks = set(self.chunk_manager.get_chunk(p) for p in params)
|
||||
for chunk in chunks:
|
||||
params = chunk.get_tensors()
|
||||
# init offload optim settings
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem:
|
||||
device = get_current_device()
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
# real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here
|
||||
offloaded_optim_chunk_mem += chunk.chunk_mem
|
||||
for p in params:
|
||||
grads_device_map[p] = device
|
||||
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
|
||||
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
|
||||
|
||||
|
||||
class AutoPlacementPolicy(PlacementPolicy):
|
||||
|
||||
need_mem_stats: bool = True
|
||||
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
|
||||
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
|
||||
_warmup_non_model_data_ratio: float = 0.8
|
||||
_steady_cuda_cap_ratio: float = 0.9
|
||||
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
warmup_non_model_data_ratio: float = 0.8,
|
||||
steady_cuda_cap_ratio: float = 0.9,
|
||||
**kwargs) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
|
||||
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
|
||||
self._warmup_non_model_data_ratio = warmup_non_model_data_ratio
|
||||
self._steady_cuda_cap_ratio = steady_cuda_cap_ratio
|
||||
|
||||
def evict_tensors(self,
|
||||
can_evict_chunks: List[Chunk],
|
||||
@@ -105,11 +138,11 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
used_cuda_model_data = self.chunk_manager.total_mem['cuda']
|
||||
if warmup:
|
||||
# We designate a part of CUDA memory for model data in warmup iterations.
|
||||
max_cuda_non_model_data_per_period = cuda_capacity * AutoPlacementPolicy._warmup_non_model_data_ratio
|
||||
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
|
||||
else:
|
||||
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
||||
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
|
||||
cuda_capacity *= AutoPlacementPolicy._steady_cuda_cap_ratio
|
||||
cuda_capacity *= self._steady_cuda_cap_ratio
|
||||
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||
freed_cuda_model_data = 0
|
||||
@@ -145,89 +178,22 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
||||
return [t for (t, idx) in next_compute_idx]
|
||||
|
||||
@staticmethod
|
||||
def set_warmup_non_model_data_ratio(ratio: float) -> None:
|
||||
ratio = float(ratio)
|
||||
assert 0.0 < ratio < 1.0
|
||||
AutoPlacementPolicy._warmup_non_model_data_ratio = ratio
|
||||
|
||||
@staticmethod
|
||||
def set_steady_cuda_cap_ratio(ratio: float) -> None:
|
||||
ratio = float(ratio)
|
||||
assert 0.0 < ratio < 1.0
|
||||
AutoPlacementPolicy._steady_cuda_cap_ratio = ratio
|
||||
|
||||
|
||||
class ConstPlacementPolicy(PlacementPolicy):
|
||||
|
||||
need_mem_stats: bool = False
|
||||
_accessed_memory_boundary = 512 * 1024**2
|
||||
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self,
|
||||
can_evict_chunks: List[Chunk],
|
||||
cuda_demand: int = 0,
|
||||
warmup: bool = True,
|
||||
compute_list: Optional[List[Tuple[Chunk, ...]]] = None,
|
||||
compute_idx: int = 0,
|
||||
**kwargs) -> Tuple[int, float]:
|
||||
"""
|
||||
See the docstrings in the class `AutoPlacementPolicy`.
|
||||
"""
|
||||
start = time()
|
||||
used_accessed_memory = self.chunk_manager.accessed_mem
|
||||
avail_accessed_memory = ConstPlacementPolicy._accessed_memory_boundary - used_accessed_memory
|
||||
freed_accessed_memory = 0
|
||||
|
||||
if avail_accessed_memory < cuda_demand:
|
||||
to_free_memory = cuda_demand - avail_accessed_memory
|
||||
to_free_chunks = can_evict_chunks
|
||||
|
||||
if not warmup:
|
||||
# sort all chunks
|
||||
to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list))
|
||||
|
||||
for chunk in to_free_chunks:
|
||||
if freed_accessed_memory >= to_free_memory:
|
||||
break
|
||||
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
freed_accessed_memory += chunk.chunk_mem
|
||||
|
||||
if freed_accessed_memory < to_free_memory:
|
||||
raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! "
|
||||
f"Need {to_free_memory}, freed {freed_accessed_memory}")
|
||||
return freed_accessed_memory, time() - start
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list:
|
||||
next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks}
|
||||
for i in range(len(compute_list) - 1, compute_idx, -1):
|
||||
for chunk in compute_list[i]:
|
||||
if chunk in next_compute_idx:
|
||||
next_compute_idx[chunk] = i
|
||||
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
||||
return [t for (t, idx) in next_compute_idx]
|
||||
|
||||
@staticmethod
|
||||
def set_const_memory_boundary(cuda_memory_mb: int) -> None:
|
||||
boundary = int(cuda_memory_mb * 1024**2)
|
||||
assert boundary > 0
|
||||
ConstPlacementPolicy._accessed_memory_boundary = boundary
|
||||
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
|
||||
torch.device]) -> None:
|
||||
for p in params:
|
||||
chunk = self.chunk_manager.get_chunk(p)
|
||||
# init offload optim settings
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk.keep_gathered:
|
||||
grads_device_map[p] = get_current_device()
|
||||
else:
|
||||
grads_device_map[p] = torch.device('cpu')
|
||||
|
||||
|
||||
class PlacementPolicyFactory:
|
||||
policies: Dict[str, Type[PlacementPolicy]] = {
|
||||
'cpu': CPUPlacementPolicy,
|
||||
'cuda': CUDAPlacementPolicy,
|
||||
'auto': AutoPlacementPolicy,
|
||||
'const': ConstPlacementPolicy
|
||||
'static': StaticPlacementPolicy,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -239,8 +205,3 @@ class PlacementPolicyFactory:
|
||||
@staticmethod
|
||||
def get_policy_names():
|
||||
return tuple(PlacementPolicyFactory.policies.keys())
|
||||
|
||||
@staticmethod
|
||||
def get_default_device(policy_name: str) -> torch.device:
|
||||
policy_cls = PlacementPolicyFactory.create(policy_name)
|
||||
return policy_cls.get_default_device()
|
||||
|
Reference in New Issue
Block a user