mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
Merge branch 'main' into feature/shardformer
This commit is contained in:
@@ -2,8 +2,7 @@ from .gemini import (
|
||||
ColoInitContext,
|
||||
GeminiAdamOptimizer,
|
||||
GeminiDDP,
|
||||
ZeroDDP,
|
||||
ZeroOptimizer,
|
||||
GeminiOptimizer,
|
||||
get_static_torch_model,
|
||||
post_process_colo_init_ctx,
|
||||
)
|
||||
@@ -11,6 +10,6 @@ from .low_level import LowLevelZeroOptimizer
|
||||
from .wrapper import zero_model_wrapper, zero_optim_wrapper
|
||||
|
||||
__all__ = [
|
||||
'ZeroDDP', 'GeminiDDP', 'ZeroOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper',
|
||||
'GeminiDDP', 'GeminiOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper',
|
||||
'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model'
|
||||
]
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration
|
||||
from .colo_init_context import ColoInitContext, post_process_colo_init_ctx
|
||||
from .gemini_ddp import GeminiDDP, ZeroDDP
|
||||
from .gemini_ddp import GeminiDDP
|
||||
from .gemini_mgr import GeminiManager
|
||||
from .gemini_optimizer import GeminiAdamOptimizer, ZeroOptimizer
|
||||
from .gemini_optimizer import GeminiAdamOptimizer, GeminiOptimizer
|
||||
from .utils import get_static_torch_model
|
||||
|
||||
__all__ = [
|
||||
'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'ZeroDDP', 'GeminiDDP',
|
||||
'get_static_torch_model', 'GeminiAdamOptimizer', 'ZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx'
|
||||
'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'GeminiDDP',
|
||||
'get_static_torch_model', 'GeminiAdamOptimizer', 'GeminiOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx'
|
||||
]
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ class Chunk:
|
||||
|
||||
def __init__(self,
|
||||
chunk_size: int,
|
||||
process_group: ColoProcessGroup,
|
||||
process_group: ProcessGroup,
|
||||
dtype: torch.dtype,
|
||||
init_device: Optional[torch.device] = None,
|
||||
cpu_shard_init: bool = False,
|
||||
@@ -69,7 +69,7 @@ class Chunk:
|
||||
|
||||
Args:
|
||||
chunk_size (int): the number of elements in the chunk
|
||||
process_group (ColoProcessGroup): the process group of this chunk
|
||||
process_group (ProcessGroup): the process group of this chunk
|
||||
dtype (torch.dtype): the data type of the chunk
|
||||
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
|
||||
The default value is None, which is the current GPU
|
||||
@@ -83,7 +83,7 @@ class Chunk:
|
||||
self.chunk_size = chunk_size
|
||||
self.utilized_size = 0
|
||||
|
||||
self.torch_pg = process_group.dp_process_group()
|
||||
self.torch_pg = process_group
|
||||
self.pg_size = dist.get_world_size(self.torch_pg)
|
||||
self.pg_rank = dist.get_rank(self.torch_pg)
|
||||
|
||||
@@ -218,7 +218,7 @@ class Chunk:
|
||||
return False
|
||||
else:
|
||||
return self.tensor_state_cnter[TensorState.HOLD] + \
|
||||
self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors
|
||||
self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors
|
||||
|
||||
@property
|
||||
def can_reduce(self):
|
||||
|
||||
@@ -2,8 +2,9 @@ from collections import deque
|
||||
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .chunk import Chunk, ChunkFullError, TensorState
|
||||
@@ -27,16 +28,17 @@ class ChunkManager:
|
||||
self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size')
|
||||
v['init_device'] = self.device
|
||||
|
||||
self.chunk_groups: Dict[str, Deque] = dict()
|
||||
self.chunk_groups: Dict[str, Deque[Chunk]] = dict()
|
||||
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
|
||||
self.accessed_chunks: Set[Chunk] = set()
|
||||
self.accessed_mem: int = 0
|
||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||
|
||||
def register_tensor(self,
|
||||
tensor: ColoTensor,
|
||||
tensor: torch.Tensor,
|
||||
group_type: str,
|
||||
config_key: int,
|
||||
process_group: ProcessGroup,
|
||||
cpu_offload: bool = False,
|
||||
pin_memory: bool = False) -> None:
|
||||
"""
|
||||
@@ -51,7 +53,7 @@ class ChunkManager:
|
||||
pin_memory: whether the chunk is pinned in the cpu memory
|
||||
"""
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
|
||||
assert isinstance(tensor, torch.Tensor), "Please feed Tensor to this ChunkManager"
|
||||
assert config_key in self.dp_degree_chunk_size_dict
|
||||
|
||||
chunk_size = self.dp_degree_chunk_size_dict[config_key]
|
||||
@@ -73,12 +75,12 @@ class ChunkManager:
|
||||
|
||||
if tensor.numel() > chunk_size:
|
||||
chunk_size = tensor.numel()
|
||||
dp_size = tensor.get_dp_world_size()
|
||||
dp_size = dist.get_world_size(process_group)
|
||||
chunk_size = chunk_size + (-chunk_size % dp_size)
|
||||
|
||||
chunk = Chunk(
|
||||
chunk_size=chunk_size,
|
||||
process_group=tensor.process_group,
|
||||
process_group=process_group,
|
||||
dtype=tensor.dtype,
|
||||
cpu_shard_init=cpu_offload,
|
||||
pin_memory=pin_memory,
|
||||
@@ -220,7 +222,7 @@ class ChunkManager:
|
||||
msg.append(f'[{i}] {chunk}\n')
|
||||
return ''.join(msg)
|
||||
|
||||
def __get_chunk_group(self, group_name: str) -> Deque:
|
||||
def __get_chunk_group(self, group_name: str) -> Deque[Chunk]:
|
||||
"""Register a chunk group.
|
||||
"""
|
||||
if group_name not in self.chunk_groups:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
@@ -59,7 +60,7 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
|
||||
return left + acc
|
||||
|
||||
|
||||
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int:
|
||||
def _tensor_numel(local_param: ColoParameter) -> int:
|
||||
"""_tensor_numel
|
||||
|
||||
Get the number of elements of a tensor.
|
||||
@@ -71,15 +72,12 @@ def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int:
|
||||
Returns:
|
||||
int: the number of elements.
|
||||
"""
|
||||
if strict_ddp_flag and type(local_param) is ColoParameter:
|
||||
return local_param.numel_global()
|
||||
else:
|
||||
# if local_param is not ColoParameter, we assume it's replicated
|
||||
return local_param.numel()
|
||||
# TODO(ver217): support dtensor here
|
||||
return local_param.numel()
|
||||
|
||||
|
||||
def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
|
||||
strict_ddp_flag: bool = False) -> Dict[int, List[ColoParameter]]:
|
||||
process_group: ProcessGroup) -> Dict[int, List[ColoParameter]]:
|
||||
"""classify_params_by_dp_degree
|
||||
|
||||
Classify the parameters by their dp degree
|
||||
@@ -97,13 +95,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
|
||||
# assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
|
||||
if is_ddp_ignored(param):
|
||||
continue
|
||||
|
||||
if strict_ddp_flag or type(param) is not ColoParameter:
|
||||
# if model is not initialized with ColoInitContext, we assume it's replicated
|
||||
# TODO(ver217): integrate DTensor
|
||||
param_key = dist.get_world_size()
|
||||
else:
|
||||
param_key = param.process_group.dp_world_size()
|
||||
param_key = dist.get_world_size(process_group)
|
||||
|
||||
if param_key not in params_dict:
|
||||
params_dict[param_key] = []
|
||||
@@ -119,6 +111,7 @@ def search_chunk_configuration(
|
||||
min_chunk_size_m: float = 32,
|
||||
filter_exlarge_params: bool = True,
|
||||
strict_ddp_flag: bool = False,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
|
||||
"""search_chunk_configuration
|
||||
|
||||
@@ -149,7 +142,7 @@ def search_chunk_configuration(
|
||||
min_chunk_size = round(min_chunk_size_m * 1024**2)
|
||||
assert search_range >= 0
|
||||
|
||||
params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag)
|
||||
params_dict = classify_params_by_dp_degree(param_order, process_group)
|
||||
size_lcm = np.lcm.reduce(list(params_dict.keys()))
|
||||
config_dict: Dict[int, Dict] = dict()
|
||||
total_param_size = 0
|
||||
@@ -157,7 +150,7 @@ def search_chunk_configuration(
|
||||
size_dict: Dict[int, List[int]] = dict()
|
||||
for dp_degree in params_dict:
|
||||
params_list = params_dict[dp_degree]
|
||||
size_list = [_tensor_numel(p, strict_ddp_flag) for p in params_list]
|
||||
size_list = [_tensor_numel(p) for p in params_list]
|
||||
group_acc_size = sum(size_list)
|
||||
total_param_size += group_acc_size
|
||||
|
||||
|
||||
@@ -2,19 +2,21 @@ import itertools
|
||||
from collections import OrderedDict
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Dict, Iterator, List, Optional, Set, Tuple, Union
|
||||
from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
from colossalai.checkpoint_io.utils import calculate_tensor_size, StateDictSharder
|
||||
from colossalai.interface import ModelWrapper
|
||||
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder
|
||||
from colossalai.lazy import LazyTensor
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.tensor import ReplicaSpec
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float, free_storage
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import get_current_device, is_ddp_ignored
|
||||
|
||||
@@ -30,14 +32,13 @@ except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
|
||||
__all__ = [
|
||||
'ZeroDDP',
|
||||
'GeminiDDP',
|
||||
]
|
||||
|
||||
|
||||
class ZeroDDP(ColoDDP):
|
||||
"""ZeRO DDP for ColoTensor.
|
||||
Warning: Nested ZeroDDP is not supported now.
|
||||
class GeminiDDP(ModelWrapper):
|
||||
"""ZeRO DDP.
|
||||
Warning: Nested GeminiDDP is not supported now.
|
||||
It is designed to be used with ChunkManager and GeminiManager.
|
||||
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
|
||||
|
||||
@@ -54,20 +55,54 @@ class ZeroDDP(ColoDDP):
|
||||
mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
gemini_manager: GeminiManager,
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True,
|
||||
mixed_precision: torch.dtype = torch.float16) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
chunk_config_dict: Optional[dict] = None,
|
||||
chunk_init_device: torch.device = torch.device('cpu'),
|
||||
placement_policy: str = "static",
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
|
||||
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
|
||||
search_range_m: int = 32, # chunk search options
|
||||
hidden_dim: Optional[int] = None, # chunk search options
|
||||
min_chunk_size_m: float = 32, # chunk search options
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True,
|
||||
mixed_precision: torch.dtype = torch.float16,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
memstats: Optional[MemStats] = None, # genimi memory stats
|
||||
verbose: bool = False) -> None:
|
||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
if chunk_config_dict is not None:
|
||||
self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device)
|
||||
else:
|
||||
# some ugly hotfix for the compatibility with Lightning
|
||||
if search_range_m is None:
|
||||
search_range_m = 32
|
||||
self.chunk_manager = init_chunk_manager(model=module,
|
||||
init_device=chunk_init_device,
|
||||
hidden_dim=hidden_dim,
|
||||
search_range_m=search_range_m,
|
||||
min_chunk_size_m=min_chunk_size_m,
|
||||
strict_ddp_flag=strict_ddp_mode,
|
||||
process_group=process_group,
|
||||
verbose=verbose)
|
||||
self.gemini_manager = GeminiManager(placement_policy,
|
||||
self.chunk_manager,
|
||||
memstats,
|
||||
shard_param_frac=shard_param_frac,
|
||||
offload_optim_frac=offload_optim_frac,
|
||||
offload_param_frac=offload_param_frac,
|
||||
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
|
||||
steady_cuda_cap_ratio=steady_cuda_cap_ratio)
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = GeminiZeROHook(gemini_manager)
|
||||
self.fp32_params: List[ColoTensor] = list()
|
||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
||||
self.fp32_params: List[torch.Tensor] = list()
|
||||
self.fp16_params: List[ColoParameter] = list()
|
||||
self.overflow_counter = 0
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||
@@ -75,6 +110,7 @@ class ZeroDDP(ColoDDP):
|
||||
self.name2param: Dict[str, nn.Parameter] = dict()
|
||||
self.scatter_after_inference = scatter_after_inference
|
||||
self.mixed_precision = mixed_precision
|
||||
self.dp_process_group = process_group or _get_default_group()
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
@@ -88,20 +124,67 @@ class ZeroDDP(ColoDDP):
|
||||
for p in module.parameters():
|
||||
param_order.append(p)
|
||||
|
||||
self._init_chunks(param_order=param_order,
|
||||
strict_ddp_mode=strict_ddp_mode,
|
||||
cpu_offload=self.gemini_manager.policy_name != 'cuda',
|
||||
pin_memory=pin_memory)
|
||||
|
||||
for name, param in module.named_parameters():
|
||||
self.param2name[param] = name
|
||||
for m_name, m_var in module.named_modules():
|
||||
for p_name, p_var in m_var.named_parameters(recurse=False):
|
||||
param_name = m_name + '.' + p_name if m_name else p_name
|
||||
self.name2param[param_name] = p_var
|
||||
super().__init__(module, process_group=ColoProcessGroup())
|
||||
|
||||
self._init_chunks(param_order=param_order,
|
||||
strict_ddp_mode=strict_ddp_mode,
|
||||
cpu_offload=self.gemini_manager.policy_name != 'cuda',
|
||||
pin_memory=pin_memory)
|
||||
super().__init__(module)
|
||||
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
|
||||
self._cast_buffers()
|
||||
# register grad hook
|
||||
for p in module.parameters():
|
||||
if is_ddp_ignored(p):
|
||||
continue
|
||||
if p.requires_grad:
|
||||
p.register_hook(partial(self.grad_handle, p))
|
||||
|
||||
def parameters(self, recurse: bool = True):
|
||||
return self.module.parameters(recurse)
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True):
|
||||
return self.module.named_parameters(prefix, recurse)
|
||||
|
||||
def named_buffers(self, prefix: str = '', recurse: bool = True):
|
||||
return self.module.named_buffers(prefix, recurse)
|
||||
|
||||
def named_children(self):
|
||||
return self.module.named_children()
|
||||
|
||||
def named_modules(self,
|
||||
memo: Optional[Set[torch.nn.Module]] = None,
|
||||
prefix: str = '',
|
||||
remove_duplicate: bool = True):
|
||||
return self.module.named_modules(memo, prefix, remove_duplicate)
|
||||
|
||||
@staticmethod
|
||||
def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None:
|
||||
"""Sets parameters to be ignored by DDP.
|
||||
This method must be called before initializing ColoDDP.
|
||||
|
||||
Example:
|
||||
>>> params_to_ignore = []
|
||||
>>> for p in module.parameters():
|
||||
>>> if should_ignore(p):
|
||||
>>> params_to_ignore.append(p)
|
||||
>>> ColoDDP.set_params_to_ignore(params_to_ignore)
|
||||
>>> module = ColoDDP(module)
|
||||
|
||||
Args:
|
||||
params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored.
|
||||
"""
|
||||
for p in params_to_ignore:
|
||||
p._ddp_to_ignore = True
|
||||
|
||||
def unwrap(self):
|
||||
# as save/load state dict is overwrited, only return self
|
||||
return self
|
||||
|
||||
def _get_non_persistent_buffers_set(self,
|
||||
module,
|
||||
@@ -207,7 +290,7 @@ class ZeroDDP(ColoDDP):
|
||||
error_params.append(self.param2name[param])
|
||||
error_str = "\n\t".join(error_params)
|
||||
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
|
||||
"The most possible reason is that the model is not compatible with ZeroDDP.\n",
|
||||
"The most possible reason is that the model is not compatible with GeminiDDP.\n",
|
||||
f"{error_str}")
|
||||
self._setup_grads_ptr()
|
||||
self._logger.debug(
|
||||
@@ -227,6 +310,7 @@ class ZeroDDP(ColoDDP):
|
||||
self._post_backward()
|
||||
|
||||
def grad_handle(self, p, grad):
|
||||
setattr(p, "_gemini_reduced", True)
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
with torch._C.DisableTorchFunction():
|
||||
@@ -533,7 +617,7 @@ class ZeroDDP(ColoDDP):
|
||||
for chunk_32 in chunk_list:
|
||||
chunk_16 = chunk_32.paired_chunk
|
||||
assert chunk_16 is not None
|
||||
chunk_16.optim_update()
|
||||
chunk_16.payload.copy_(chunk_32.payload)
|
||||
|
||||
for name, buf in persistent_buffers.items():
|
||||
if buf is not None:
|
||||
@@ -557,17 +641,11 @@ class ZeroDDP(ColoDDP):
|
||||
unexpected_keys.append(key)
|
||||
|
||||
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
|
||||
ddp_pg = ColoProcessGroup()
|
||||
dp_world_size = dist.get_world_size(self.dp_process_group)
|
||||
for p in param_order.generate():
|
||||
self._preprocess_param(p)
|
||||
assert type(p) is ColoParameter
|
||||
|
||||
# gather sharded parameters in the strict ddp mode
|
||||
if strict_ddp_mode:
|
||||
if not p.is_replicate():
|
||||
p.set_dist_spec(ReplicaSpec())
|
||||
p.set_process_group(pg=ddp_pg)
|
||||
|
||||
# ignore the parameters with no gradient
|
||||
if not p.requires_grad:
|
||||
self.set_params_to_ignore([p])
|
||||
@@ -578,38 +656,37 @@ class ZeroDDP(ColoDDP):
|
||||
continue
|
||||
|
||||
# create a fp32 parameter
|
||||
fp32_data = p.data.float()
|
||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||
fp32_p = p.data.float()
|
||||
# create a fp16 parameter
|
||||
p.data = p.data.to(self.mixed_precision)
|
||||
|
||||
# register the fp16 parameter and fp32 parameter in the chunk manager
|
||||
dp_world_size = p.process_group.dp_world_size()
|
||||
self.chunk_manager.register_tensor(tensor=p,
|
||||
group_type='fp16_param',
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.chunk_manager.register_tensor(tensor=fp32_p,
|
||||
group_type='fp32_param',
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
|
||||
self.fp16_params.append(p)
|
||||
self.fp32_params.append(fp32_p)
|
||||
self.grads_device[p] = self.gemini_manager.default_device
|
||||
|
||||
self.chunk_manager.close_all_groups()
|
||||
|
||||
self.gemini_manager.setup_grads_device(self.fp16_params, self.grads_device)
|
||||
# move master weights to corresponding device and setup paired chunks
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||
chunk_32.init_pair(chunk_16)
|
||||
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk_16.keep_gathered:
|
||||
self.grads_device[p] = get_current_device()
|
||||
if chunk_32.device_type != self.grads_device[p].type:
|
||||
self.chunk_manager.move_chunk(chunk_32, self.grads_device[p])
|
||||
|
||||
def _cast_buffers(self):
|
||||
for buffer in self.module.buffers():
|
||||
@@ -705,65 +782,3 @@ class ZeroDDP(ColoDDP):
|
||||
yield sharder.current_block, sharder.current_block_size
|
||||
|
||||
|
||||
class GeminiDDP(ZeroDDP):
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
device: torch.device,
|
||||
placement_policy: str = "cpu",
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True,
|
||||
search_range_m: int = 32,
|
||||
hidden_dim: Optional[int] = None,
|
||||
min_chunk_size_m: float = 32,
|
||||
memstats: Optional[MemStats] = None,
|
||||
mixed_precision: torch.dtype = torch.float16,
|
||||
verbose: bool = False) -> None:
|
||||
"""
|
||||
A torch.Module wrapper using ZeRO-DP and Gemini.
|
||||
ZeRO is for parallel. Gemini is for memory management.
|
||||
WARNING: The class will modify the module inline!
|
||||
|
||||
Example:
|
||||
model is initialized under the context of ColoInitContext
|
||||
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
|
||||
>>> logits = model(x)
|
||||
>>> loss = criterion(logits, labels)
|
||||
>>> model.backward(loss)
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): the model to be wrapped.
|
||||
device (torch.device): device to place the model.
|
||||
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
|
||||
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
||||
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
||||
search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
|
||||
hidden_dim (int, optional): the hidden dimension of DNN.
|
||||
Users can provide this argument to speed up searching.
|
||||
If users do not know this argument before training, it is ok. We will use a default value 1024.
|
||||
min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
|
||||
If the aggregate size of parameters is still smaller than the minimum chunk size,
|
||||
all parameters will be compacted into one small chunk.
|
||||
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
|
||||
"""
|
||||
# some ugly hotfix for the compatibility with Lightning
|
||||
if search_range_m is None:
|
||||
search_range_m = 32
|
||||
|
||||
chunk_manager = init_chunk_manager(model=module,
|
||||
init_device=device,
|
||||
hidden_dim=hidden_dim,
|
||||
search_range_m=search_range_m,
|
||||
min_chunk_size_m=min_chunk_size_m,
|
||||
strict_ddp_flag=strict_ddp_mode,
|
||||
verbose=verbose)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||
super().__init__(module,
|
||||
gemini_manager,
|
||||
pin_memory,
|
||||
force_outputs_fp32,
|
||||
strict_ddp_mode,
|
||||
scatter_after_inference,
|
||||
mixed_precision=mixed_precision)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import functools
|
||||
from time import time
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -26,7 +26,11 @@ class GeminiManager:
|
||||
memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
|
||||
def __init__(self,
|
||||
placement_policy: str,
|
||||
chunk_manager: ChunkManager,
|
||||
memstats: Optional[MemStats] = None,
|
||||
**placement_kwargs) -> None:
|
||||
|
||||
assert placement_policy in PlacementPolicyFactory.get_policy_names()
|
||||
self.policy_name = placement_policy
|
||||
@@ -37,7 +41,7 @@ class GeminiManager:
|
||||
self._memstats = memstats
|
||||
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager,
|
||||
self._memstats) if policy_cls.need_mem_stats else None
|
||||
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
|
||||
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs)
|
||||
self._compute_list: List[Tuple[Chunk, ...]] = []
|
||||
self._compute_idx: int = -1
|
||||
|
||||
@@ -133,10 +137,6 @@ class GeminiManager:
|
||||
if self._warmup and self._placement_policy.need_mem_stats:
|
||||
self._compute_list.append(chunks)
|
||||
|
||||
@property
|
||||
def default_device(self):
|
||||
return self._placement_policy.get_default_device()
|
||||
|
||||
def sample_overall_data(self):
|
||||
if self._mem_stats_collector:
|
||||
self._mem_stats_collector.sample_overall_data()
|
||||
@@ -159,6 +159,6 @@ class GeminiManager:
|
||||
def is_cuda_margin_mem_avail(self) -> bool:
|
||||
return self._placement_policy.need_mem_stats
|
||||
|
||||
@staticmethod
|
||||
def get_default_device(policy_name: str) -> torch.device:
|
||||
return PlacementPolicyFactory.get_default_device(policy_name)
|
||||
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
|
||||
torch.device]) -> None:
|
||||
self._placement_policy.setup_grads_device(params, grads_device_map)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple
|
||||
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -10,16 +10,17 @@ from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder
|
||||
from colossalai.checkpoint_io.utils import calculate_tensor_size, StateDictSharder
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
|
||||
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .gemini_ddp import ZeroDDP
|
||||
from .gemini_ddp import GeminiDDP
|
||||
|
||||
__all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer']
|
||||
__all__ = ['GeminiOptimizer', 'GeminiAdamOptimizer']
|
||||
|
||||
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
|
||||
|
||||
@@ -27,7 +28,7 @@ _AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
|
||||
class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
|
||||
def __init__(self,
|
||||
module: ZeroDDP,
|
||||
module: GeminiDDP,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
@@ -46,11 +47,11 @@ class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
self.module.overflow_counter = 0
|
||||
|
||||
|
||||
class ZeroOptimizer(ColossalaiOptimizer):
|
||||
"""A wrapper for optimizer. ``ZeroDDP`` and ``ZeroOptimizer`` implement Zero Redundancy Optimizer (ZeRO state-3).
|
||||
class GeminiOptimizer(OptimizerWrapper):
|
||||
"""A wrapper for optimizer. ``GeminiDDP`` and ``GeminiOptimizer`` implement Zero Redundancy Optimizer (ZeRO state-3).
|
||||
|
||||
Note:
|
||||
You must use ``ZeroDDP`` with ``ZeroOptimizer``.
|
||||
You must use ``GeminiDDP`` with ``GeminiOptimizer``.
|
||||
|
||||
Note:
|
||||
Make sure you set ``placement_policy`` of ``GeminiManager`` to `"auto"`,
|
||||
@@ -58,7 +59,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
|
||||
Args:
|
||||
optim (Optimizer): An Optimizer instance.
|
||||
module (ZeroDDP): A ``ZeroDDP`` instance.
|
||||
module (GeminiDDP): A ``GeminiDDP`` instance.
|
||||
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
|
||||
which will be used when using hybrid CPU optimizer.
|
||||
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
|
||||
@@ -70,15 +71,15 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
growth_interval (float, optional): Growth_interval used by DynamicGradScaler. Defaults to 1000.
|
||||
hysteresis (float, optional): Hysteresis used by DynamicGradScaler. Defaults to 2.
|
||||
max_scale (int, optional): Max_scale used by DynamicGradScaler. Defaults to 2**32.
|
||||
clipping_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0.
|
||||
max_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0.
|
||||
norm_type (float, optional): The type of norm used for gradient clipping. Currently, only L2-norm (norm_type=2.0)
|
||||
is supported in ZeroOptimizer. Defaults to 2.0.
|
||||
is supported in GeminiOptimizer. Defaults to 2.0.
|
||||
verbose (bool, optional): Whether to print verbose information, including grad overflow info. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optim: Optimizer,
|
||||
module: ZeroDDP,
|
||||
module: GeminiDDP,
|
||||
gpu_margin_mem_ratio: float = 0.0,
|
||||
initial_scale: float = 2**32,
|
||||
min_scale: float = 1,
|
||||
@@ -87,12 +88,12 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
clipping_norm: float = 0.0,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0,
|
||||
verbose: bool = False,
|
||||
**defaults: Any):
|
||||
super().__init__(optim)
|
||||
assert isinstance(module, ZeroDDP)
|
||||
assert isinstance(module, GeminiDDP)
|
||||
assert type(optim) in _AVAIL_OPTIM_LIST, "You should use an optimizer in the available list:\n" \
|
||||
f"{_AVAIL_OPTIM_LIST}"
|
||||
self.module = module
|
||||
@@ -101,8 +102,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
|
||||
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
|
||||
self.chunk16_set: Set[Chunk] = set()
|
||||
self.clipping_flag = clipping_norm > 0.0
|
||||
self.max_norm = clipping_norm
|
||||
self.clipping_flag = max_norm > 0.0
|
||||
self.max_norm = max_norm
|
||||
self.verbose = verbose
|
||||
self.param_groups_backup = list()
|
||||
|
||||
@@ -111,7 +112,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
self.id_to_fake_params: Dict[int, Parameter] = dict()
|
||||
|
||||
if self.clipping_flag:
|
||||
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
|
||||
assert norm_type == 2.0, "GeminiOptimizer only supports L2 norm now"
|
||||
|
||||
ddp_param_list = []
|
||||
for name, param in module.named_parameters():
|
||||
@@ -703,8 +704,19 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
|
||||
yield sharder.current_block, sharder.current_block_size
|
||||
|
||||
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
|
||||
raise NotImplementedError('Gemini does not support clip_grad_by_value')
|
||||
|
||||
class GeminiAdamOptimizer(ZeroOptimizer):
|
||||
def clip_grad_by_norm(self,
|
||||
max_norm: Union[float, int],
|
||||
norm_type: Union[float, int] = 2,
|
||||
error_if_nonfinite: bool = False,
|
||||
*args,
|
||||
**kwargs) -> torch.Tensor:
|
||||
warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')
|
||||
|
||||
|
||||
class GeminiAdamOptimizer(GeminiOptimizer):
|
||||
|
||||
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
|
||||
optimizer = HybridAdam(model.parameters(), **defaults)
|
||||
|
||||
@@ -9,7 +9,7 @@ class MemStats(object):
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Store the non model data statistics used for Gemini and ZeroOptimizer.
|
||||
Store the non model data statistics used for Gemini and GeminiOptimizer.
|
||||
"""
|
||||
# (preop_step, List[param])
|
||||
self._step_param_dict = dict()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import functools
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from time import time
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
@@ -7,6 +8,7 @@ import torch
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.zero.gemini.chunk import Chunk
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .memory_tracer import ChunkMemStatsCollector
|
||||
@@ -17,7 +19,8 @@ class PlacementPolicy(ABC):
|
||||
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
**kwargs) -> None:
|
||||
self.chunk_manager = chunk_manager
|
||||
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
|
||||
|
||||
@@ -25,57 +28,87 @@ class PlacementPolicy(ABC):
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_default_device() -> torch.device:
|
||||
return torch.device('cpu')
|
||||
@abstractmethod
|
||||
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
|
||||
torch.device]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CPUPlacementPolicy(PlacementPolicy):
|
||||
class StaticPlacementPolicy(PlacementPolicy):
|
||||
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
shard_param_frac: float = 1.0,
|
||||
offload_optim_frac: float = 0.0,
|
||||
offload_param_frac: float = 0.0,
|
||||
**kwargs) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
|
||||
warnings.warn('offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0')
|
||||
offload_param_frac = 0.0
|
||||
self.shard_param_frac = shard_param_frac
|
||||
self.offload_optim_frac = offload_optim_frac
|
||||
self.offload_param_frac = offload_param_frac
|
||||
# these should be initialized in setup_grads_device
|
||||
self.keep_gathered_chunk_mem = 0.0
|
||||
self.keep_cuda_chunk_mem = 0.0
|
||||
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||
volume = 0
|
||||
start = time()
|
||||
can_shard_chunk_mem = sum(chunk.chunk_mem for chunk in can_evict_chunks)
|
||||
can_offload_chunk_mem = can_shard_chunk_mem
|
||||
for chunk in can_evict_chunks:
|
||||
if can_shard_chunk_mem <= self.keep_gathered_chunk_mem:
|
||||
break
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
# real saved mem is chunk_mem - shard_mem, for simplicity we use chunk_mem
|
||||
can_shard_chunk_mem -= chunk.chunk_mem
|
||||
for chunk in can_evict_chunks:
|
||||
if can_offload_chunk_mem <= self.keep_cuda_chunk_mem:
|
||||
break
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
volume += chunk.chunk_mem
|
||||
return volume, time() - start
|
||||
# real saved mem is shard_mem, for simplicity we use chunk_mem
|
||||
can_offload_chunk_mem -= chunk.chunk_mem
|
||||
return 0, 0.0
|
||||
|
||||
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
|
||||
torch.device]) -> None:
|
||||
total_chunk_mem = sum(self.chunk_manager.get_chunk(p).chunk_mem for p in params)
|
||||
|
||||
class CUDAPlacementPolicy(PlacementPolicy):
|
||||
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||
return 0, 0
|
||||
|
||||
@staticmethod
|
||||
def get_default_device() -> torch.device:
|
||||
return get_current_device()
|
||||
offload_optim_chunk_mem = total_chunk_mem * self.offload_optim_frac
|
||||
offloaded_optim_chunk_mem = 0
|
||||
chunks = set(self.chunk_manager.get_chunk(p) for p in params)
|
||||
for chunk in chunks:
|
||||
params = chunk.get_tensors()
|
||||
# init offload optim settings
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem:
|
||||
device = get_current_device()
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
# real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here
|
||||
offloaded_optim_chunk_mem += chunk.chunk_mem
|
||||
for p in params:
|
||||
grads_device_map[p] = device
|
||||
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
|
||||
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
|
||||
|
||||
|
||||
class AutoPlacementPolicy(PlacementPolicy):
|
||||
|
||||
need_mem_stats: bool = True
|
||||
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
|
||||
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
|
||||
_warmup_non_model_data_ratio: float = 0.8
|
||||
_steady_cuda_cap_ratio: float = 0.9
|
||||
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
warmup_non_model_data_ratio: float = 0.8,
|
||||
steady_cuda_cap_ratio: float = 0.9,
|
||||
**kwargs) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
|
||||
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
|
||||
self._warmup_non_model_data_ratio = warmup_non_model_data_ratio
|
||||
self._steady_cuda_cap_ratio = steady_cuda_cap_ratio
|
||||
|
||||
def evict_tensors(self,
|
||||
can_evict_chunks: List[Chunk],
|
||||
@@ -105,11 +138,11 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
used_cuda_model_data = self.chunk_manager.total_mem['cuda']
|
||||
if warmup:
|
||||
# We designate a part of CUDA memory for model data in warmup iterations.
|
||||
max_cuda_non_model_data_per_period = cuda_capacity * AutoPlacementPolicy._warmup_non_model_data_ratio
|
||||
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
|
||||
else:
|
||||
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
||||
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
|
||||
cuda_capacity *= AutoPlacementPolicy._steady_cuda_cap_ratio
|
||||
cuda_capacity *= self._steady_cuda_cap_ratio
|
||||
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||
freed_cuda_model_data = 0
|
||||
@@ -145,89 +178,22 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
||||
return [t for (t, idx) in next_compute_idx]
|
||||
|
||||
@staticmethod
|
||||
def set_warmup_non_model_data_ratio(ratio: float) -> None:
|
||||
ratio = float(ratio)
|
||||
assert 0.0 < ratio < 1.0
|
||||
AutoPlacementPolicy._warmup_non_model_data_ratio = ratio
|
||||
|
||||
@staticmethod
|
||||
def set_steady_cuda_cap_ratio(ratio: float) -> None:
|
||||
ratio = float(ratio)
|
||||
assert 0.0 < ratio < 1.0
|
||||
AutoPlacementPolicy._steady_cuda_cap_ratio = ratio
|
||||
|
||||
|
||||
class ConstPlacementPolicy(PlacementPolicy):
|
||||
|
||||
need_mem_stats: bool = False
|
||||
_accessed_memory_boundary = 512 * 1024**2
|
||||
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self,
|
||||
can_evict_chunks: List[Chunk],
|
||||
cuda_demand: int = 0,
|
||||
warmup: bool = True,
|
||||
compute_list: Optional[List[Tuple[Chunk, ...]]] = None,
|
||||
compute_idx: int = 0,
|
||||
**kwargs) -> Tuple[int, float]:
|
||||
"""
|
||||
See the docstrings in the class `AutoPlacementPolicy`.
|
||||
"""
|
||||
start = time()
|
||||
used_accessed_memory = self.chunk_manager.accessed_mem
|
||||
avail_accessed_memory = ConstPlacementPolicy._accessed_memory_boundary - used_accessed_memory
|
||||
freed_accessed_memory = 0
|
||||
|
||||
if avail_accessed_memory < cuda_demand:
|
||||
to_free_memory = cuda_demand - avail_accessed_memory
|
||||
to_free_chunks = can_evict_chunks
|
||||
|
||||
if not warmup:
|
||||
# sort all chunks
|
||||
to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list))
|
||||
|
||||
for chunk in to_free_chunks:
|
||||
if freed_accessed_memory >= to_free_memory:
|
||||
break
|
||||
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
freed_accessed_memory += chunk.chunk_mem
|
||||
|
||||
if freed_accessed_memory < to_free_memory:
|
||||
raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! "
|
||||
f"Need {to_free_memory}, freed {freed_accessed_memory}")
|
||||
return freed_accessed_memory, time() - start
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list:
|
||||
next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks}
|
||||
for i in range(len(compute_list) - 1, compute_idx, -1):
|
||||
for chunk in compute_list[i]:
|
||||
if chunk in next_compute_idx:
|
||||
next_compute_idx[chunk] = i
|
||||
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
||||
return [t for (t, idx) in next_compute_idx]
|
||||
|
||||
@staticmethod
|
||||
def set_const_memory_boundary(cuda_memory_mb: int) -> None:
|
||||
boundary = int(cuda_memory_mb * 1024**2)
|
||||
assert boundary > 0
|
||||
ConstPlacementPolicy._accessed_memory_boundary = boundary
|
||||
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
|
||||
torch.device]) -> None:
|
||||
for p in params:
|
||||
chunk = self.chunk_manager.get_chunk(p)
|
||||
# init offload optim settings
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk.keep_gathered:
|
||||
grads_device_map[p] = get_current_device()
|
||||
else:
|
||||
grads_device_map[p] = torch.device('cpu')
|
||||
|
||||
|
||||
class PlacementPolicyFactory:
|
||||
policies: Dict[str, Type[PlacementPolicy]] = {
|
||||
'cpu': CPUPlacementPolicy,
|
||||
'cuda': CUDAPlacementPolicy,
|
||||
'auto': AutoPlacementPolicy,
|
||||
'const': ConstPlacementPolicy
|
||||
'static': StaticPlacementPolicy,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -239,8 +205,3 @@ class PlacementPolicyFactory:
|
||||
@staticmethod
|
||||
def get_policy_names():
|
||||
return tuple(PlacementPolicyFactory.policies.keys())
|
||||
|
||||
@staticmethod
|
||||
def get_default_device(policy_name: str) -> torch.device:
|
||||
policy_cls = PlacementPolicyFactory.create(policy_name)
|
||||
return policy_cls.get_default_device()
|
||||
|
||||
@@ -64,13 +64,13 @@ def get_static_torch_model(zero_ddp_model,
|
||||
device=torch.device("cpu"),
|
||||
dtype=torch.float32,
|
||||
only_rank_0=True) -> torch.nn.Module:
|
||||
"""Get a static torch.nn.Module model from the given ZeroDDP module.
|
||||
You should notice that the original ZeroDDP model is not modified.
|
||||
"""Get a static torch.nn.Module model from the given GeminiDDP module.
|
||||
You should notice that the original GeminiDDP model is not modified.
|
||||
Thus, you can use the original model in further training.
|
||||
But you should not use the returned torch model to train, this can cause unexpected errors.
|
||||
|
||||
Args:
|
||||
zero_ddp_model (ZeroDDP): a zero ddp model
|
||||
zero_ddp_model (GeminiDDP): a zero ddp model
|
||||
device (torch.device): the device of the final torch model
|
||||
dtype (torch.dtype): the dtype of the final torch model
|
||||
only_rank_0 (bool): if True, only rank0 has the converted torch model
|
||||
@@ -78,8 +78,8 @@ def get_static_torch_model(zero_ddp_model,
|
||||
Returns:
|
||||
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
|
||||
"""
|
||||
from colossalai.zero.gemini.gemini_ddp import ZeroDDP
|
||||
assert isinstance(zero_ddp_model, ZeroDDP)
|
||||
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
|
||||
assert isinstance(zero_ddp_model, GeminiDDP)
|
||||
|
||||
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)
|
||||
colo_model = zero_ddp_model.module
|
||||
|
||||
@@ -57,8 +57,8 @@ class GradientStore(BaseStore):
|
||||
self._grads_of_params[group_id][param_id].append(grad)
|
||||
|
||||
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
|
||||
"""For old gradient accumulation, not in use now.
|
||||
Add a gradient slice on an existing slice of the parameter's gradient
|
||||
"""Add a gradient slice on an existing slice of the parameter's gradient
|
||||
Used when no_sync is not activated.
|
||||
|
||||
Args:
|
||||
grad (Tensor): The split gradient to append to list
|
||||
|
||||
@@ -80,9 +80,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
forced_dtype: Optional[torch.dtype] = None):
|
||||
|
||||
# TODO:
|
||||
# 1. state_dict for checkpoint IO
|
||||
|
||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||
self._dtype = self.optim.param_groups[0]['params'][0].dtype
|
||||
self._logger = get_dist_logger()
|
||||
@@ -277,7 +274,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
sync_tensor(flat_grads_per_rank[rank], grad_list)
|
||||
for grad in grad_list:
|
||||
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id,
|
||||
param_id)) < self._world_size:
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
else:
|
||||
self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
|
||||
|
||||
else:
|
||||
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
|
||||
@@ -291,7 +292,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
sync_tensor(recieved_grad, grad_in_bucket_current_rank)
|
||||
for grad in grad_in_bucket_current_rank:
|
||||
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1:
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
else:
|
||||
self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id)
|
||||
|
||||
self._bucket_store.reset()
|
||||
|
||||
@@ -303,7 +307,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# or got a grad of param from another group
|
||||
# after reduction, the bucket will be empty
|
||||
if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \
|
||||
group_id != self._bucket_store.current_group_id:
|
||||
group_id != self._bucket_store.current_group_id:
|
||||
self._run_reduction()
|
||||
|
||||
padding_size = self._param_store.get_param_padding_size(param)
|
||||
@@ -315,7 +319,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
def backward(self, loss, retain_graph=False):
|
||||
assert not(self._partition_grads and not self.require_grad_sync), \
|
||||
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
||||
"ZeRO2(partition_grads) and no_sync are not compatible"
|
||||
|
||||
if self.mixed_precision_mixin is not None:
|
||||
loss = self.mixed_precision_mixin.pre_backward(loss)
|
||||
|
||||
@@ -537,9 +542,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
working_param = self._param_store.master_to_working_param[id(param)]
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
|
||||
dist.all_gather(gather_tensor, v, group=self.dp_pg)
|
||||
param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
|
||||
gather_tensor = [
|
||||
torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
|
||||
param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(
|
||||
working_param).cpu()
|
||||
zero_state[param][k] = param_state
|
||||
|
||||
states_dict = self._pack_state(zero_state)
|
||||
@@ -562,10 +570,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
v_list = v.split(v.numel() // self._world_size)
|
||||
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach()
|
||||
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone()
|
||||
|
||||
self.optim.load_state_dict(zero_state_dict)
|
||||
zero_state_dict = dict()
|
||||
|
||||
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
|
||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||
@@ -594,9 +601,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
for k, v in states.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
state_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
|
||||
dist.all_gather(state_tensor, v, group=self.dp_pg)
|
||||
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
|
||||
state_tensor = [torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)]
|
||||
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
|
||||
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(
|
||||
working_param).cpu()
|
||||
current_block_size += state_tensor.numel()
|
||||
current_block[k] = state_tensor
|
||||
|
||||
|
||||
@@ -1,5 +1,41 @@
|
||||
# Low Level ZeRO
|
||||
>Low Level ZeRO == ZeRO-DP stage 1 and 2, we would denote it as ZeRO.
|
||||
## Examples of ZeRO and gradient accumulation
|
||||
|
||||
The code below only shows a typical gradient accumulation process, and it drops a lot of details, such as the processing of loss.
|
||||
|
||||
```python
|
||||
# examples of ZeRO1 with gradient accumulation
|
||||
...
|
||||
outputs = model(input)
|
||||
loss = SomeLoss(outputs)
|
||||
if (idx + 1) % ACCUMULATE_STEP != 0:
|
||||
with booster.no_sync(model, optimizer):
|
||||
# under this context, the gradient would not sync when backward,
|
||||
# left each rank having different gradient.
|
||||
# It saves the backward time
|
||||
booster.backward(loss, optimizer)
|
||||
continue
|
||||
else:
|
||||
# need to sync all the accumulated gradient
|
||||
booster.backward(loss, optimizer):
|
||||
optimizer.step()
|
||||
...
|
||||
```
|
||||
|
||||
```python
|
||||
# example of ZeRO2 with gradient accumulation
|
||||
|
||||
...
|
||||
outputs = model(input)
|
||||
loss = SomeLoss(outputs)
|
||||
# ZeRO2 split the gradients and can NOT accumulate gradient with syncing.
|
||||
booster.backward(loss, optimizer)
|
||||
if (idx + 1) % ACCUMULATE_STEP == 0:
|
||||
optimizer.step()
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
## Design:
|
||||
### Notion
|
||||
@@ -25,11 +61,11 @@ The data structure looks like this:
|
||||
```
|
||||
After that, the gradients would be flattened by rank, and the data structure looks like this:
|
||||
```
|
||||
# g-0 means flatten([g-00, g-10])
|
||||
# g-X0 means flatten([g-00, g-10])
|
||||
{
|
||||
0: [g-0],
|
||||
1: [g-1],
|
||||
2: [g-2]
|
||||
0: [g-X0],
|
||||
1: [g-X1],
|
||||
2: [g-X2]
|
||||
}
|
||||
```
|
||||
For zero1, we iterate the dictionary and do `all_reduce`. For zero2, we can just do `reduce-scatter`.
|
||||
|
||||
@@ -109,6 +109,6 @@ def zero_optim_wrapper(model: nn.Module,
|
||||
config_dict['clip_grad_norm'] = max_norm
|
||||
return LowLevelZeroOptimizer(optimizer, **config_dict, verbose=verbose)
|
||||
else:
|
||||
from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer
|
||||
from colossalai.zero.gemini.gemini_optimizer import GeminiOptimizer
|
||||
config_dict['clipping_norm'] = max_norm
|
||||
return ZeroOptimizer(optimizer, model, **config_dict, verbose=verbose)
|
||||
return GeminiOptimizer(optimizer, model, **config_dict, verbose=verbose)
|
||||
|
||||
Reference in New Issue
Block a user