mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[doc] Update booster user documents. (#4669)
* update booster_api.md * update booster_checkpoint.md * update booster_plugins.md * move transformers importing inside function * fix Dict typing * fix autodoc bug * small fix
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Iterator, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -24,29 +24,31 @@ class Booster:
|
||||
Booster is a high-level API for training neural networks. It provides a unified interface for
|
||||
training with different precision, accelerator, and plugin.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
colossalai.launch(...)
|
||||
plugin = GeminiPlugin(...)
|
||||
booster = Booster(precision='fp16', plugin=plugin)
|
||||
|
||||
model = GPT2()
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
dataloader = Dataloader(Dataset)
|
||||
lr_scheduler = LinearWarmupScheduler()
|
||||
criterion = GPTLMLoss()
|
||||
```python
|
||||
# Following is pseudocode
|
||||
|
||||
model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
|
||||
colossalai.launch(...)
|
||||
plugin = GeminiPlugin(...)
|
||||
booster = Booster(precision='fp16', plugin=plugin)
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
for input_ids, attention_mask in dataloader:
|
||||
outputs = model(input_ids, attention_mask)
|
||||
loss = criterion(outputs.logits, input_ids)
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
model = GPT2()
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
||||
lr_scheduler = LinearWarmupScheduler()
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, dataloader, lr_scheduler)
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
for input_ids, attention_mask in dataloader:
|
||||
outputs = model(input_ids.cuda(), attention_mask.cuda())
|
||||
loss = criterion(outputs.logits, input_ids)
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
Args:
|
||||
device (str or torch.device): The device to run the training. Default: None.
|
||||
@@ -60,7 +62,7 @@ class Booster:
|
||||
|
||||
def __init__(self,
|
||||
device: Optional[str] = None,
|
||||
mixed_precision: Union[MixedPrecision, str] = None,
|
||||
mixed_precision: Optional[Union[MixedPrecision, str]] = None,
|
||||
plugin: Optional[Plugin] = None) -> None:
|
||||
if plugin is not None:
|
||||
assert isinstance(
|
||||
@@ -110,14 +112,19 @@ class Booster:
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
|
||||
"""
|
||||
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
|
||||
Wrap and inject features to the passed in model, optimizer, criterion, lr_scheduler, and dataloader.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be boosted.
|
||||
optimizer (Optimizer): The optimizer to be boosted.
|
||||
criterion (Callable): The criterion to be boosted.
|
||||
dataloader (DataLoader): The dataloader to be boosted.
|
||||
lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
|
||||
model (nn.Module): Convert model into a wrapped model for distributive training.
|
||||
The model might be decorated or partitioned by plugin's strategy after execution of this method.
|
||||
optimizer (Optimizer, optional): Convert optimizer into a wrapped optimizer for distributive training.
|
||||
The optimizer's param groups or states might be decorated or partitioned by plugin's strategy after execution of this method. Defaults to None.
|
||||
criterion (Callable, optional): The function that calculates loss. Defaults to None.
|
||||
dataloader (DataLoader, optional): The prepared dataloader for training. Defaults to None.
|
||||
lr_scheduler (LRScheduler, optional): The learning scheduler for training. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: The list of boosted input arguments.
|
||||
"""
|
||||
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
|
||||
# TODO(FrankLeeeee): consider multi-dataloader case
|
||||
@@ -138,10 +145,10 @@ class Booster:
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
|
||||
"""Backward pass.
|
||||
"""Execution of backward during training step.
|
||||
|
||||
Args:
|
||||
loss (torch.Tensor): The loss to be backpropagated.
|
||||
loss (torch.Tensor): The loss for backpropagation.
|
||||
optimizer (Optimizer): The optimizer to be updated.
|
||||
"""
|
||||
# TODO(frank lee): implement this method with plugin
|
||||
@@ -153,9 +160,31 @@ class Booster:
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
return_loss: bool = True,
|
||||
return_outputs: bool = False) -> dict:
|
||||
# run pipeline forward backward pass
|
||||
# return loss or outputs if needed
|
||||
return_outputs: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute forward & backward when utilizing pipeline parallel.
|
||||
Return loss or Huggingface style model outputs if needed.
|
||||
|
||||
Warning: This function is tailored for the scenario of pipeline parallel.
|
||||
As a result, please don't do the forward/backward pass in the conventional way (model(input)/loss.backward())
|
||||
when doing pipeline parallel training with booster, which will cause unexpected errors.
|
||||
|
||||
Args:
|
||||
data_iter(Iterator): The iterator for getting the next batch of data. Usually there are two ways to obtain this argument:
|
||||
1. wrap the dataloader to iterator through: iter(dataloader)
|
||||
2. get the next batch from dataloader, and wrap this batch to iterator: iter([batch])
|
||||
model (nn.Module): The model to execute forward/backward, it should be a model wrapped by a plugin that supports pipeline.
|
||||
criterion: (Callable[[Any, Any], torch.Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
||||
'lambda y, x: loss_fn(y)' can turn a normal loss function into a valid two-argument criterion here.
|
||||
optimizer (Optimizer, optional): The optimizer for execution of backward. Can be None when only doing forward (i.e. evaluation). Defaults to None.
|
||||
return_loss (bool, optional): Whether to return loss in the dict returned by this method. Defaults to True.
|
||||
return_output (bool, optional): Whether to return Huggingface style model outputs in the dict returned by this method. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Output dict in the form of {'loss': ..., 'outputs': ...}.
|
||||
ret_dict['loss'] is the loss of forward if return_loss is set to True, else None.
|
||||
ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None.
|
||||
"""
|
||||
assert isinstance(self.plugin,
|
||||
PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.'
|
||||
return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs)
|
||||
@@ -175,7 +204,7 @@ class Booster:
|
||||
assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
||||
return self.plugin.no_sync(model, optimizer)
|
||||
|
||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
|
||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
|
||||
"""Load model from checkpoint.
|
||||
|
||||
Args:
|
||||
@@ -195,7 +224,7 @@ class Booster:
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
use_safetensors: bool = False) -> None:
|
||||
"""Save model to checkpoint.
|
||||
|
||||
Args:
|
||||
@@ -203,7 +232,7 @@ class Booster:
|
||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
||||
shard (bool, optional): Whether to save checkpoint a sharded way.
|
||||
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
||||
If true, the checkpoint will be a folder with the same format as Huggingface transformers checkpoint. Otherwise, it will be a single file. Defaults to False.
|
||||
gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True.
|
||||
prefix (str, optional): A prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
@@ -218,7 +247,7 @@ class Booster:
|
||||
size_per_shard=size_per_shard,
|
||||
use_safetensors=use_safetensors)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
|
||||
"""Load optimizer from checkpoint.
|
||||
|
||||
Args:
|
||||
@@ -237,7 +266,7 @@ class Booster:
|
||||
shard: bool = False,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024):
|
||||
size_per_shard: int = 1024) -> None:
|
||||
"""
|
||||
Save optimizer to checkpoint.
|
||||
|
||||
@@ -254,7 +283,7 @@ class Booster:
|
||||
"""
|
||||
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
|
||||
"""Save lr scheduler to checkpoint.
|
||||
|
||||
Args:
|
||||
@@ -263,7 +292,7 @@ class Booster:
|
||||
"""
|
||||
self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
|
||||
"""Load lr scheduler from checkpoint.
|
||||
|
||||
Args:
|
||||
|
Reference in New Issue
Block a user