mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
Merge branch 'main' into feature/shardformer
This commit is contained in:
@@ -2,19 +2,21 @@ import itertools
|
||||
from collections import OrderedDict
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Dict, Iterator, List, Optional, Set, Tuple, Union
|
||||
from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
from colossalai.checkpoint_io.utils import calculate_tensor_size, StateDictSharder
|
||||
from colossalai.interface import ModelWrapper
|
||||
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder
|
||||
from colossalai.lazy import LazyTensor
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.tensor import ReplicaSpec
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float, free_storage
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import get_current_device, is_ddp_ignored
|
||||
|
||||
@@ -30,14 +32,13 @@ except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
|
||||
__all__ = [
|
||||
'ZeroDDP',
|
||||
'GeminiDDP',
|
||||
]
|
||||
|
||||
|
||||
class ZeroDDP(ColoDDP):
|
||||
"""ZeRO DDP for ColoTensor.
|
||||
Warning: Nested ZeroDDP is not supported now.
|
||||
class GeminiDDP(ModelWrapper):
|
||||
"""ZeRO DDP.
|
||||
Warning: Nested GeminiDDP is not supported now.
|
||||
It is designed to be used with ChunkManager and GeminiManager.
|
||||
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
|
||||
|
||||
@@ -54,20 +55,54 @@ class ZeroDDP(ColoDDP):
|
||||
mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
gemini_manager: GeminiManager,
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True,
|
||||
mixed_precision: torch.dtype = torch.float16) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
chunk_config_dict: Optional[dict] = None,
|
||||
chunk_init_device: torch.device = torch.device('cpu'),
|
||||
placement_policy: str = "static",
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
|
||||
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
|
||||
search_range_m: int = 32, # chunk search options
|
||||
hidden_dim: Optional[int] = None, # chunk search options
|
||||
min_chunk_size_m: float = 32, # chunk search options
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True,
|
||||
mixed_precision: torch.dtype = torch.float16,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
memstats: Optional[MemStats] = None, # genimi memory stats
|
||||
verbose: bool = False) -> None:
|
||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
if chunk_config_dict is not None:
|
||||
self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device)
|
||||
else:
|
||||
# some ugly hotfix for the compatibility with Lightning
|
||||
if search_range_m is None:
|
||||
search_range_m = 32
|
||||
self.chunk_manager = init_chunk_manager(model=module,
|
||||
init_device=chunk_init_device,
|
||||
hidden_dim=hidden_dim,
|
||||
search_range_m=search_range_m,
|
||||
min_chunk_size_m=min_chunk_size_m,
|
||||
strict_ddp_flag=strict_ddp_mode,
|
||||
process_group=process_group,
|
||||
verbose=verbose)
|
||||
self.gemini_manager = GeminiManager(placement_policy,
|
||||
self.chunk_manager,
|
||||
memstats,
|
||||
shard_param_frac=shard_param_frac,
|
||||
offload_optim_frac=offload_optim_frac,
|
||||
offload_param_frac=offload_param_frac,
|
||||
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
|
||||
steady_cuda_cap_ratio=steady_cuda_cap_ratio)
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = GeminiZeROHook(gemini_manager)
|
||||
self.fp32_params: List[ColoTensor] = list()
|
||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
||||
self.fp32_params: List[torch.Tensor] = list()
|
||||
self.fp16_params: List[ColoParameter] = list()
|
||||
self.overflow_counter = 0
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||
@@ -75,6 +110,7 @@ class ZeroDDP(ColoDDP):
|
||||
self.name2param: Dict[str, nn.Parameter] = dict()
|
||||
self.scatter_after_inference = scatter_after_inference
|
||||
self.mixed_precision = mixed_precision
|
||||
self.dp_process_group = process_group or _get_default_group()
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
@@ -88,20 +124,67 @@ class ZeroDDP(ColoDDP):
|
||||
for p in module.parameters():
|
||||
param_order.append(p)
|
||||
|
||||
self._init_chunks(param_order=param_order,
|
||||
strict_ddp_mode=strict_ddp_mode,
|
||||
cpu_offload=self.gemini_manager.policy_name != 'cuda',
|
||||
pin_memory=pin_memory)
|
||||
|
||||
for name, param in module.named_parameters():
|
||||
self.param2name[param] = name
|
||||
for m_name, m_var in module.named_modules():
|
||||
for p_name, p_var in m_var.named_parameters(recurse=False):
|
||||
param_name = m_name + '.' + p_name if m_name else p_name
|
||||
self.name2param[param_name] = p_var
|
||||
super().__init__(module, process_group=ColoProcessGroup())
|
||||
|
||||
self._init_chunks(param_order=param_order,
|
||||
strict_ddp_mode=strict_ddp_mode,
|
||||
cpu_offload=self.gemini_manager.policy_name != 'cuda',
|
||||
pin_memory=pin_memory)
|
||||
super().__init__(module)
|
||||
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
|
||||
self._cast_buffers()
|
||||
# register grad hook
|
||||
for p in module.parameters():
|
||||
if is_ddp_ignored(p):
|
||||
continue
|
||||
if p.requires_grad:
|
||||
p.register_hook(partial(self.grad_handle, p))
|
||||
|
||||
def parameters(self, recurse: bool = True):
|
||||
return self.module.parameters(recurse)
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True):
|
||||
return self.module.named_parameters(prefix, recurse)
|
||||
|
||||
def named_buffers(self, prefix: str = '', recurse: bool = True):
|
||||
return self.module.named_buffers(prefix, recurse)
|
||||
|
||||
def named_children(self):
|
||||
return self.module.named_children()
|
||||
|
||||
def named_modules(self,
|
||||
memo: Optional[Set[torch.nn.Module]] = None,
|
||||
prefix: str = '',
|
||||
remove_duplicate: bool = True):
|
||||
return self.module.named_modules(memo, prefix, remove_duplicate)
|
||||
|
||||
@staticmethod
|
||||
def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None:
|
||||
"""Sets parameters to be ignored by DDP.
|
||||
This method must be called before initializing ColoDDP.
|
||||
|
||||
Example:
|
||||
>>> params_to_ignore = []
|
||||
>>> for p in module.parameters():
|
||||
>>> if should_ignore(p):
|
||||
>>> params_to_ignore.append(p)
|
||||
>>> ColoDDP.set_params_to_ignore(params_to_ignore)
|
||||
>>> module = ColoDDP(module)
|
||||
|
||||
Args:
|
||||
params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored.
|
||||
"""
|
||||
for p in params_to_ignore:
|
||||
p._ddp_to_ignore = True
|
||||
|
||||
def unwrap(self):
|
||||
# as save/load state dict is overwrited, only return self
|
||||
return self
|
||||
|
||||
def _get_non_persistent_buffers_set(self,
|
||||
module,
|
||||
@@ -207,7 +290,7 @@ class ZeroDDP(ColoDDP):
|
||||
error_params.append(self.param2name[param])
|
||||
error_str = "\n\t".join(error_params)
|
||||
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
|
||||
"The most possible reason is that the model is not compatible with ZeroDDP.\n",
|
||||
"The most possible reason is that the model is not compatible with GeminiDDP.\n",
|
||||
f"{error_str}")
|
||||
self._setup_grads_ptr()
|
||||
self._logger.debug(
|
||||
@@ -227,6 +310,7 @@ class ZeroDDP(ColoDDP):
|
||||
self._post_backward()
|
||||
|
||||
def grad_handle(self, p, grad):
|
||||
setattr(p, "_gemini_reduced", True)
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
with torch._C.DisableTorchFunction():
|
||||
@@ -533,7 +617,7 @@ class ZeroDDP(ColoDDP):
|
||||
for chunk_32 in chunk_list:
|
||||
chunk_16 = chunk_32.paired_chunk
|
||||
assert chunk_16 is not None
|
||||
chunk_16.optim_update()
|
||||
chunk_16.payload.copy_(chunk_32.payload)
|
||||
|
||||
for name, buf in persistent_buffers.items():
|
||||
if buf is not None:
|
||||
@@ -557,17 +641,11 @@ class ZeroDDP(ColoDDP):
|
||||
unexpected_keys.append(key)
|
||||
|
||||
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
|
||||
ddp_pg = ColoProcessGroup()
|
||||
dp_world_size = dist.get_world_size(self.dp_process_group)
|
||||
for p in param_order.generate():
|
||||
self._preprocess_param(p)
|
||||
assert type(p) is ColoParameter
|
||||
|
||||
# gather sharded parameters in the strict ddp mode
|
||||
if strict_ddp_mode:
|
||||
if not p.is_replicate():
|
||||
p.set_dist_spec(ReplicaSpec())
|
||||
p.set_process_group(pg=ddp_pg)
|
||||
|
||||
# ignore the parameters with no gradient
|
||||
if not p.requires_grad:
|
||||
self.set_params_to_ignore([p])
|
||||
@@ -578,38 +656,37 @@ class ZeroDDP(ColoDDP):
|
||||
continue
|
||||
|
||||
# create a fp32 parameter
|
||||
fp32_data = p.data.float()
|
||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||
fp32_p = p.data.float()
|
||||
# create a fp16 parameter
|
||||
p.data = p.data.to(self.mixed_precision)
|
||||
|
||||
# register the fp16 parameter and fp32 parameter in the chunk manager
|
||||
dp_world_size = p.process_group.dp_world_size()
|
||||
self.chunk_manager.register_tensor(tensor=p,
|
||||
group_type='fp16_param',
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.chunk_manager.register_tensor(tensor=fp32_p,
|
||||
group_type='fp32_param',
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
|
||||
self.fp16_params.append(p)
|
||||
self.fp32_params.append(fp32_p)
|
||||
self.grads_device[p] = self.gemini_manager.default_device
|
||||
|
||||
self.chunk_manager.close_all_groups()
|
||||
|
||||
self.gemini_manager.setup_grads_device(self.fp16_params, self.grads_device)
|
||||
# move master weights to corresponding device and setup paired chunks
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||
chunk_32.init_pair(chunk_16)
|
||||
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk_16.keep_gathered:
|
||||
self.grads_device[p] = get_current_device()
|
||||
if chunk_32.device_type != self.grads_device[p].type:
|
||||
self.chunk_manager.move_chunk(chunk_32, self.grads_device[p])
|
||||
|
||||
def _cast_buffers(self):
|
||||
for buffer in self.module.buffers():
|
||||
@@ -705,65 +782,3 @@ class ZeroDDP(ColoDDP):
|
||||
yield sharder.current_block, sharder.current_block_size
|
||||
|
||||
|
||||
class GeminiDDP(ZeroDDP):
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
device: torch.device,
|
||||
placement_policy: str = "cpu",
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True,
|
||||
search_range_m: int = 32,
|
||||
hidden_dim: Optional[int] = None,
|
||||
min_chunk_size_m: float = 32,
|
||||
memstats: Optional[MemStats] = None,
|
||||
mixed_precision: torch.dtype = torch.float16,
|
||||
verbose: bool = False) -> None:
|
||||
"""
|
||||
A torch.Module wrapper using ZeRO-DP and Gemini.
|
||||
ZeRO is for parallel. Gemini is for memory management.
|
||||
WARNING: The class will modify the module inline!
|
||||
|
||||
Example:
|
||||
model is initialized under the context of ColoInitContext
|
||||
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
|
||||
>>> logits = model(x)
|
||||
>>> loss = criterion(logits, labels)
|
||||
>>> model.backward(loss)
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): the model to be wrapped.
|
||||
device (torch.device): device to place the model.
|
||||
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
|
||||
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
||||
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
||||
search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
|
||||
hidden_dim (int, optional): the hidden dimension of DNN.
|
||||
Users can provide this argument to speed up searching.
|
||||
If users do not know this argument before training, it is ok. We will use a default value 1024.
|
||||
min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
|
||||
If the aggregate size of parameters is still smaller than the minimum chunk size,
|
||||
all parameters will be compacted into one small chunk.
|
||||
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
|
||||
"""
|
||||
# some ugly hotfix for the compatibility with Lightning
|
||||
if search_range_m is None:
|
||||
search_range_m = 32
|
||||
|
||||
chunk_manager = init_chunk_manager(model=module,
|
||||
init_device=device,
|
||||
hidden_dim=hidden_dim,
|
||||
search_range_m=search_range_m,
|
||||
min_chunk_size_m=min_chunk_size_m,
|
||||
strict_ddp_flag=strict_ddp_mode,
|
||||
verbose=verbose)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||
super().__init__(module,
|
||||
gemini_manager,
|
||||
pin_memory,
|
||||
force_outputs_fp32,
|
||||
strict_ddp_mode,
|
||||
scatter_after_inference,
|
||||
mixed_precision=mixed_precision)
|
||||
|
Reference in New Issue
Block a user