Merge branch 'main' into feature/shardformer

This commit is contained in:
Hongxin Liu
2023-09-04 23:43:13 +08:00
committed by GitHub
138 changed files with 4664 additions and 4219 deletions

View File

@@ -2,19 +2,21 @@ import itertools
from collections import OrderedDict
from contextlib import nullcontext
from functools import partial
from typing import Dict, Iterator, List, Optional, Set, Tuple, Union
from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.checkpoint_io.utils import calculate_tensor_size, StateDictSharder
from colossalai.interface import ModelWrapper
from colossalai.checkpoint_io.utils import StateDictSharder
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor import ReplicaSpec
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.data_parallel import _cast_float, free_storage
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device, is_ddp_ignored
@@ -30,14 +32,13 @@ except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
__all__ = [
'ZeroDDP',
'GeminiDDP',
]
class ZeroDDP(ColoDDP):
"""ZeRO DDP for ColoTensor.
Warning: Nested ZeroDDP is not supported now.
class GeminiDDP(ModelWrapper):
"""ZeRO DDP.
Warning: Nested GeminiDDP is not supported now.
It is designed to be used with ChunkManager and GeminiManager.
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
@@ -54,20 +55,54 @@ class ZeroDDP(ColoDDP):
mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16.
"""
def __init__(self,
module: torch.nn.Module,
gemini_manager: GeminiManager,
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
mixed_precision: torch.dtype = torch.float16) -> None:
def __init__(
self,
module: torch.nn.Module,
chunk_config_dict: Optional[dict] = None,
chunk_init_device: torch.device = torch.device('cpu'),
placement_policy: str = "static",
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
search_range_m: int = 32, # chunk search options
hidden_dim: Optional[int] = None, # chunk search options
min_chunk_size_m: float = 32, # chunk search options
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
mixed_precision: torch.dtype = torch.float16,
process_group: Optional[ProcessGroup] = None,
memstats: Optional[MemStats] = None, # genimi memory stats
verbose: bool = False) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
if chunk_config_dict is not None:
self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device)
else:
# some ugly hotfix for the compatibility with Lightning
if search_range_m is None:
search_range_m = 32
self.chunk_manager = init_chunk_manager(model=module,
init_device=chunk_init_device,
hidden_dim=hidden_dim,
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
process_group=process_group,
verbose=verbose)
self.gemini_manager = GeminiManager(placement_policy,
self.chunk_manager,
memstats,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio)
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(gemini_manager)
self.fp32_params: List[ColoTensor] = list()
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
@@ -75,6 +110,7 @@ class ZeroDDP(ColoDDP):
self.name2param: Dict[str, nn.Parameter] = dict()
self.scatter_after_inference = scatter_after_inference
self.mixed_precision = mixed_precision
self.dp_process_group = process_group or _get_default_group()
self._logger = get_dist_logger()
@@ -88,20 +124,67 @@ class ZeroDDP(ColoDDP):
for p in module.parameters():
param_order.append(p)
self._init_chunks(param_order=param_order,
strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != 'cuda',
pin_memory=pin_memory)
for name, param in module.named_parameters():
self.param2name[param] = name
for m_name, m_var in module.named_modules():
for p_name, p_var in m_var.named_parameters(recurse=False):
param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var
super().__init__(module, process_group=ColoProcessGroup())
self._init_chunks(param_order=param_order,
strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != 'cuda',
pin_memory=pin_memory)
super().__init__(module)
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
self._cast_buffers()
# register grad hook
for p in module.parameters():
if is_ddp_ignored(p):
continue
if p.requires_grad:
p.register_hook(partial(self.grad_handle, p))
def parameters(self, recurse: bool = True):
return self.module.parameters(recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True):
return self.module.named_parameters(prefix, recurse)
def named_buffers(self, prefix: str = '', recurse: bool = True):
return self.module.named_buffers(prefix, recurse)
def named_children(self):
return self.module.named_children()
def named_modules(self,
memo: Optional[Set[torch.nn.Module]] = None,
prefix: str = '',
remove_duplicate: bool = True):
return self.module.named_modules(memo, prefix, remove_duplicate)
@staticmethod
def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None:
"""Sets parameters to be ignored by DDP.
This method must be called before initializing ColoDDP.
Example:
>>> params_to_ignore = []
>>> for p in module.parameters():
>>> if should_ignore(p):
>>> params_to_ignore.append(p)
>>> ColoDDP.set_params_to_ignore(params_to_ignore)
>>> module = ColoDDP(module)
Args:
params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored.
"""
for p in params_to_ignore:
p._ddp_to_ignore = True
def unwrap(self):
# as save/load state dict is overwrited, only return self
return self
def _get_non_persistent_buffers_set(self,
module,
@@ -207,7 +290,7 @@ class ZeroDDP(ColoDDP):
error_params.append(self.param2name[param])
error_str = "\n\t".join(error_params)
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
"The most possible reason is that the model is not compatible with ZeroDDP.\n",
"The most possible reason is that the model is not compatible with GeminiDDP.\n",
f"{error_str}")
self._setup_grads_ptr()
self._logger.debug(
@@ -227,6 +310,7 @@ class ZeroDDP(ColoDDP):
self._post_backward()
def grad_handle(self, p, grad):
setattr(p, "_gemini_reduced", True)
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
@@ -533,7 +617,7 @@ class ZeroDDP(ColoDDP):
for chunk_32 in chunk_list:
chunk_16 = chunk_32.paired_chunk
assert chunk_16 is not None
chunk_16.optim_update()
chunk_16.payload.copy_(chunk_32.payload)
for name, buf in persistent_buffers.items():
if buf is not None:
@@ -557,17 +641,11 @@ class ZeroDDP(ColoDDP):
unexpected_keys.append(key)
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
ddp_pg = ColoProcessGroup()
dp_world_size = dist.get_world_size(self.dp_process_group)
for p in param_order.generate():
self._preprocess_param(p)
assert type(p) is ColoParameter
# gather sharded parameters in the strict ddp mode
if strict_ddp_mode:
if not p.is_replicate():
p.set_dist_spec(ReplicaSpec())
p.set_process_group(pg=ddp_pg)
# ignore the parameters with no gradient
if not p.requires_grad:
self.set_params_to_ignore([p])
@@ -578,38 +656,37 @@ class ZeroDDP(ColoDDP):
continue
# create a fp32 parameter
fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
fp32_p = p.data.float()
# create a fp16 parameter
p.data = p.data.to(self.mixed_precision)
# register the fp16 parameter and fp32 parameter in the chunk manager
dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.register_tensor(tensor=p,
group_type='fp16_param',
config_key=dp_world_size,
process_group=self.dp_process_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.register_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
process_group=self.dp_process_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.fp16_params.append(p)
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups()
self.gemini_manager.setup_grads_device(self.fp16_params, self.grads_device)
# move master weights to corresponding device and setup paired chunks
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16)
# keep gathered chunks are in CUDA
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
if chunk_32.device_type != self.grads_device[p].type:
self.chunk_manager.move_chunk(chunk_32, self.grads_device[p])
def _cast_buffers(self):
for buffer in self.module.buffers():
@@ -705,65 +782,3 @@ class ZeroDDP(ColoDDP):
yield sharder.current_block, sharder.current_block_size
class GeminiDDP(ZeroDDP):
def __init__(self,
module: torch.nn.Module,
device: torch.device,
placement_policy: str = "cpu",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
search_range_m: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_m: float = 32,
memstats: Optional[MemStats] = None,
mixed_precision: torch.dtype = torch.float16,
verbose: bool = False) -> None:
"""
A torch.Module wrapper using ZeRO-DP and Gemini.
ZeRO is for parallel. Gemini is for memory management.
WARNING: The class will modify the module inline!
Example:
model is initialized under the context of ColoInitContext
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): the model to be wrapped.
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
If the aggregate size of parameters is still smaller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
# some ugly hotfix for the compatibility with Lightning
if search_range_m is None:
search_range_m = 32
chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
verbose=verbose)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module,
gemini_manager,
pin_memory,
force_outputs_fp32,
strict_ddp_mode,
scatter_after_inference,
mixed_precision=mixed_precision)