[gemini] improve compatibility and add static placement policy (#4479)

* [gemini] remove distributed-related part from colotensor (#4379)

* [gemini] remove process group dependency

* [gemini] remove tp part from colo tensor

* [gemini] patch inplace op

* [gemini] fix param op hook and update tests

* [test] remove useless tests

* [test] remove useless tests

* [misc] fix requirements

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [misc] update requirements

* [gemini] refactor gemini optimizer and gemini ddp (#4398)

* [gemini] update optimizer interface

* [gemini] renaming gemini optimizer

* [gemini] refactor gemini ddp class

* [example] update gemini related example

* [example] update gemini related example

* [plugin] fix gemini plugin args

* [test] update gemini ckpt tests

* [gemini] fix checkpoint io

* [example] fix opt example requirements

* [example] fix opt example

* [example] fix opt example

* [example] fix opt example

* [gemini] add static placement policy (#4443)

* [gemini] add static placement policy

* [gemini] fix param offload

* [test] update gemini tests

* [plugin] update gemini plugin

* [plugin] update gemini plugin docstr

* [misc] fix flash attn requirement

* [test] fix gemini checkpoint io test

* [example] update resnet example result (#4457)

* [example] update bert example result (#4458)

* [doc] update gemini doc (#4468)

* [example] update gemini related examples (#4473)

* [example] update gpt example

* [example] update dreambooth example

* [example] update vit

* [example] update opt

* [example] update palm

* [example] update vit and opt benchmark

* [hotfix] fix bert in model zoo (#4480)

* [hotfix] fix bert in model zoo

* [test] remove chatglm gemini test

* [test] remove sam gemini test

* [test] remove vit gemini test

* [hotfix] fix opt tutorial example (#4497)

* [hotfix] fix opt tutorial example

* [hotfix] fix opt tutorial example
This commit is contained in:
Hongxin Liu
2023-08-24 09:29:25 +08:00
committed by GitHub
parent 285fe7ba71
commit 27061426f7
82 changed files with 1008 additions and 4036 deletions

View File

@@ -1,4 +1,5 @@
import functools
import warnings
from abc import ABC, abstractmethod
from time import time
from typing import Dict, List, Optional, Tuple, Type
@@ -7,6 +8,7 @@ import torch
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import Chunk
from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector
@@ -17,7 +19,8 @@ class PlacementPolicy(ABC):
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
**kwargs) -> None:
self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
@@ -25,57 +28,87 @@ class PlacementPolicy(ABC):
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
raise NotImplementedError
@staticmethod
def get_default_device() -> torch.device:
return torch.device('cpu')
@abstractmethod
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
torch.device]) -> None:
raise NotImplementedError
class CPUPlacementPolicy(PlacementPolicy):
class StaticPlacementPolicy(PlacementPolicy):
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
shard_param_frac: float = 1.0,
offload_optim_frac: float = 0.0,
offload_param_frac: float = 0.0,
**kwargs) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
warnings.warn('offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0')
offload_param_frac = 0.0
self.shard_param_frac = shard_param_frac
self.offload_optim_frac = offload_optim_frac
self.offload_param_frac = offload_param_frac
# these should be initialized in setup_grads_device
self.keep_gathered_chunk_mem = 0.0
self.keep_cuda_chunk_mem = 0.0
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
volume = 0
start = time()
can_shard_chunk_mem = sum(chunk.chunk_mem for chunk in can_evict_chunks)
can_offload_chunk_mem = can_shard_chunk_mem
for chunk in can_evict_chunks:
if can_shard_chunk_mem <= self.keep_gathered_chunk_mem:
break
self.chunk_manager.release_chunk(chunk)
# real saved mem is chunk_mem - shard_mem, for simplicity we use chunk_mem
can_shard_chunk_mem -= chunk.chunk_mem
for chunk in can_evict_chunks:
if can_offload_chunk_mem <= self.keep_cuda_chunk_mem:
break
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
volume += chunk.chunk_mem
return volume, time() - start
# real saved mem is shard_mem, for simplicity we use chunk_mem
can_offload_chunk_mem -= chunk.chunk_mem
return 0, 0.0
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
torch.device]) -> None:
total_chunk_mem = sum(self.chunk_manager.get_chunk(p).chunk_mem for p in params)
class CUDAPlacementPolicy(PlacementPolicy):
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
return 0, 0
@staticmethod
def get_default_device() -> torch.device:
return get_current_device()
offload_optim_chunk_mem = total_chunk_mem * self.offload_optim_frac
offloaded_optim_chunk_mem = 0
chunks = set(self.chunk_manager.get_chunk(p) for p in params)
for chunk in chunks:
params = chunk.get_tensors()
# init offload optim settings
# keep gathered chunks are in CUDA
if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem:
device = get_current_device()
else:
device = torch.device('cpu')
# real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here
offloaded_optim_chunk_mem += chunk.chunk_mem
for p in params:
grads_device_map[p] = device
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
class AutoPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = True
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
_warmup_non_model_data_ratio: float = 0.8
_steady_cuda_cap_ratio: float = 0.9
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
warmup_non_model_data_ratio: float = 0.8,
steady_cuda_cap_ratio: float = 0.9,
**kwargs) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
self._warmup_non_model_data_ratio = warmup_non_model_data_ratio
self._steady_cuda_cap_ratio = steady_cuda_cap_ratio
def evict_tensors(self,
can_evict_chunks: List[Chunk],
@@ -105,11 +138,11 @@ class AutoPlacementPolicy(PlacementPolicy):
used_cuda_model_data = self.chunk_manager.total_mem['cuda']
if warmup:
# We designate a part of CUDA memory for model data in warmup iterations.
max_cuda_non_model_data_per_period = cuda_capacity * AutoPlacementPolicy._warmup_non_model_data_ratio
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
else:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
cuda_capacity *= AutoPlacementPolicy._steady_cuda_cap_ratio
cuda_capacity *= self._steady_cuda_cap_ratio
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
freed_cuda_model_data = 0
@@ -145,89 +178,22 @@ class AutoPlacementPolicy(PlacementPolicy):
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
return [t for (t, idx) in next_compute_idx]
@staticmethod
def set_warmup_non_model_data_ratio(ratio: float) -> None:
ratio = float(ratio)
assert 0.0 < ratio < 1.0
AutoPlacementPolicy._warmup_non_model_data_ratio = ratio
@staticmethod
def set_steady_cuda_cap_ratio(ratio: float) -> None:
ratio = float(ratio)
assert 0.0 < ratio < 1.0
AutoPlacementPolicy._steady_cuda_cap_ratio = ratio
class ConstPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = False
_accessed_memory_boundary = 512 * 1024**2
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self,
can_evict_chunks: List[Chunk],
cuda_demand: int = 0,
warmup: bool = True,
compute_list: Optional[List[Tuple[Chunk, ...]]] = None,
compute_idx: int = 0,
**kwargs) -> Tuple[int, float]:
"""
See the docstrings in the class `AutoPlacementPolicy`.
"""
start = time()
used_accessed_memory = self.chunk_manager.accessed_mem
avail_accessed_memory = ConstPlacementPolicy._accessed_memory_boundary - used_accessed_memory
freed_accessed_memory = 0
if avail_accessed_memory < cuda_demand:
to_free_memory = cuda_demand - avail_accessed_memory
to_free_chunks = can_evict_chunks
if not warmup:
# sort all chunks
to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list))
for chunk in to_free_chunks:
if freed_accessed_memory >= to_free_memory:
break
self.chunk_manager.release_chunk(chunk)
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
freed_accessed_memory += chunk.chunk_mem
if freed_accessed_memory < to_free_memory:
raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! "
f"Need {to_free_memory}, freed {freed_accessed_memory}")
return freed_accessed_memory, time() - start
@staticmethod
@functools.lru_cache(maxsize=None)
def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list:
next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks}
for i in range(len(compute_list) - 1, compute_idx, -1):
for chunk in compute_list[i]:
if chunk in next_compute_idx:
next_compute_idx[chunk] = i
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
return [t for (t, idx) in next_compute_idx]
@staticmethod
def set_const_memory_boundary(cuda_memory_mb: int) -> None:
boundary = int(cuda_memory_mb * 1024**2)
assert boundary > 0
ConstPlacementPolicy._accessed_memory_boundary = boundary
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
torch.device]) -> None:
for p in params:
chunk = self.chunk_manager.get_chunk(p)
# init offload optim settings
# keep gathered chunks are in CUDA
if chunk.keep_gathered:
grads_device_map[p] = get_current_device()
else:
grads_device_map[p] = torch.device('cpu')
class PlacementPolicyFactory:
policies: Dict[str, Type[PlacementPolicy]] = {
'cpu': CPUPlacementPolicy,
'cuda': CUDAPlacementPolicy,
'auto': AutoPlacementPolicy,
'const': ConstPlacementPolicy
'static': StaticPlacementPolicy,
}
@staticmethod
@@ -239,8 +205,3 @@ class PlacementPolicyFactory:
@staticmethod
def get_policy_names():
return tuple(PlacementPolicyFactory.policies.keys())
@staticmethod
def get_default_device(policy_name: str) -> torch.device:
policy_cls = PlacementPolicyFactory.create(policy_name)
return policy_cls.get_default_device()