mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[shardformer] update pipeline parallel document (#4725)
* [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document
This commit is contained in:
@@ -1,14 +1,15 @@
|
||||
# Pipeline Parallel
|
||||
|
||||
Author: Guangyang Lu, Hongxin Liu, Yongbin Li
|
||||
Author: Guangyang Lu, Hongxin Liu, Yongbin Li, Mingyan Jiang
|
||||
|
||||
**Prerequisite**
|
||||
- [Define Your Configuration](../basics/define_your_config.md)
|
||||
- [Use Engine and Trainer in Training](../basics/engine_trainer.md)
|
||||
- [Configure Parallelization](../basics/configure_parallelization.md)
|
||||
- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md)
|
||||
- [Use Booster to Training](../basics/booster_api.md)
|
||||
- [Shardformer](../features/shardformer.md)
|
||||
- [Plugin of Booster](../basics/booster_plugins.md)
|
||||
|
||||
**Example Code**
|
||||
- [ColossalAI-Examples ResNet with pipeline](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/pipeline_parallel)
|
||||
- [Fine-tune Bert with pipeline](https://github.com/hpcaitech/ColossalAI/blob/main/examples/language/bert/finetune.py)
|
||||
|
||||
**Related Paper**
|
||||
- [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883)
|
||||
@@ -17,7 +18,7 @@ Author: Guangyang Lu, Hongxin Liu, Yongbin Li
|
||||
|
||||
## Quick introduction
|
||||
|
||||
In this tutorial, you will learn how to use pipeline parallel. In Colossal-AI, we use 1F1B pipeline, introduced by Nvidia. In this case, ViT and Imagenet are too large to use. Therefore, here we use ResNet and Cifar as example.
|
||||
In this tutorial, you will learn how to use pipeline parallel. In Colossal-AI, we use 1F1B pipeline, introduced by Nvidia. In this case, ViT and Imagenet are too large to use. Therefore, here we use bert model and glue dataset as example.
|
||||
|
||||
## Table Of Content
|
||||
|
||||
@@ -25,7 +26,7 @@ In this tutorial we will cover:
|
||||
|
||||
1. Introduction of 1F1B pipeline.
|
||||
2. Usage of non-interleaved and interleaved schedule.
|
||||
3. Training ResNet with pipeline.
|
||||
3. Finetune Bert with pipeline.
|
||||
|
||||
## Introduction of 1F1B pipeline
|
||||
|
||||
@@ -60,101 +61,158 @@ In this schedule, each device can perform computation for multiple subsets of la
|
||||
|
||||
This mode is both memory-efficient and time-efficient.
|
||||
|
||||
## Usage of non-interleaved and interleaved schedule
|
||||
## Colossal-AI's Implementation
|
||||
|
||||
In Colossal-AI, we provided both non-interleaved(as `PipelineSchedule`) and interleaved schedule(as `InterleavedPipelineSchedule`).
|
||||
In Colossal-AI, pipeline parallelism relies on the `scheduler` and [`Shardformer`](../features/shardformer.md). We provide both non-interleaved (`OneForwardOneBackwardSchedule`) and interleaved (`InterleavedSchedule`) schedules. While `Shardformer` implements layer splitting for models and replaces the `forward` function of the model to make it compatible with the scheduler.
|
||||
|
||||
You just need to set `NUM_MICRO_BATCHES` in config file and set `NUM_CHUNKS` in config file if you want to use Interleaved Pipeline Schedule. If you certainly know the shape of each pipeline stage's output tensor and the shapes are all the same, you can set `TENSOR_SHAPE` in config file to further reduce communication. Otherwise, you can just ignore `tensor_shape`, and the shape will be exchanged over pipeline stages automatically. Then we will generate an appropriate schedule for you.
|
||||
In Colossal-AI, the `HybridParallelPlugin` encapsulates pipeline execution strategies. It manages pipeline parallel communication groups and a scheduler. When boosting the model with this plugin, the model's layers are split by calling the `shardformer.optimize` function, and then `execute_pipeline` is called to execute the model in segments using `OneForwardOneBackwardSchedule` which is default scheduler used in `HybridParallelPlugin`, and `InterleavedSchedule` will be integrated later.
|
||||
|
||||
## Training ResNet with pipeline
|
||||
You can customize your parallel strategy by setting parameters for the `HybridParallelPlugin`.
|
||||
|
||||
Let's build the `ResNet` model first with Colossal PipelinableContext:
|
||||
For more usage details, please refer to the [documentation](../basics/booster_plugins.md) for `HybridParallelPlugin`.
|
||||
|
||||
## Fine-tune Bert with pipeline
|
||||
|
||||
First, we define the necessary training components, including model, dataloader, optimizer, lr_scheduler, criterion:
|
||||
```python
|
||||
import os
|
||||
from typing import Callable, List, Optional, Type, Union
|
||||
import argparse
|
||||
from typing import Callable, List, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from data import GLUEDataBuilder
|
||||
from torch.optim import Adam, Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import (
|
||||
AlbertForSequenceClassification,
|
||||
AutoConfig,
|
||||
BertForSequenceClassification,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
|
||||
import colossalai
|
||||
import colossalai.nn as col_nn
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.legacy.trainer import Trainer, hooks
|
||||
from colossalai.utils import MultiTimer, get_dataloader
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
|
||||
from titans.dataloader.cifar10 import build_cifar
|
||||
from torchvision.models import resnet50
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
# Define some config
|
||||
BATCH_SIZE = 64
|
||||
NUM_EPOCHS = 2
|
||||
NUM_CHUNKS = 1
|
||||
CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2))
|
||||
NUM_EPOCHS = 3
|
||||
BATCH_SIZE = 32
|
||||
LEARNING_RATE = 2.4e-5
|
||||
WEIGHT_DECAY = 0.01
|
||||
WARMUP_FRACTION = 0.1
|
||||
|
||||
# Train
|
||||
disable_existing_loggers()
|
||||
parser = colossalai.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
colossalai.launch_from_torch(backend=args.backend, config=CONFIG)
|
||||
logger = get_dist_logger()
|
||||
pipelinable = PipelinableContext()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# build model
|
||||
with pipelinable:
|
||||
model = resnet50()
|
||||
```
|
||||
def move_to_cuda(batch):
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
Define an execution sequence.
|
||||
```python
|
||||
exec_seq = [
|
||||
'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool',
|
||||
(lambda x: torch.flatten(x, 1), "behind"), 'fc'
|
||||
]
|
||||
pipelinable.to_layer_list(exec_seq)
|
||||
```
|
||||
|
||||
Partition the model into pipeline.
|
||||
```python
|
||||
model = pipelinable.partition(NUM_CHUNKS, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
|
||||
```
|
||||
# Define 'criterion' function with two inputs, which will be passed to 'execute_pipeline'.
|
||||
def _criterion(outputs, inputs):
|
||||
return outputs.loss
|
||||
|
||||
In this tutorial, we use `Trainer` to train `ResNet`:
|
||||
```python
|
||||
# build criterion
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# optimizer
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
|
||||
# build dataloader
|
||||
root = os.environ.get('DATA', './data')
|
||||
train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, padding=4, crop=32, resize=32)
|
||||
|
||||
lr_scheduler = col_nn.lr_scheduler.LinearWarmupLR(optimizer, NUM_EPOCHS, warmup_steps=1)
|
||||
engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model, optimizer, criterion,
|
||||
train_dataloader, test_dataloader,
|
||||
lr_scheduler)
|
||||
timer = MultiTimer()
|
||||
|
||||
trainer = Trainer(engine=engine, timer=timer, logger=logger)
|
||||
|
||||
hook_list = [
|
||||
hooks.LossHook(),
|
||||
hooks.AccuracyHook(col_nn.metric.Accuracy()),
|
||||
hooks.LogMetricByEpochHook(logger),
|
||||
hooks.LRSchedulerHook(lr_scheduler, by_epoch=True)
|
||||
# Define optimizer
|
||||
lr = LEARNING_RATE
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": WEIGHT_DECAY,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
epochs=NUM_EPOCHS,
|
||||
test_dataloader=test_dataloader,
|
||||
test_interval=1,
|
||||
hooks=hook_list,
|
||||
display_progress=True)
|
||||
optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)
|
||||
|
||||
|
||||
# Define lr_scheduler
|
||||
total_steps = len(train_dataloader) * NUM_EPOCHS
|
||||
num_warmup_steps = int(WARMUP_FRACTION * total_steps)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=total_steps,
|
||||
)
|
||||
|
||||
|
||||
# Define Bert model
|
||||
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", config=cfg).cuda()
|
||||
|
||||
# Define a dataloader
|
||||
data_builder = GLUEDataBuilder(model_name,
|
||||
plugin,
|
||||
args.task,
|
||||
train_batch_size=BATCH_SIZE,
|
||||
eval_batch_size=BATCH_SIZE)
|
||||
train_dataloader = data_builder.train_dataloader()
|
||||
```
|
||||
|
||||
We use `2` pipeline stages and the batch will be split into `4` micro batches.
|
||||
Define a booster with the `HybridParallelPlugin`.
|
||||
```python
|
||||
plugin = HybridParallelPlugin(tp_size=1,
|
||||
pp_size=2,
|
||||
num_microbatches=None,
|
||||
microbatch_size=1,
|
||||
enable_all_optimization=True,
|
||||
zero_stage=1,
|
||||
precision='fp16',
|
||||
initial_scale=1)
|
||||
booster = Booster(plugin=plugin)
|
||||
```
|
||||
|
||||
Boost these train componts with the booster created.
|
||||
```python
|
||||
model, optimizer, _criterion, _, lr_scheduler = booster.boost(model,
|
||||
optimizer,
|
||||
criterion=_criterion,
|
||||
lr_scheduler=lr_scheduler)
|
||||
```
|
||||
|
||||
Train the model at last.
|
||||
|
||||
```python
|
||||
# Define a train function
|
||||
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler,
|
||||
train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
|
||||
|
||||
is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
|
||||
total_step = len(train_dataloader)
|
||||
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
# convert train_dataloader to a iterator
|
||||
train_dataloader_iter = iter(train_dataloader)
|
||||
with tqdm(range(total_step),
|
||||
desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
|
||||
disable=not (is_pp_last_stage)) as pbar:
|
||||
# Forward pass
|
||||
for _ in pbar:
|
||||
outputs = booster.execute_pipeline(train_dataloader_iter,
|
||||
model,
|
||||
_criterion,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True)
|
||||
# Backward and optimize
|
||||
if is_pp_last_stage:
|
||||
loss = outputs['loss']
|
||||
pbar.set_postfix({'loss': loss.item()})
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
lr_scheduler.step()
|
||||
|
||||
# Train model
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
|
||||
```
|
||||
|
||||
We use `2` pipeline stages and the micro batches is 1. (these parameters can be configured to an appropriate value)
|
||||
<!-- doc-test-command: echo -->
|
||||
|
Reference in New Issue
Block a user