diff --git a/colossalai/booster/__init__.py b/colossalai/booster/__init__.py index d475676ba..3b3f45bb0 100644 --- a/colossalai/booster/__init__.py +++ b/colossalai/booster/__init__.py @@ -2,4 +2,3 @@ from .accelerator import Accelerator from .booster import Booster from .environment_table import EnvironmentTable from .plugin import Plugin -from .precision import Precision diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 4aae200a0..7b351ae34 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -1,37 +1,95 @@ from contextlib import contextmanager -from typing import Callable, Iterator, List, Optional, Tuple, Union +from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union import torch import torch.nn as nn +from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader +from .mixed_precision import MixedPrecision, mixed_precision_factory from .plugin import Plugin __all__ = ['Booster'] class Booster: + """ + Booster is a high-level API for training neural networks. It provides a unified interface for + training with different precisio, accelerator, and plugin. + + Examples: + >>> colossalai.launch(...) + >>> plugin = GeminiPlugin(stage=3, ...) + >>> booster = Booster(precision='fp16', plugin=plugin) + >>> + >>> model = GPT2() + >>> optimizer = Adam(model.parameters()) + >>> dataloader = Dataloader(Dataset) + >>> lr_scheduler = LinearWarmupScheduler() + >>> criterion = GPTLMLoss() + >>> + >>> model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader) + >>> + >>> for epoch in range(max_epochs): + >>> for input_ids, attention_mask in dataloader: + >>> outputs = model(input_ids, attention_mask) + >>> loss = criterion(outputs.logits, input_ids) + >>> booster.backward(loss, optimizer) + >>> optimizer.step() + >>> lr_scheduler.step() + >>> optimizer.zero_grad() + + + Args: + device (str or torch.device): The device to run the training. Default: 'cuda'. + mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None. + If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'. + 'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex. + plugin (Plugin): The plugin to run the training. Default: None. + """ def __init__(self, device: Union[str, torch.device] = 'cuda', - precision: str = 'fp32', - grad_clipping_type: str = 'norm', - grad_clipping_value: float = 0.0, + mixed_precision: Union[MixedPrecision, str] = None, plugin: Optional[Plugin] = None) -> None: - # TODO: implement this method - pass + # validate and set precision + if isinstance(MixedPrecision, str): + # the user will take the default arguments for amp training + self.mixed_precision = mixed_precision_factory(mixed_precision) + elif isinstance(mixed_precision, MixedPrecision): + # the user can customize the arguments by passing the precision object + self.mixed_precision = mixed_precision + else: + raise ValueError( + f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.' + ) - def boost( - self, *args: Union[nn.Module, Optimizer, LRScheduler, DataLoader] - ) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: - # TODO: implement this method - pass + def boost(self, model: nn.Module, optimizer: Optimizer, criterion: Callable, lr_scheduler: LRScheduler, + dataloader: DataLoader) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: + """ + Boost the model, optimizer, criterion, lr_scheduler, and dataloader. + + Args: + model (nn.Module): The model to be boosted. + optimizer (Optimizer): The optimizer to be boosted. + criterion (Callable): The criterion to be boosted. + lr_scheduler (LRScheduler): The lr_scheduler to be boosted. + dataloader (DataLoader): The dataloader to be boosted. + """ + # TODO(FrankLeeeee): consider multi-model and multi-optimizer case + # TODO(lsg): Add plugin control logic + # e.g. + # if self.plugin is not None and self.plugin.control_boost: + # ... + # transform model for mixed precision + model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion) + return model, optimizer, criterion, lr_scheduler, dataloader def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None: - # TODO: implement this method - pass + # TODO: implement this method with plugin + optimizer.backward(loss) def execute_pipeline(self, data_iter: Iterator, diff --git a/colossalai/booster/interface/__init__.py b/colossalai/booster/interface/__init__.py new file mode 100644 index 000000000..8892a13e1 --- /dev/null +++ b/colossalai/booster/interface/__init__.py @@ -0,0 +1,3 @@ +from .optimizer import OptimizerWrapper + +__all__ = ['OptimizerWrapper'] diff --git a/colossalai/booster/interface/optimizer.py b/colossalai/booster/interface/optimizer.py new file mode 100644 index 000000000..dd9acab17 --- /dev/null +++ b/colossalai/booster/interface/optimizer.py @@ -0,0 +1,121 @@ +from typing import Union + +import torch.nn as nn +from torch import Tensor +from torch.optim import Optimizer + + +class OptimizerWrapper: + """ + A standard interface for optimizers wrapped by the Booster. + + Args: + optim (Optimizer): The optimizer to be wrapped. + """ + + def __init__(self, optim: Optimizer): + self.optim = optim + + @property + def parameters(self): + params = [] + + for group in self.param_groups: + params += group['params'] + return params + + @property + def param_groups(self): + return self.optim.param_groups + + @property + def defaults(self): + return self.optim.defaults + + def add_param_group(self, *args, **kwargs): + return self.optim.add_param_group(*args, **kwargs) + + def step(self, *args, **kwargs): + """ + Performs a single optimization step. + """ + return self.optim.step(*args, **kwargs) + + def zero_grad(self, *args, **kwargs): + """ + Clears the gradients of all optimized `torch.Tensor`. + """ + self.optim.zero_grad(*args, **kwargs) + + def backward(self, loss: Tensor, *args, **kwargs): + """ + Performs a backward pass on the loss. + """ + loss.backward(*args, **kwargs) + + def state_dict(self): + """ + Returns the optimizer state. + """ + return self.optim.state_dict() + + def load_state_dict(self, *args, **kwargs): + """ + Loads the optimizer state. + """ + self.optim.load_state_dict(*args, **kwargs) + + def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: + """ + Clips gradient of an iterable of parameters at specified min and max values. + + Args: + clip_value (float or int): maximum allowed value of the gradients. Gradients are clipped in the range + + Note: + In PyTorch Torch 2.0 and above, you can pass in foreach=True as kwargs to clip_grad_value_ to use the + faster implementation. Please refer to the PyTorch documentation for more details. + """ + nn.utils.clip_grad_value_(self.parameters, clip_value, *args, **kwargs) + + def clip_grad_by_norm(self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + error_if_nonfinite: bool = False, + *args, + **kwargs) -> Tensor: + """ + Clips gradient norm of an iterable of parameters. + + Args: + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. + error_if_nonfinite (bool): if True, an error is raised if the total norm is non-finite. Default: False + + Note: + In PyTorch Torch 2.0 and above, you can pass in foreach=True as kwargs to clip_grad_norm_ to use the + faster implementation. Please refer to the PyTorch documentation for more details. + """ + norm = nn.utils.clip_grad_norm_(self.parameters, max_norm, norm_type, error_if_nonfinite, *args, **kwargs) + return norm + + def scale_loss(self, loss: Tensor): + """ + Scales the loss for mixed precision training. + + Note: Only available for optimizers with mixed precision training. + + Args: + loss (Tensor): The loss to be scaled. + """ + raise NotImplementedError( + "The method scale_loss is only available for optimizers with mixed precision training") + + def unscale_grad(self): + """ + Unscale the gradients for mixed precision training. + + Note: Only available for optimizers with mixed precision training. + """ + raise NotImplementedError( + "The method unscale_grad is only available for optimizers with mixed precision training") diff --git a/colossalai/booster/mixed_precision/__init__.py b/colossalai/booster/mixed_precision/__init__.py new file mode 100644 index 000000000..3cf0ad28c --- /dev/null +++ b/colossalai/booster/mixed_precision/__init__.py @@ -0,0 +1,33 @@ +from .bf16 import BF16MixedPrecision +from .fp8 import FP8MixedPrecision +from .fp16_apex import FP16ApexMixedPrecision +from .fp16_torch import FP16TorchMixedPrecision +from .mixed_precision_base import MixedPrecision + +__all__ = [ + 'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision', + 'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision' +] + +_mixed_precision_mapping = { + 'fp16': FP16TorchMixedPrecision, + 'fp16_apex': FP16ApexMixedPrecision, + 'bf16': BF16MixedPrecision, + 'fp8': FP8MixedPrecision +} + + +def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision: + """ + Factory method to create mixed precision object + + Args: + mixed_precision_type (str): mixed precision type, including None, 'fp16', 'fp16_apex', 'bf16', and 'fp8'. + """ + + if mixed_precision_type in _mixed_precision_mapping: + return _mixed_precision_mapping[mixed_precision_type]() + else: + raise ValueError( + f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}' + ) diff --git a/colossalai/booster/mixed_precision/bf16.py b/colossalai/booster/mixed_precision/bf16.py new file mode 100644 index 000000000..4a840fea6 --- /dev/null +++ b/colossalai/booster/mixed_precision/bf16.py @@ -0,0 +1,5 @@ +from .mixed_precision_base import MixedPrecision + + +class BF16MixedPrecision(MixedPrecision): + pass diff --git a/colossalai/booster/mixed_precision/fp16_apex.py b/colossalai/booster/mixed_precision/fp16_apex.py new file mode 100644 index 000000000..266a75073 --- /dev/null +++ b/colossalai/booster/mixed_precision/fp16_apex.py @@ -0,0 +1,5 @@ +from .mixed_precision_base import MixedPrecision + + +class FP16ApexMixedPrecision(MixedPrecision): + pass diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py new file mode 100644 index 000000000..054f78d2e --- /dev/null +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -0,0 +1,122 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch.optim import Optimizer + +from ..interface import OptimizerWrapper +from .mixed_precision_base import MixedPrecision + +__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule'] + + +class TorchAMPOptimizer(OptimizerWrapper): + """ + Optimizer wrapper for mixed precision training in FP16 using PyTorch AMP. + + Args: + optim (Optimizer): Optimizer to wrap. + init_scale (float): Initial scale factor. Default: 2**16. + growth_factor (float): Factor by which the scale is multiplied during + :meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite + this iteration. Default: 2.0. + backoff_factor (float): Factor by which the scale is multiplied during + :meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite + this iteration. Default: 0.5. + growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step` + calls that may cause the scale to increase. Default: 2000. + """ + + def __init__(self, + optim: Optimizer, + init_scale: float = 2.**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000) -> None: + super().__init__(optim) + self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval) + + def backward(self, loss: Tensor, *args, **kwargs) -> None: + scaled_loss = self.scale_loss(loss) + scaled_loss.backward(*args, **kwargs) + + def step(self, *args, **kwargs) -> Optional[float]: + return self.scaler.step(self.optim, *args, **kwargs) + + def scale_loss(self, loss: Tensor) -> Tensor: + return self.scaler.scale(loss) + + def unscale_grad(self) -> None: + self.scaler.unscale_(self.optim) + + def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: + self.unscale_grad() + super().clip_grad_by_value(clip_value, *args, **kwargs) + + def clip_grad_by_norm(self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + error_if_nonfinite: bool = False, + *args, + **kwargs) -> None: + self.unscale_grad() + super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs) + + +class TorchAMPModule(nn.Module): + """ + Module wrapper for mixed precision training in FP16 using PyTorch AMP. + + Args: + module (nn.Module): Module to wrap. + """ + + def __init__(self, module: nn.Module): + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + with torch.cuda.amp.autocast(): + return self.module(*args, **kwargs) + + +class FP16TorchMixedPrecision(MixedPrecision): + """ + Precision for mixed precision training in FP16 using PyTorch AMP. + + Args: + init_scale (float): Initial scale factor. Default: 2**16. + growth_factor (float): Factor by which the scale is multiplied during + :meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite + this iteration. Default: 2.0. + backoff_factor (float): Factor by which the scale is multiplied during + :meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite + this iteration. Default: 0.5. + growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step` + calls that may cause the scale to increase. Default: 2000. + """ + + def __init__(self, + init_scale: float = 2.**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000) -> None: + super().__init__() + self.torch_amp_kwargs = dict(init_scale=init_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval) + + def configure(self, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]: + model = TorchAMPModule(model) + optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs) + if criterion is not None: + criterion = TorchAMPModule(criterion) + return model, optimizer, criterion diff --git a/colossalai/booster/mixed_precision/fp8.py b/colossalai/booster/mixed_precision/fp8.py new file mode 100644 index 000000000..28847345d --- /dev/null +++ b/colossalai/booster/mixed_precision/fp8.py @@ -0,0 +1,5 @@ +from .mixed_precision_base import MixedPrecision + + +class FP8MixedPrecision(MixedPrecision): + pass diff --git a/colossalai/booster/mixed_precision/mixed_precision_base.py b/colossalai/booster/mixed_precision/mixed_precision_base.py new file mode 100644 index 000000000..d1e8acc82 --- /dev/null +++ b/colossalai/booster/mixed_precision/mixed_precision_base.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod +from typing import Callable, Tuple + +import torch.nn as nn +from torch.optim import Optimizer + +from ..interface import OptimizerWrapper + + +class MixedPrecision(ABC): + """ + An abstract class for mixed precision training. + """ + + @abstractmethod + def configure(self, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]: + # TODO: implement this method + pass diff --git a/colossalai/booster/precision.py b/colossalai/booster/precision.py deleted file mode 100644 index 8a391d9e4..000000000 --- a/colossalai/booster/precision.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -import torch.nn as nn -from torch.optim import Optimizer - -__all__ = ['Precision'] - - -class Precision: - - def __init__(self, precision_type: torch.dtype, grad_clipping_type: str, grad_clipping_value: float): - self.precision_type = precision_type - self.grad_clipping_type = grad_clipping_type - self.grad_clipping_value = grad_clipping_value - - def setup_model(self, model: nn.Module) -> nn.Module: - # TODO: implement this method - pass - - def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: - # TODO: implement this method - # inject grad clipping and unscale loss - pass - - def scale_loss(self, loss: torch.Tensor) -> torch.Tensor: - pass diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index a92a46e36..2a100c981 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -6,7 +6,7 @@ from ..registry import ModelAttribute, model_zoo # =============================== # Register single-sentence GPT # =============================== -BATCH_SIZE = 2 +BATCH_SIZE = 1 # it can only be 1 as GPT cannot handle batch sizes > 1 if no padding token is defined. SEQ_LENGTH = 16 diff --git a/tests/test_booster/test_mixed_precision/test_fp16_torch.py b/tests/test_booster/test_mixed_precision/test_fp16_torch.py new file mode 100644 index 000000000..c56fcae58 --- /dev/null +++ b/tests/test_booster/test_mixed_precision/test_fp16_torch.py @@ -0,0 +1,23 @@ +import torch +from torch.optim import Adam + +from colossalai.booster.mixed_precision import FP16TorchMixedPrecision +from tests.kit.model_zoo import model_zoo + + +def test_torch_amp(): + for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + model = model_fn().cuda() + optimizer = Adam(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + data = {k: v.cuda() if torch.is_tensor(v) else v for k, v in data.items()} + mixed_precision = FP16TorchMixedPrecision() + model, optimizer, criterion = mixed_precision.configure(model, optimizer, criterion) + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + optimizer.backward(loss) + optimizer.clip_grad_by_norm(1.0) + optimizer.step()