[elixir] add elixir plugin and its unit test (#3865)

This commit is contained in:
Haichen Huang 2023-05-31 12:10:44 +08:00 committed by GitHub
parent 206280408a
commit dbb9659099
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 386 additions and 96 deletions

View File

@ -1,9 +1,10 @@
from .elixir_plugin import ElixirPlugin
from .gemini_plugin import GeminiPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin']
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin', 'ElixirPlugin']
import torch
from packaging import version

View File

@ -0,0 +1,243 @@
import logging
import os
import warnings
from pathlib import Path
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import load_state_dict, save_state_dict
from colossalai.cluster import DistCoordinator
from colossalai.elixir import ElixirModule, ElixirOptimizer
from colossalai.elixir.cuda import set_memory_fraction
from colossalai.elixir.search import minimum_waste_search, optimal_search
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from .dp_plugin_base import DPPluginBase
__all__ = ['ElixirPlugin']
class ElixirCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
def load_unsharded_model(self, model: ElixirModule, checkpoint: str):
"""
Load available model states from checkpoint.
"""
if self.coordinator.is_master():
checkpoint = load_state_dict(checkpoint)
else:
checkpoint = None
model.load_state_dict(checkpoint, only_rank_0=True)
def save_unsharded_model(self, model: ElixirModule, checkpoint: str, use_safetensors: bool = False):
"""
Save model states to checkpoint but only on master process.
"""
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
# TODO: optimizer state dict is sharded
warnings.warn('ElixirPlugin does not support save full optimizer checkpoint now. Save it on every process.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
warnings.warn(
'ElixirPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().load_optimizer(optimizer, checkpoint)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
class ELXModel(ModelWrapper):
def __init__(self, module: nn.Module, search_func: Callable, search_config: Dict, module_config: Dict) -> None:
super().__init__(module)
sr = search_func(module, **search_config)
self.module = ElixirModule(module, sr, **module_config)
def unwrap(self):
# just return the ElixirModule instance
return self.module
class ELXOptimizer(OptimizerWrapper):
def __init__(self, module: ElixirModule, optimizer: Optimizer, optimizer_config: dict) -> None:
optimizer = ElixirOptimizer(module, optimizer, **optimizer_config, init_step=True)
super().__init__(optimizer)
def backward(self, loss: Tensor, *args, **kwargs):
self.optim.backward(loss)
def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> Tensor:
warnings.warn(f'Elixir controls grad clipping by itself, so you should set the max_norm before training.')
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('Elixir does not support clip_grad_by_value')
class ElixirPlugin(DPPluginBase):
"""
Plugin for Elixir.
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import ElixirPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = ElixirPlugin()
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
Args:
search_type (str): The used search algorithm for the chunk initialization, 'mini_waste' or 'optimal'.
dtype (torch.dtype): The data type used in computations, torch.float or torch.float16.
If torch.float16 is used, AMP is enabled automatically.
prefetch (bool): Whether to prefetch chunks for overlapping.
Users should provide example_input and example_step_fn if prefetch is True.
cpu_offload (bool): Whether to offload optimizer states (OS).
Only available when the search_type is 'mini_waste'.
pin_memory (bool): Whether to store OS in the pinned cpu memory.
Only available when cpu_offload is enabled.
reduce_always_fp32 (bool): Whether to reduce gradients in fp32.
outputs_always_fp32 (bool): Whether to cast outputs to fp32.
example_input (Dict): An example input for the model.
example_step_fn (Callable): A callable function that takes the model and the example input as input, and does a training step.
optimizer_type (str): The type of optimizer, 'Adam' or 'SGD'.
Only used when the search type is 'optimal'.
optimizer_config (Dict): The config of the optimizer.
This config is commonly used in AMP.
See the class `ElixirOptimizer` for more details.
cuda_memory_fraction (float): The fraction of the GPU memory used Elixir.
"""
def __init__(self,
search_type: str = 'mini_waste',
dtype: torch.dtype = torch.float32,
prefetch: bool = False,
cpu_offload: bool = False,
pin_memory: bool = False,
reduce_always_fp32: bool = False,
outputs_always_fp32: bool = False,
example_input: Optional[Dict] = None,
example_step_fn: Optional[Callable] = None,
optimizer_type: str = 'Adam',
optimizer_config: Optional[Dict] = None,
cuda_memory_fraction: float = 1.0,
verbose: bool = False) -> None:
super().__init__()
assert search_type in {'mini_waste', 'optimal'}
assert dtype in {torch.float, torch.float16}
self.dtype = dtype
self.verbose = verbose
self.world_size = dist.get_world_size()
self.world_group = dist.group.WORLD
set_memory_fraction(fraction=cuda_memory_fraction)
if search_type == 'mini_waste':
self.search_func = minimum_waste_search
self.search_config = dict(group_size=self.world_size,
unified_dtype=self.dtype,
prefetch=prefetch,
cpu_offload=cpu_offload,
pin_memory=pin_memory,
inp=example_input,
step_fn=example_step_fn,
verbose=self.verbose)
elif search_type == 'optimal':
self.search = optimal_search
self.search_config = dict(group_size=self.world_size,
unified_dtype=self.dtype,
optimizer_type=optimizer_type,
overlap=prefetch,
inp=example_input,
step_fn=example_step_fn,
verbose=self.verbose)
else:
raise NotImplementedError
self.module_config = dict(process_group=self.world_group,
prefetch=prefetch,
dtype=self.dtype,
reduce_always_fp32=reduce_always_fp32,
output_fp32=outputs_always_fp32)
if optimizer_config is None:
optimizer_config = dict()
self.optimizer_config = optimizer_config
def support_no_sync(self) -> bool:
return False
def control_precision(self) -> bool:
return True
def supported_precisions(self) -> List[str]:
return ['fp16', 'fp32']
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
return ['cuda']
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
if not isinstance(model, ModelWrapper):
model = ELXModel(module=model,
search_func=self.search_func,
search_config=self.search_config,
module_config=self.module_config)
if not isinstance(optimizer, OptimizerWrapper):
optimizer = ELXOptimizer(module=model.unwrap(), optimizer=optimizer, optimizer_config=self.optimizer_config)
return model, optimizer, criterion, dataloader, lr_scheduler
def control_checkpoint_io(self) -> bool:
return True
def get_checkpoint_io(self) -> CheckpointIO:
return ElixirCheckpointIO()
def no_sync(self, model: nn.Module) -> Iterator[None]:
raise NotImplementedError

View File

@ -211,6 +211,9 @@ class Chunk:
def reduce_check(self):
return self.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == self.num_tensors
def enable_l2_norm_flag(self) -> None:
self.l2_norm_flag = True
def set_overflow_flag(self, valid_tensor: torch.Tensor) -> None:
assert not self.overflow
self.overflow = torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, Set
from colossalai.elixir.chunk.core import Chunk
@ -12,7 +12,7 @@ class ChunkScheduler(ABC):
def __init__(self) -> None:
super().__init__()
self.releasable_set: Optional[set] = None
self.releasable_set: Optional[Set[Chunk]] = None
self.current_step = -1
@abstractmethod

View File

@ -12,6 +12,17 @@ from .functions import postfwd_prebwd_function, prefwd_postbwd_function
from .storage import BufferStore
def always_skip(func, args, kwargs) -> bool:
if is_no_hook_op(func):
return True
if func is torch.Tensor.reshape_as:
if isinstance(args[0], HookParam):
return False
else:
return True
return False
class HookParam(OutplaceTensor, nn.Parameter):
"""HookParam is a special type of tensor that is used to triggered hooks on parameters.
HookParam adds chunk fetching before torch functions.
@ -43,7 +54,7 @@ class HookParam(OutplaceTensor, nn.Parameter):
if kwargs is None:
kwargs = {}
if is_no_hook_op(func):
if always_skip(func, args, kwargs):
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
return ret

View File

@ -18,6 +18,11 @@ from colossalai.elixir.tensor import OutplaceTensor
from colossalai.utils.model.experimental import LazyTensor
def is_leaf_module(m: nn.Module):
special_modules = [nn.MultiheadAttention]
return type(m) in special_modules
def get_param_optim_data(param_data: torch.Tensor, param_dtype: torch.dtype):
param_data = param_data.to(gpu_device())
optim_data = param_data.clone() if param_data.dtype == torch.float else param_data.float()
@ -71,6 +76,7 @@ class ElixirModule(nn.Module):
assert name in self.no_grad_state_dict
continue
assert name in self.grad_state_dict
# param.debug_name = name
param.register_hook(partial(self._gradient_handler, param=param))
param.__class__ = HookParam
@ -165,8 +171,9 @@ class ElixirModule(nn.Module):
buffer_size = 0
for submodule in self.modules():
sum_param_size = 0
for param in submodule.parameters(recurse=False):
if not param.requires_grad or self.fetcher.is_in_fused(param):
recurse_flag = is_leaf_module(submodule)
for param in submodule.parameters(recurse=recurse_flag):
if not param.requires_grad:
continue
assert param.dtype == self.dtype
sum_param_size += param.numel()

View File

@ -91,6 +91,10 @@ class ElixirOptimizer(colo_optim.ColossalaiOptimizer):
# allocate memory before training
self.__zero_step()
if self.clipping_flag:
for param_chunk in self.param_chunk_set:
param_chunk.enable_l2_norm_flag()
def __zero_step(self):
torch.cuda.empty_cache()

View File

@ -19,13 +19,18 @@ def data_gen_fn():
output_transform_fn = lambda x: x
def output_bert(x):
return dict(pooler_output=x.get('pooler_output'))
config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
# register the BERT variants
model_zoo.register(name='transformers_bert',
model_fn=lambda: transformers.BertModel(config),
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
output_transform_fn=output_bert,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_pretraining',
model_fn=lambda: transformers.BertForPreTraining(config),

View File

@ -0,0 +1,105 @@
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import ElixirPlugin
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
def run_fn(model_fn, data_gen_fn, output_transform_fn):
os_config = dict(initial_scale=64, max_norm=1.0)
plugin = ElixirPlugin(optimizer_config=os_config)
booster = Booster(plugin=plugin)
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, optimizer)
optimizer.step()
def check_elixir_plugin(early_stop: bool = True):
"""check elixir plugin over model zoo
Args:
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
"""
passed_info = {}
failed_info = {}
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
# have not been tested with torchrec
if name.startswith('torchrec'):
continue
# dm_nfnet is not supported because of the skipinit_gain parameter in its NormFreeBlock
# there is `out.mul_(self.skipinit_gain)`, which should be changed to `out *= self.skipinit_gain`
if name in ['timm_dm_nfnet']:
continue
# Elixir stipulate that parameters with gradients should have gradients after the backward pass
# here are some unsupported models
# these models use layer drop
# some randomly selected layers are not used in computations
if name in ['torchaudio_wav2vec2_base', 'torchaudio_hubert_base']:
continue
# because our criterion function is too simple to generate gradients for all parameters
# following models are not supported
# users should provide complete input data to use all parameters
if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'transformers_albert',
'transformers_albert_for_pretraining', 'transformers_bert_for_pretraining',
'transformers_gpt_double_heads', 'transformers_t5', 'transformers_t5_for_conditional_generation',
'transformers_t5_encoder_model'):
continue
# currently, nn.RNN is not supported yet
if name in ('torchaudio_deepspeech', 'torchaudio_wavernn', 'torchaudio_tacotron'):
continue
try:
run_fn(model_fn, data_gen_fn, output_transform_fn)
passed_info[name] = 'passed'
except Exception as e:
failed_info[name] = str(e)
print(f"failed model name: {name}")
if early_stop:
raise e
torch.cuda.empty_cache()
if dist.get_rank() == 0:
print(f'Passed models({len(passed_info)}): {list(passed_info.keys())}\n\n')
print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n')
assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])
def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_elixir_plugin(early_stop=early_stop)
@pytest.mark.skip(reason="skip this test now")
@rerun_if_address_is_in_use()
def test_elixir_plugin(early_stop: bool = True):
spawn(run_dist, 1, early_stop=early_stop)
if __name__ == '__main__':
test_elixir_plugin(early_stop=True)

View File

@ -1,89 +0,0 @@
import torch
import torch.distributed as dist
import colossalai
from colossalai.elixir import ElixirModule, ElixirOptimizer
from colossalai.elixir.search import minimum_waste_search
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
def check_elixir_compatibility(early_stop: bool = True):
"""check gemini plugin over model zoo
Args:
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
"""
passed_models = []
failed_info = {} # (model_name, error) pair
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
# These models lead to CUDA error
if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext',
'torchaudio_wav2vec2_base', 'torchaudio_hubert_base', 'torchvision_convnext_base'):
continue
try:
print(name)
global_size = dist.get_world_size()
global_group = dist.GroupMember.WORLD
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
sr = minimum_waste_search(
# pre-commit: do not rearrange
m=model,
group_size=global_size,
unified_dtype=torch.float16,
prefetch=False,
verbose=True)
model = ElixirModule(model, sr, global_group, prefetch=False, dtype=torch.float16)
optimizer = ElixirOptimizer(model, optimizer, initial_scale=32)
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
optimizer.backward(loss)
optimizer.step()
passed_models.append(name)
del model, optimizer, criterion, data, output, loss
except Exception as e:
failed_info[name] = e
if early_stop:
raise e
torch.cuda.empty_cache()
if dist.get_rank() == 0:
print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n')
assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])
def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_elixir_compatibility(early_stop=early_stop)
@rerun_if_address_is_in_use()
def exam_compatibility(early_stop: bool = True):
spawn(run_dist, 2, early_stop=early_stop)
if __name__ == '__main__':
exam_compatibility(early_stop=False)