mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-31 23:35:47 +00:00
[booster] implemented mixed precision class (#3151)
* [booster] implemented mixed precision class * polish code
This commit is contained in:
parent
ecd643f1e4
commit
ed19290560
@ -2,4 +2,3 @@ from .accelerator import Accelerator
|
|||||||
from .booster import Booster
|
from .booster import Booster
|
||||||
from .environment_table import EnvironmentTable
|
from .environment_table import EnvironmentTable
|
||||||
from .plugin import Plugin
|
from .plugin import Plugin
|
||||||
from .precision import Precision
|
|
||||||
|
@ -1,37 +1,95 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
||||||
from .plugin import Plugin
|
from .plugin import Plugin
|
||||||
|
|
||||||
__all__ = ['Booster']
|
__all__ = ['Booster']
|
||||||
|
|
||||||
|
|
||||||
class Booster:
|
class Booster:
|
||||||
|
"""
|
||||||
|
Booster is a high-level API for training neural networks. It provides a unified interface for
|
||||||
|
training with different precisio, accelerator, and plugin.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> colossalai.launch(...)
|
||||||
|
>>> plugin = GeminiPlugin(stage=3, ...)
|
||||||
|
>>> booster = Booster(precision='fp16', plugin=plugin)
|
||||||
|
>>>
|
||||||
|
>>> model = GPT2()
|
||||||
|
>>> optimizer = Adam(model.parameters())
|
||||||
|
>>> dataloader = Dataloader(Dataset)
|
||||||
|
>>> lr_scheduler = LinearWarmupScheduler()
|
||||||
|
>>> criterion = GPTLMLoss()
|
||||||
|
>>>
|
||||||
|
>>> model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
|
||||||
|
>>>
|
||||||
|
>>> for epoch in range(max_epochs):
|
||||||
|
>>> for input_ids, attention_mask in dataloader:
|
||||||
|
>>> outputs = model(input_ids, attention_mask)
|
||||||
|
>>> loss = criterion(outputs.logits, input_ids)
|
||||||
|
>>> booster.backward(loss, optimizer)
|
||||||
|
>>> optimizer.step()
|
||||||
|
>>> lr_scheduler.step()
|
||||||
|
>>> optimizer.zero_grad()
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (str or torch.device): The device to run the training. Default: 'cuda'.
|
||||||
|
mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None.
|
||||||
|
If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'.
|
||||||
|
'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex.
|
||||||
|
plugin (Plugin): The plugin to run the training. Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
device: Union[str, torch.device] = 'cuda',
|
device: Union[str, torch.device] = 'cuda',
|
||||||
precision: str = 'fp32',
|
mixed_precision: Union[MixedPrecision, str] = None,
|
||||||
grad_clipping_type: str = 'norm',
|
|
||||||
grad_clipping_value: float = 0.0,
|
|
||||||
plugin: Optional[Plugin] = None) -> None:
|
plugin: Optional[Plugin] = None) -> None:
|
||||||
# TODO: implement this method
|
# validate and set precision
|
||||||
pass
|
if isinstance(MixedPrecision, str):
|
||||||
|
# the user will take the default arguments for amp training
|
||||||
|
self.mixed_precision = mixed_precision_factory(mixed_precision)
|
||||||
|
elif isinstance(mixed_precision, MixedPrecision):
|
||||||
|
# the user can customize the arguments by passing the precision object
|
||||||
|
self.mixed_precision = mixed_precision
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
|
||||||
|
)
|
||||||
|
|
||||||
def boost(
|
def boost(self, model: nn.Module, optimizer: Optimizer, criterion: Callable, lr_scheduler: LRScheduler,
|
||||||
self, *args: Union[nn.Module, Optimizer, LRScheduler, DataLoader]
|
dataloader: DataLoader) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
|
||||||
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
|
"""
|
||||||
# TODO: implement this method
|
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
|
||||||
pass
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The model to be boosted.
|
||||||
|
optimizer (Optimizer): The optimizer to be boosted.
|
||||||
|
criterion (Callable): The criterion to be boosted.
|
||||||
|
lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
|
||||||
|
dataloader (DataLoader): The dataloader to be boosted.
|
||||||
|
"""
|
||||||
|
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
|
||||||
|
# TODO(lsg): Add plugin control logic
|
||||||
|
# e.g.
|
||||||
|
# if self.plugin is not None and self.plugin.control_boost:
|
||||||
|
# ...
|
||||||
|
# transform model for mixed precision
|
||||||
|
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
|
||||||
|
return model, optimizer, criterion, lr_scheduler, dataloader
|
||||||
|
|
||||||
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
|
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
|
||||||
# TODO: implement this method
|
# TODO: implement this method with plugin
|
||||||
pass
|
optimizer.backward(loss)
|
||||||
|
|
||||||
def execute_pipeline(self,
|
def execute_pipeline(self,
|
||||||
data_iter: Iterator,
|
data_iter: Iterator,
|
||||||
|
3
colossalai/booster/interface/__init__.py
Normal file
3
colossalai/booster/interface/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .optimizer import OptimizerWrapper
|
||||||
|
|
||||||
|
__all__ = ['OptimizerWrapper']
|
121
colossalai/booster/interface/optimizer.py
Normal file
121
colossalai/booster/interface/optimizer.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerWrapper:
|
||||||
|
"""
|
||||||
|
A standard interface for optimizers wrapped by the Booster.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optim (Optimizer): The optimizer to be wrapped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, optim: Optimizer):
|
||||||
|
self.optim = optim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self):
|
||||||
|
params = []
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
params += group['params']
|
||||||
|
return params
|
||||||
|
|
||||||
|
@property
|
||||||
|
def param_groups(self):
|
||||||
|
return self.optim.param_groups
|
||||||
|
|
||||||
|
@property
|
||||||
|
def defaults(self):
|
||||||
|
return self.optim.defaults
|
||||||
|
|
||||||
|
def add_param_group(self, *args, **kwargs):
|
||||||
|
return self.optim.add_param_group(*args, **kwargs)
|
||||||
|
|
||||||
|
def step(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Performs a single optimization step.
|
||||||
|
"""
|
||||||
|
return self.optim.step(*args, **kwargs)
|
||||||
|
|
||||||
|
def zero_grad(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Clears the gradients of all optimized `torch.Tensor`.
|
||||||
|
"""
|
||||||
|
self.optim.zero_grad(*args, **kwargs)
|
||||||
|
|
||||||
|
def backward(self, loss: Tensor, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Performs a backward pass on the loss.
|
||||||
|
"""
|
||||||
|
loss.backward(*args, **kwargs)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
"""
|
||||||
|
Returns the optimizer state.
|
||||||
|
"""
|
||||||
|
return self.optim.state_dict()
|
||||||
|
|
||||||
|
def load_state_dict(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Loads the optimizer state.
|
||||||
|
"""
|
||||||
|
self.optim.load_state_dict(*args, **kwargs)
|
||||||
|
|
||||||
|
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Clips gradient of an iterable of parameters at specified min and max values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
clip_value (float or int): maximum allowed value of the gradients. Gradients are clipped in the range
|
||||||
|
|
||||||
|
Note:
|
||||||
|
In PyTorch Torch 2.0 and above, you can pass in foreach=True as kwargs to clip_grad_value_ to use the
|
||||||
|
faster implementation. Please refer to the PyTorch documentation for more details.
|
||||||
|
"""
|
||||||
|
nn.utils.clip_grad_value_(self.parameters, clip_value, *args, **kwargs)
|
||||||
|
|
||||||
|
def clip_grad_by_norm(self,
|
||||||
|
max_norm: Union[float, int],
|
||||||
|
norm_type: Union[float, int] = 2.0,
|
||||||
|
error_if_nonfinite: bool = False,
|
||||||
|
*args,
|
||||||
|
**kwargs) -> Tensor:
|
||||||
|
"""
|
||||||
|
Clips gradient norm of an iterable of parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_norm (float or int): max norm of the gradients
|
||||||
|
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
|
||||||
|
error_if_nonfinite (bool): if True, an error is raised if the total norm is non-finite. Default: False
|
||||||
|
|
||||||
|
Note:
|
||||||
|
In PyTorch Torch 2.0 and above, you can pass in foreach=True as kwargs to clip_grad_norm_ to use the
|
||||||
|
faster implementation. Please refer to the PyTorch documentation for more details.
|
||||||
|
"""
|
||||||
|
norm = nn.utils.clip_grad_norm_(self.parameters, max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
|
||||||
|
return norm
|
||||||
|
|
||||||
|
def scale_loss(self, loss: Tensor):
|
||||||
|
"""
|
||||||
|
Scales the loss for mixed precision training.
|
||||||
|
|
||||||
|
Note: Only available for optimizers with mixed precision training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (Tensor): The loss to be scaled.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The method scale_loss is only available for optimizers with mixed precision training")
|
||||||
|
|
||||||
|
def unscale_grad(self):
|
||||||
|
"""
|
||||||
|
Unscale the gradients for mixed precision training.
|
||||||
|
|
||||||
|
Note: Only available for optimizers with mixed precision training.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The method unscale_grad is only available for optimizers with mixed precision training")
|
33
colossalai/booster/mixed_precision/__init__.py
Normal file
33
colossalai/booster/mixed_precision/__init__.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from .bf16 import BF16MixedPrecision
|
||||||
|
from .fp8 import FP8MixedPrecision
|
||||||
|
from .fp16_apex import FP16ApexMixedPrecision
|
||||||
|
from .fp16_torch import FP16TorchMixedPrecision
|
||||||
|
from .mixed_precision_base import MixedPrecision
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision',
|
||||||
|
'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision'
|
||||||
|
]
|
||||||
|
|
||||||
|
_mixed_precision_mapping = {
|
||||||
|
'fp16': FP16TorchMixedPrecision,
|
||||||
|
'fp16_apex': FP16ApexMixedPrecision,
|
||||||
|
'bf16': BF16MixedPrecision,
|
||||||
|
'fp8': FP8MixedPrecision
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision:
|
||||||
|
"""
|
||||||
|
Factory method to create mixed precision object
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mixed_precision_type (str): mixed precision type, including None, 'fp16', 'fp16_apex', 'bf16', and 'fp8'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if mixed_precision_type in _mixed_precision_mapping:
|
||||||
|
return _mixed_precision_mapping[mixed_precision_type]()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}'
|
||||||
|
)
|
5
colossalai/booster/mixed_precision/bf16.py
Normal file
5
colossalai/booster/mixed_precision/bf16.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from .mixed_precision_base import MixedPrecision
|
||||||
|
|
||||||
|
|
||||||
|
class BF16MixedPrecision(MixedPrecision):
|
||||||
|
pass
|
5
colossalai/booster/mixed_precision/fp16_apex.py
Normal file
5
colossalai/booster/mixed_precision/fp16_apex.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from .mixed_precision_base import MixedPrecision
|
||||||
|
|
||||||
|
|
||||||
|
class FP16ApexMixedPrecision(MixedPrecision):
|
||||||
|
pass
|
122
colossalai/booster/mixed_precision/fp16_torch.py
Normal file
122
colossalai/booster/mixed_precision/fp16_torch.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
from ..interface import OptimizerWrapper
|
||||||
|
from .mixed_precision_base import MixedPrecision
|
||||||
|
|
||||||
|
__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
|
||||||
|
|
||||||
|
|
||||||
|
class TorchAMPOptimizer(OptimizerWrapper):
|
||||||
|
"""
|
||||||
|
Optimizer wrapper for mixed precision training in FP16 using PyTorch AMP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optim (Optimizer): Optimizer to wrap.
|
||||||
|
init_scale (float): Initial scale factor. Default: 2**16.
|
||||||
|
growth_factor (float): Factor by which the scale is multiplied during
|
||||||
|
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite
|
||||||
|
this iteration. Default: 2.0.
|
||||||
|
backoff_factor (float): Factor by which the scale is multiplied during
|
||||||
|
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite
|
||||||
|
this iteration. Default: 0.5.
|
||||||
|
growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step`
|
||||||
|
calls that may cause the scale to increase. Default: 2000.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optim: Optimizer,
|
||||||
|
init_scale: float = 2.**16,
|
||||||
|
growth_factor: float = 2.0,
|
||||||
|
backoff_factor: float = 0.5,
|
||||||
|
growth_interval: int = 2000) -> None:
|
||||||
|
super().__init__(optim)
|
||||||
|
self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale,
|
||||||
|
growth_factor=growth_factor,
|
||||||
|
backoff_factor=backoff_factor,
|
||||||
|
growth_interval=growth_interval)
|
||||||
|
|
||||||
|
def backward(self, loss: Tensor, *args, **kwargs) -> None:
|
||||||
|
scaled_loss = self.scale_loss(loss)
|
||||||
|
scaled_loss.backward(*args, **kwargs)
|
||||||
|
|
||||||
|
def step(self, *args, **kwargs) -> Optional[float]:
|
||||||
|
return self.scaler.step(self.optim, *args, **kwargs)
|
||||||
|
|
||||||
|
def scale_loss(self, loss: Tensor) -> Tensor:
|
||||||
|
return self.scaler.scale(loss)
|
||||||
|
|
||||||
|
def unscale_grad(self) -> None:
|
||||||
|
self.scaler.unscale_(self.optim)
|
||||||
|
|
||||||
|
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
|
||||||
|
self.unscale_grad()
|
||||||
|
super().clip_grad_by_value(clip_value, *args, **kwargs)
|
||||||
|
|
||||||
|
def clip_grad_by_norm(self,
|
||||||
|
max_norm: Union[float, int],
|
||||||
|
norm_type: Union[float, int] = 2.0,
|
||||||
|
error_if_nonfinite: bool = False,
|
||||||
|
*args,
|
||||||
|
**kwargs) -> None:
|
||||||
|
self.unscale_grad()
|
||||||
|
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class TorchAMPModule(nn.Module):
|
||||||
|
"""
|
||||||
|
Module wrapper for mixed precision training in FP16 using PyTorch AMP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): Module to wrap.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, module: nn.Module):
|
||||||
|
super().__init__()
|
||||||
|
self.module = module
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
return self.module(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class FP16TorchMixedPrecision(MixedPrecision):
|
||||||
|
"""
|
||||||
|
Precision for mixed precision training in FP16 using PyTorch AMP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_scale (float): Initial scale factor. Default: 2**16.
|
||||||
|
growth_factor (float): Factor by which the scale is multiplied during
|
||||||
|
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite
|
||||||
|
this iteration. Default: 2.0.
|
||||||
|
backoff_factor (float): Factor by which the scale is multiplied during
|
||||||
|
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite
|
||||||
|
this iteration. Default: 0.5.
|
||||||
|
growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step`
|
||||||
|
calls that may cause the scale to increase. Default: 2000.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
init_scale: float = 2.**16,
|
||||||
|
growth_factor: float = 2.0,
|
||||||
|
backoff_factor: float = 0.5,
|
||||||
|
growth_interval: int = 2000) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.torch_amp_kwargs = dict(init_scale=init_scale,
|
||||||
|
growth_factor=growth_factor,
|
||||||
|
backoff_factor=backoff_factor,
|
||||||
|
growth_interval=growth_interval)
|
||||||
|
|
||||||
|
def configure(self,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||||
|
model = TorchAMPModule(model)
|
||||||
|
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
|
||||||
|
if criterion is not None:
|
||||||
|
criterion = TorchAMPModule(criterion)
|
||||||
|
return model, optimizer, criterion
|
5
colossalai/booster/mixed_precision/fp8.py
Normal file
5
colossalai/booster/mixed_precision/fp8.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from .mixed_precision_base import MixedPrecision
|
||||||
|
|
||||||
|
|
||||||
|
class FP8MixedPrecision(MixedPrecision):
|
||||||
|
pass
|
21
colossalai/booster/mixed_precision/mixed_precision_base.py
Normal file
21
colossalai/booster/mixed_precision/mixed_precision_base.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
from ..interface import OptimizerWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class MixedPrecision(ABC):
|
||||||
|
"""
|
||||||
|
An abstract class for mixed precision training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def configure(self,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||||
|
# TODO: implement this method
|
||||||
|
pass
|
@ -1,25 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
|
|
||||||
__all__ = ['Precision']
|
|
||||||
|
|
||||||
|
|
||||||
class Precision:
|
|
||||||
|
|
||||||
def __init__(self, precision_type: torch.dtype, grad_clipping_type: str, grad_clipping_value: float):
|
|
||||||
self.precision_type = precision_type
|
|
||||||
self.grad_clipping_type = grad_clipping_type
|
|
||||||
self.grad_clipping_value = grad_clipping_value
|
|
||||||
|
|
||||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
|
||||||
# TODO: implement this method
|
|
||||||
pass
|
|
||||||
|
|
||||||
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
|
|
||||||
# TODO: implement this method
|
|
||||||
# inject grad clipping and unscale loss
|
|
||||||
pass
|
|
||||||
|
|
||||||
def scale_loss(self, loss: torch.Tensor) -> torch.Tensor:
|
|
||||||
pass
|
|
@ -6,7 +6,7 @@ from ..registry import ModelAttribute, model_zoo
|
|||||||
# ===============================
|
# ===============================
|
||||||
# Register single-sentence GPT
|
# Register single-sentence GPT
|
||||||
# ===============================
|
# ===============================
|
||||||
BATCH_SIZE = 2
|
BATCH_SIZE = 1 # it can only be 1 as GPT cannot handle batch sizes > 1 if no padding token is defined.
|
||||||
SEQ_LENGTH = 16
|
SEQ_LENGTH = 16
|
||||||
|
|
||||||
|
|
||||||
|
23
tests/test_booster/test_mixed_precision/test_fp16_torch.py
Normal file
23
tests/test_booster/test_mixed_precision/test_fp16_torch.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import torch
|
||||||
|
from torch.optim import Adam
|
||||||
|
|
||||||
|
from colossalai.booster.mixed_precision import FP16TorchMixedPrecision
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
def test_torch_amp():
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
|
||||||
|
model = model_fn().cuda()
|
||||||
|
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||||
|
criterion = lambda x: x.mean()
|
||||||
|
data = data_gen_fn()
|
||||||
|
data = {k: v.cuda() if torch.is_tensor(v) else v for k, v in data.items()}
|
||||||
|
mixed_precision = FP16TorchMixedPrecision()
|
||||||
|
model, optimizer, criterion = mixed_precision.configure(model, optimizer, criterion)
|
||||||
|
output = model(**data)
|
||||||
|
output = output_transform_fn(output)
|
||||||
|
output_key = list(output.keys())[0]
|
||||||
|
loss = criterion(output[output_key])
|
||||||
|
optimizer.backward(loss)
|
||||||
|
optimizer.clip_grad_by_norm(1.0)
|
||||||
|
optimizer.step()
|
Loading…
Reference in New Issue
Block a user