mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 22:10:37 +00:00
[doc] update advanced tutorials, training gpt with hybrid parallelism (#4866)
* [doc]update advanced tutorials, training gpt with hybrid parallelism * [doc]update advanced tutorials, training gpt with hybrid parallelism * update vit tutorials * update vit tutorials * update vit tutorials * update vit tutorials * update en/train_vit_with_hybrid_parallel.py * fix * resolve comments * fix
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
# Train GPT Using Hybrid Parallelism
|
||||
# Fine-tune GPT-2 Using Hybrid Parallelism
|
||||
|
||||
Author: Hongxin Liu, Yongbin Li
|
||||
Author: Hongxin Liu, Yongbin Li, Mingyan Jiang
|
||||
|
||||
**Prerequisite:**
|
||||
- [parallelism plugin](../basics/booster_plugins.md)
|
||||
- [booster API](../basics/booster_api.md)
|
||||
|
||||
**Example Code**
|
||||
- [ColossalAI-Examples GPT2](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt_2)
|
||||
- [ColossalAI-Examples GPT3](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt_3)
|
||||
- [ColossalAI-Examples GPT](https://github.com/hpcaitech/ColossalAI/blob/main/examples/language/gpt/hybridparallelism/finetune.py)
|
||||
|
||||
|
||||
**Related Paper**
|
||||
- [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883)
|
||||
@@ -12,260 +16,192 @@ Author: Hongxin Liu, Yongbin Li
|
||||
|
||||
## Introduction
|
||||
|
||||
In the previous tutorial, we introduce how to train ViT with pipeline. In this tutorial, you will learn a more complex scenario -- train GPT with hybrid parallelism. In this case, GPT-3 is so large that CPU memory cannot fit it as well. Therefore, you must split the model by yourself.
|
||||
In the previous tutorial, we introduce how to train ViT with pipeline. In this tutorial, you will learn a more complex scenario -- fine-tune GPT-2 with hybrid parallelism. In this case, GPT-2 is so large that CPU memory cannot fit it as well. Therefore, you must split the model.
|
||||
|
||||
## Table of content
|
||||
|
||||
In this tutorial we will cover:
|
||||
|
||||
1. The definition of GPT model, based on colossalai/model_zoo
|
||||
2. Processing the dataset
|
||||
3. Training GPT using hybrid parallelism
|
||||
1. Initialize the hybrid parallelism plugin.
|
||||
2. Defining the Training Components of the GPT-2 Model
|
||||
3. Boost the GPT-2 Model with [`HybridParallelPlugin`](../basics/booster_plugins.md)
|
||||
4. Training GPT-2 using hybrid parallelism
|
||||
|
||||
## Import libraries
|
||||
|
||||
```python
|
||||
import json
|
||||
import os
|
||||
from typing import Callable
|
||||
from typing import Callable, List, Union
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, GPT2ForSequenceClassification, get_linear_schedule_with_warmup
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
import colossalai.utils as utils
|
||||
import model_zoo.gpt.gpt as col_gpt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai import nn as col_nn
|
||||
from colossalai.amp import AMP_TYPE
|
||||
from colossalai.legacy.builder.pipeline import partition_uniform
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
|
||||
PipelineSchedule)
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper
|
||||
from colossalai.legacy.trainer import Trainer, hooks
|
||||
from colossalai.utils.timer import MultiTimer
|
||||
from model_zoo.gpt import GPTLMLoss
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import GPT2Tokenizer
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Define GPT model
|
||||
|
||||
In the previous tutorial, we introduced 3 ways to build a pipelined model. But for huge models like GPT-3, you can't even build the model in CPU. In this case, you must split the model by yourself.
|
||||
|
||||
GPT dataloader returns `input_ids` and `attention_mask`, so we use two keyword arguments in `forward()` to get them. Note that for stages except the first stage, the first positional argument of `forward()` is the output tensor from the previous stage. So the `hidden_states` is from the previous stage, and for the first stage it's `None`.
|
||||
|
||||
For GPT, the *word embedding layer* shares the weights with the *output head*. We provide `PipelineSharedModuleWrapper` to share parameters among pipeline stages. It takes a `list` of `int` as argument, which means those ranks share the parameters. You can use `register_module()` or `register_parameter()` to register a module or a parameter as the shared module or parameter. If you have multiple sets of shared modules / parameters, you should have multiple `PipelineSharedModuleWrapper` instance. If the parameter is shared within **one** stage, you should not use `PipelineSharedModuleWrapper`, and just use the same module / parameter instance. In this example, the *word embedding layer* is at the first stage, and the *output head* is at the last stage. Thus, they are shared among ranks `[0, pipeline_size - 1]`.
|
||||
|
||||
For the first stage, it maintains the embedding layer and some transformer blocks. For the last stage, it maintains some transformer blocks and the output head layer. For other stages, they just maintain some transformer blocks. `partition_uniform(num_layers, pipeline_size, num_chunks)` returns the parts of all ranks, and the part is a `tuple` of `(start, end)` (exclude end). `start == 0` means that it's the first stage, and `end == num_layers` means it's the last stage.
|
||||
|
||||
## Define Plugin
|
||||
Create a `HybridParallelPlugin` object and specify the desired parallelism strategies to be used. In this example, both pipeline parallelism and ZeRO-1 are used simultaneously.
|
||||
```python
|
||||
class PipelineGPTHybrid(nn.Module):
|
||||
def __init__(self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.,
|
||||
act_func: Callable = F.gelu,
|
||||
mlp_ratio: int = 4,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_rate: float = 0.,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
first: bool = False,
|
||||
last: bool = False):
|
||||
super().__init__()
|
||||
self.embedding = None
|
||||
self.norm = None
|
||||
self.head = None
|
||||
if first:
|
||||
self.embedding = col_gpt.GPTEmbedding(
|
||||
hidden_size, vocab_size, max_position_embeddings, dropout=embed_drop_rate, dtype=dtype)
|
||||
self.blocks = nn.ModuleList([
|
||||
col_gpt.GPTBlock(hidden_size, num_attention_heads, mlp_ratio=mlp_ratio, attention_dropout=attn_drop_rate,
|
||||
dropout=drop_rate, dtype=dtype, checkpoint=checkpoint, activation=act_func)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
if last:
|
||||
self.norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.head = col_gpt.GPTLMHead(vocab_size=vocab_size,
|
||||
dim=hidden_size,
|
||||
dtype=dtype,
|
||||
bias=False)
|
||||
|
||||
def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
|
||||
if self.embedding is not None:
|
||||
hidden_states = self.embedding(input_ids=input_ids)
|
||||
batch_size = hidden_states.shape[0]
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
for block in self.blocks:
|
||||
hidden_states, attention_mask = block(hidden_states, attention_mask)
|
||||
if self.norm is not None:
|
||||
hidden_states = self.head(self.norm(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
|
||||
def build_gpt_pipeline(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||||
logger = get_dist_logger()
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
rank = gpc.get_global_rank()
|
||||
wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
|
||||
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
|
||||
models = []
|
||||
for start, end in parts:
|
||||
kwargs['num_layers'] = end - start
|
||||
kwargs['first'] = start == 0
|
||||
kwargs['last'] = end == num_layers
|
||||
logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
|
||||
chunk = PipelineGPTHybrid(**kwargs).to(device)
|
||||
if start == 0:
|
||||
wrapper.register_module(chunk.embedding.word_embeddings)
|
||||
elif end == num_layers:
|
||||
wrapper.register_module(chunk.head)
|
||||
models.append(chunk)
|
||||
if len(models) == 1:
|
||||
model = models[0]
|
||||
else:
|
||||
model = nn.ModuleList(models)
|
||||
return model
|
||||
|
||||
|
||||
def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float):
|
||||
cfg = dict(hidden_size=1600, num_attention_heads=32, checkpoint=checkpoint, dtype=dtype)
|
||||
return build_gpt_pipeline(48, num_chunks, **cfg)
|
||||
|
||||
|
||||
def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float):
|
||||
cfg = dict(hidden_size=12288, num_attention_heads=96,
|
||||
checkpoint=checkpoint, max_position_embeddings=2048, dtype=dtype)
|
||||
return build_gpt_pipeline(96, num_chunks, **cfg)
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=1,
|
||||
pp_size=2,
|
||||
num_microbatches=None,
|
||||
microbatch_size=1,
|
||||
enable_all_optimization=True,
|
||||
zero_stage=1,
|
||||
precision="fp16",
|
||||
initial_scale=1,
|
||||
)
|
||||
```
|
||||
## Define GPT-2's Training Components
|
||||
|
||||
## Process the dataset
|
||||
|
||||
We provide a small GPT web-text dataset here. The original format is loose JSON, and we will save the processed dataset.
|
||||
Before using hybrid parallelism, you need to define the components used for training.
|
||||
|
||||
Define hyperparameters
|
||||
```python
|
||||
class WebtextDataset(Dataset):
|
||||
def __init__(self, path, seq_len=1024) -> None:
|
||||
super().__init__()
|
||||
root = os.path.dirname(path)
|
||||
encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
|
||||
if os.path.isfile(encoded_data_cache_path):
|
||||
seq_len_, data, attention_mask = torch.load(
|
||||
encoded_data_cache_path)
|
||||
if seq_len_ == seq_len:
|
||||
self.data = data
|
||||
self.attention_mask = attention_mask
|
||||
return
|
||||
raw_data = []
|
||||
with open(path) as f:
|
||||
for line in f.readlines():
|
||||
raw_data.append(json.loads(line)['text'])
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
encoded_data = tokenizer(
|
||||
raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt')
|
||||
self.data = encoded_data['input_ids']
|
||||
self.attention_mask = encoded_data['attention_mask']
|
||||
torch.save((seq_len, self.data, self.attention_mask),
|
||||
encoded_data_cache_path)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return {
|
||||
'input_ids': self.data[index],
|
||||
'attention_mask': self.attention_mask[index]
|
||||
}, self.data[index]
|
||||
NUM_EPOCHS = 3
|
||||
BATCH_SIZE = 32
|
||||
LEARNING_RATE = 2.4e-5
|
||||
WEIGHT_DECAY = 0.01
|
||||
WARMUP_FRACTION = 0.1
|
||||
```
|
||||
|
||||
## Training GPT using hybrid parallelism
|
||||
|
||||
In the previous tutorial, we explained the meanings of some pipeline arguments. In this case, we can determine the shape of each output tensor which is exchanged among pipeline stages. For GPT, the shape is `(MICRO BATCH SIZE, SEQUENCE LEN, HIDDEN SIZE)`. By setting this, we can avoid exchanging the tensor shape of each stage. When you are not sure of the tensor shape, you can just leave it `None`, and the shape is inferred automatically. Make sure that the `dtype` of your model is correct. When you use `fp16`, the `dtype` of your model must be `torch.half`. Otherwise, the `dtype` must be `torch.float`. For pipeline parallelism, only `AMP_TYPE.NAIVE` is supported.
|
||||
|
||||
You can easily use tensor parallel by setting `parallel` in `CONFIG`. The data parallelism size is automatically set based on the number of GPUs.
|
||||
|
||||
we create a distributed environment.
|
||||
```python
|
||||
NUM_EPOCHS = 60
|
||||
SEQ_LEN = 1024
|
||||
BATCH_SIZE = 192
|
||||
NUM_CHUNKS = None
|
||||
TENSOR_SHAPE = (1, 1024, 1600)
|
||||
# only pipeline parallel
|
||||
# CONFIG = dict(parallel=dict(pipeline=2), fp16=dict(mode=AMP_TYPE.NAIVE))
|
||||
# pipeline + 1D model parallel
|
||||
CONFIG = dict(NUM_MICRO_BATCHES = 192, parallel=dict(pipeline=2, tensor=dict(mode='1d', size=2)), fp16=dict(mode=AMP_TYPE.NAIVE))
|
||||
|
||||
|
||||
def train():
|
||||
disable_existing_loggers()
|
||||
parser = colossalai.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
colossalai.launch_from_torch(config=CONFIG, backend=args.backend)
|
||||
logger = get_dist_logger()
|
||||
|
||||
train_ds = WebtextDataset(os.environ['DATA'], seq_len=SEQ_LEN)
|
||||
train_dataloader = utils.get_dataloader(train_ds,
|
||||
seed=42,
|
||||
batch_size=BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
|
||||
use_interleaved = NUM_CHUNKS is not None
|
||||
num_chunks = 1 if not use_interleaved else NUM_CHUNKS
|
||||
model = GPT2_exlarge_pipeline_hybrid(num_chunks=num_chunks, checkpoint=True, dtype=torch.half)
|
||||
# model = GPT3_pipeline_hybrid(num_chunks=num_chunks, checkpoint=True, dtype=torch.half)
|
||||
if use_interleaved and not isinstance(model, nn.ModuleList):
|
||||
model = nn.ModuleList([model])
|
||||
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.00015, weight_decay=1e-2,)
|
||||
|
||||
engine, train_dataloader, _, _ = colossalai.initialize(model,
|
||||
optimizer,
|
||||
criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
global_batch_size = BATCH_SIZE * \
|
||||
gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1)
|
||||
logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0])
|
||||
|
||||
timer = MultiTimer()
|
||||
|
||||
trainer = Trainer(
|
||||
engine=engine,
|
||||
logger=logger,
|
||||
timer=timer
|
||||
)
|
||||
|
||||
hook_list = [
|
||||
hooks.LossHook(),
|
||||
hooks.LogMetricByEpochHook(logger),
|
||||
hooks.ThroughputHook(),
|
||||
hooks.LogMetricByStepHook(),
|
||||
]
|
||||
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader,
|
||||
epochs=NUM_EPOCHS,
|
||||
test_interval=1,
|
||||
hooks=hook_list,
|
||||
display_progress=True,
|
||||
return_output_label=False,
|
||||
)
|
||||
# Launch ColossalAI
|
||||
colossalai.launch_from_torch(config={}, seed=42)
|
||||
coordinator = DistCoordinator()
|
||||
```
|
||||
<!-- doc-test-command: echo -->
|
||||
prepare the dataset. You can use `plugin.prepare_dataloader` to generate a dataloader or customize your own dataloader.
|
||||
```python
|
||||
def tokenize_batch(batch, tokenizer: Optional[AutoTokenizer] = None, max_length: int = 2048):
|
||||
texts = [sample["sentence1"] + sample["sentence2"] for sample in batch]
|
||||
data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length)
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
data["labels"] = data["input_ids"].clone()
|
||||
return data
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
dataset = datasets.load_dataset("glue", "mrpc")
|
||||
train_dataloader = plugin.prepare_dataloader(
|
||||
dataset["train"],
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=partial(tokenize_batch, tokenizer=tokenizer, max_length=512),
|
||||
)
|
||||
```
|
||||
Prepare gpt-2 model
|
||||
```python
|
||||
cfg = AutoConfig.from_pretrained("gpt2", num_labels=2)
|
||||
model = GPT2ForSequenceClassification.from_pretrained("gpt2", config=cfg).cuda()
|
||||
|
||||
```
|
||||
prepare optimizer
|
||||
```python
|
||||
lr = LEARNING_RATE * coordinator.world_size
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": WEIGHT_DECAY,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)
|
||||
```
|
||||
Prepare the lr_scheduler and criterion, and it's important to note that when hybrid parallelism with pipeline parallelism is used, a criterion function should also be defined. This function should take the input and output of the model's forward pass as parameters and return the loss.
|
||||
```python
|
||||
# lr scheduler
|
||||
total_steps = len(train_dataloader) * NUM_EPOCHS
|
||||
num_warmup_steps = int(WARMUP_FRACTION * total_steps)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=total_steps,
|
||||
)
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
return outputs.loss
|
||||
```
|
||||
## Boost the GPT-2 Model
|
||||
Define a booster with `HybridParallelPlugin`. Based on the configured plugin parameters, the booster will inject one or more parallel strategies into the model. In this example, pipeline parallelism, zero1, and mixed-precision training optimizations are utilized.
|
||||
```python
|
||||
booster = Booster(plugin=plugin)
|
||||
```
|
||||
Boost these components with the defined booster.
|
||||
```python
|
||||
model, optimizer, _criterion, _, lr_scheduler = booster.boost(
|
||||
model, optimizer, criterion=_criterion, lr_scheduler=lr_scheduler
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
## Training GPT-2 using hybrid parallelism
|
||||
|
||||
In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training.
|
||||
Define a training function. When pipeline parallelism is used, you need to call `booster.execute_pipeline` to schedule the stages of model training.
|
||||
```python
|
||||
def train_epoch(
|
||||
epoch: int,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
_criterion: Callable,
|
||||
lr_scheduler: LRScheduler,
|
||||
train_dataloader: DataLoader,
|
||||
booster: Booster,
|
||||
coordinator: DistCoordinator,
|
||||
):
|
||||
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
|
||||
total_step = len(train_dataloader)
|
||||
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
train_dataloader_iter = iter(train_dataloader)
|
||||
with tqdm(
|
||||
range(total_step),
|
||||
desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]",
|
||||
disable=not print_flag,
|
||||
) as pbar:
|
||||
# Forward pass
|
||||
for _ in pbar:
|
||||
if use_pipeline:
|
||||
outputs = booster.execute_pipeline(
|
||||
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
|
||||
)
|
||||
# Backward and optimize
|
||||
if is_pp_last_stage:
|
||||
loss = outputs["loss"]
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
else:
|
||||
data = next(train_dataloader_iter)
|
||||
data = move_to_cuda(data)
|
||||
outputs = model(**data)
|
||||
loss = _criterion(outputs, None)
|
||||
# Backward
|
||||
booster.backward(loss, optimizer)
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
lr_scheduler.step()
|
||||
|
||||
```
|
||||
Training the gpt-2 model
|
||||
```python
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
|
||||
```
|
||||
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 train_gpt_using_hybrid_parallelism.py -->
|
@@ -1,248 +0,0 @@
|
||||
# Train ViT Using Pipeline Parallelism
|
||||
|
||||
Author: Hongxin Liu, Yongbin Li
|
||||
|
||||
**Example Code**
|
||||
- [ColossalAI-Examples Pipeline Parallel ViT](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/vision_transformer/pipeline_parallel)
|
||||
|
||||
**Related Paper**
|
||||
- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473)
|
||||
|
||||
## Introduction
|
||||
|
||||
In this tutorial, you will learn how to train Vision Transformer for image classification from scratch, using pipeline.
|
||||
Pipeline parallelism is a kind of model parallelism, which is useful when your GPU memory cannot fit your model.
|
||||
By using it, we split the original model into multi stages, and each stage maintains a part of the original model.
|
||||
We assume that your GPU memory cannot fit ViT/L-16, and your memory can fit this model.
|
||||
|
||||
## Table of contents
|
||||
|
||||
In this tutorial we will cover:
|
||||
|
||||
1. The definition of ViT model, based on [TIMM](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py)
|
||||
2. Processing the dataset
|
||||
3. Training ViT using pipeline
|
||||
|
||||
## Import libraries
|
||||
|
||||
```python
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import colossalai.nn as col_nn
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.legacy.builder import build_pipeline_model
|
||||
from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
|
||||
PipelineSchedule)
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.legacy.trainer import Trainer, hooks
|
||||
from colossalai.utils import MultiTimer, get_dataloader
|
||||
from timm.models import vision_transformer as vit
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Define Vision Transformer model
|
||||
|
||||
Generally, we provide 3 ways to build a pipelined model:
|
||||
|
||||
1. `colossalai.legacy.builder.build_pipeline_model_from_cfg`
|
||||
2. `colossalai.legacy.builder.build_pipeline_model`
|
||||
3. Split the model by stages by yourself
|
||||
|
||||
When your memory can fit the model, you can use the first two methods to build your model, otherwise you must split the model by yourself. The first two methods first build the whole model on CPU, then split the model, and finally you can just move the corresponding part of model to GPU.
|
||||
|
||||
`colossalai.legacy.builder.build_pipeline_model_from_cfg()` receives a config file of model, and it can split the model uniformly (by layer) or balanced (by parameter size).
|
||||
|
||||
If you are familiar with `PyTorch`, you can use `colossalai.legacy.builder.build_pipeline_model()` which receives a `torch.nn.Sequential` model and split it by layer uniformly.
|
||||
|
||||
In this tutorial, we will modify [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential` and then use `colossalai.legacy.builder.build_pipeline_model()` to build the pipelined model.
|
||||
|
||||
When the data is **one** `Tensor`, you can use the positional argument in `forward()` of your model to get the data tensor. For the first stage of pipeline, the first positional argument of `forward()` is the data tensor loaded from data loader. For other stages, the first positional argument of `forward()` is the output tensor from the previous stage. Note that if the stage is not the last stage, the return of `forward()` must be a `Tensor`.
|
||||
|
||||
When the data is a `dict` of `Tensor`, you can use named keyword arguments in `forward()` of your model to get the data `dict`.
|
||||
|
||||
```python
|
||||
class ViTEmbedding(nn.Module):
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, embed_layer=vit.PatchEmbed, drop_rate=0., distilled=False):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_tokens = 2 if distilled else 1
|
||||
self.patch_embed = embed_layer(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
if self.dist_token is None:
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
else:
|
||||
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
return x
|
||||
|
||||
def init_weights(self):
|
||||
vit.trunc_normal_(self.pos_embed, std=.02)
|
||||
if self.dist_token is not None:
|
||||
vit.trunc_normal_(self.dist_token, std=.02)
|
||||
vit.trunc_normal_(self.cls_token, std=.02)
|
||||
self.apply(vit._init_vit_weights)
|
||||
|
||||
|
||||
class ViTHead(nn.Module):
|
||||
def __init__(self, embed_dim=768, num_classes=1000, norm_layer=None, distilled=False, representation_size=None):
|
||||
super().__init__()
|
||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||
self.norm = norm_layer(embed_dim)
|
||||
self.num_classes = num_classes
|
||||
self.distilled = distilled
|
||||
self.num_features = embed_dim
|
||||
# Representation layer
|
||||
if representation_size and not distilled:
|
||||
self.num_features = representation_size
|
||||
self.pre_logits = nn.Sequential(OrderedDict([
|
||||
('fc', nn.Linear(embed_dim, representation_size)),
|
||||
('act', nn.Tanh())
|
||||
]))
|
||||
else:
|
||||
self.pre_logits = nn.Identity()
|
||||
# Classifier head(s)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head_dist = None
|
||||
if distilled:
|
||||
self.head_dist = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
if self.distilled:
|
||||
x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1])
|
||||
if self.training and not torch.jit.is_scripting():
|
||||
# during inference, return the average of both classifier predictions
|
||||
return x, x_dist
|
||||
else:
|
||||
return (x + x_dist) / 2
|
||||
else:
|
||||
x = self.pre_logits(x[:, 0])
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
def init_weights(self):
|
||||
self.apply(vit._init_vit_weights)
|
||||
|
||||
|
||||
def sequential_vit(img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=vit.PatchEmbed, norm_layer=None,
|
||||
act_layer=None):
|
||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||
act_layer = act_layer or nn.GELU
|
||||
embedding = ViTEmbedding(img_size=img_size, patch_size=patch_size, in_chans=in_chans,
|
||||
embed_dim=embed_dim, embed_layer=embed_layer, drop_rate=drop_rate, distilled=distilled)
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
blocks = [vit.Block(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
|
||||
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
|
||||
for i in range(depth)]
|
||||
for block in blocks:
|
||||
block.apply(vit._init_vit_weights)
|
||||
head = ViTHead(embed_dim=embed_dim, num_classes=num_classes, norm_layer=norm_layer,
|
||||
distilled=distilled, representation_size=representation_size)
|
||||
return nn.Sequential(embedding, *blocks, head)
|
||||
|
||||
|
||||
def vit_large_patch16_224(**kwargs):
|
||||
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
||||
return sequential_vit(**model_kwargs)
|
||||
```
|
||||
|
||||
## Process the dataset
|
||||
|
||||
Generally, we train ViT on large dataset like Imagenet. For simplicity, we just use CIFAR-10 here, since this tutorial is just for pipeline training.
|
||||
|
||||
```python
|
||||
def build_cifar(batch_size):
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(224, pad_if_needed=True),
|
||||
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
transform_test = transforms.Compose([
|
||||
transforms.Resize(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
|
||||
train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train)
|
||||
test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test)
|
||||
train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True)
|
||||
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True)
|
||||
return train_dataloader, test_dataloader
|
||||
```
|
||||
|
||||
## Training ViT using pipeline
|
||||
|
||||
You can set the size of pipeline parallel and number of microbatches in config. `NUM_CHUNKS` is useful when using interleaved-pipeline (for more details see [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) ). The original batch will be split into `num_microbatches`, and each stage will load a micro batch each time. Then we will generate an appropriate schedule for you to execute the pipeline training. If you don't need the output and label of model, you can set `return_output_label` to `False` when calling `trainer.fit()` which can further reduce GPU memory usage.
|
||||
|
||||
You should `export DATA=/path/to/cifar`.
|
||||
|
||||
```python
|
||||
BATCH_SIZE = 16
|
||||
NUM_EPOCHS = 60
|
||||
NUM_CHUNKS = 1
|
||||
CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2))
|
||||
|
||||
|
||||
def train():
|
||||
disable_existing_loggers()
|
||||
parser = colossalai.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
colossalai.launch_from_torch(backend=args.backend, config=CONFIG)
|
||||
logger = get_dist_logger()
|
||||
|
||||
# build model
|
||||
model = vit_large_patch16_224()
|
||||
model = build_pipeline_model(model, num_chunks=NUM_CHUNKS, verbose=True)
|
||||
|
||||
# build criterion
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# optimizer
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
|
||||
|
||||
# build dataloader
|
||||
train_dataloader, test_dataloader = build_cifar(BATCH_SIZE)
|
||||
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, optimizer, criterion,
|
||||
train_dataloader, test_dataloader)
|
||||
timer = MultiTimer()
|
||||
|
||||
trainer = Trainer(engine=engine, timer=timer, logger=logger)
|
||||
|
||||
hook_list = [
|
||||
hooks.LossHook(),
|
||||
hooks.AccuracyHook(col_nn.metric.Accuracy()),
|
||||
hooks.LogMetricByEpochHook(logger),
|
||||
]
|
||||
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
epochs=NUM_EPOCHS,
|
||||
test_dataloader=test_dataloader,
|
||||
test_interval=1,
|
||||
hooks=hook_list,
|
||||
display_progress=True)
|
||||
```
|
||||
<!-- doc-test-command: echo -->
|
@@ -1,10 +1,14 @@
|
||||
# Step By Step: Accelerate ViT Training With Colossal-AI (From Data Parallel to Hybrid Parallel)
|
||||
|
||||
Author: Yuxuan Lou
|
||||
Author: Yuxuan Lou, Mingyan Jiang
|
||||
|
||||
**Prerequisite:**
|
||||
- [parallelism plugin](../basics/booster_plugins.md)
|
||||
- [booster API](../basics/booster_api.md)
|
||||
|
||||
**Example Code**
|
||||
|
||||
- [Colossal-AI Examples ViT on Cifar10](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/vision_transformer)
|
||||
- [Colossal-AI Examples ViT on `beans`](https://github.com/hpcaitech/ColossalAI/blob/main/examples/images/vit/vit_train_demo.py)
|
||||
|
||||
**Related Paper**
|
||||
- [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf)
|
||||
@@ -13,14 +17,14 @@ Author: Yuxuan Lou
|
||||
## Introduction
|
||||
|
||||
In this example for ViT model, Colossal-AI provides three different parallelism techniques which accelerate model training: data parallelism, pipeline parallelism and tensor parallelism.
|
||||
We will show you how to train ViT on CIFAR-10 dataset with these parallelism techniques. To run this example, you will need 2-4 GPUs.
|
||||
We will show you how to train ViT on `beans` dataset with these parallelism techniques. To run this example, you will need 2-4 GPUs.
|
||||
|
||||
|
||||
## Table of Contents
|
||||
1. Colossal-AI installation
|
||||
2. Steps to train ViT with data parallelism
|
||||
3. Steps to train ViT with pipeline parallelism
|
||||
4. Steps to train ViT with tensor parallelism or hybrid parallelism
|
||||
2. Define the ViT model and related training components.
|
||||
3. Boost the VIT Model with [`HybridParallelPlugin`](../basics/booster_plugins.md)
|
||||
4. Train the VIT model using data parallelism, pipeline parallelism, and tensor parallelism.
|
||||
|
||||
## Colossal-AI Installation
|
||||
You can install Colossal-AI package and its dependencies with PyPI.
|
||||
@@ -29,619 +33,250 @@ pip install colossalai
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Data Parallelism
|
||||
Data parallelism is one basic way to accelerate model training process. You can apply data parallelism to training by only two steps:
|
||||
1. Define a configuration file
|
||||
2. Change a few lines of code in train script
|
||||
|
||||
### Define your configuration file (`data_parallel/config.py`)
|
||||
To use Colossal-AI, the first step is to define a configuration file. And there are two kinds of variables here:
|
||||
|
||||
1. **Colossal-AI feature specification**
|
||||
|
||||
There is an array of features Colossal-AI provides to speed up training (parallel mode, mixed precision, ZeRO, etc.). Each feature is defined by a corresponding field in the config file. If we apply data parallel only, we do not need to specify the parallel mode. In this example, we use mixed precision training natively provided by PyTorch by define the mixed precision configuration `fp16 = dict(mode=AMP_TYPE.TORCH)`.
|
||||
|
||||
2. **Global hyper-parameters**
|
||||
|
||||
Global hyper-parameters include model-specific hyper-parameters, training settings, dataset information, etc.
|
||||
|
||||
## Import libraries
|
||||
```python
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
# ViT Base
|
||||
BATCH_SIZE = 256
|
||||
DROP_RATE = 0.1
|
||||
NUM_EPOCHS = 300
|
||||
|
||||
# mix precision
|
||||
fp16 = dict(
|
||||
mode=AMP_TYPE.TORCH,
|
||||
)
|
||||
|
||||
gradient_accumulation = 16
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
dali = dict(
|
||||
gpu_aug=True,
|
||||
mixup_alpha=0.2
|
||||
)
|
||||
```
|
||||
|
||||
### Modify train script (`/data_parallel/train_with_cifar10.py`)
|
||||
|
||||
#### Import modules
|
||||
- Colossal-AI related modules
|
||||
```python
|
||||
import colossalai
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import LinearWarmupLR
|
||||
from colossalai.legacy.nn.metric import Accuracy
|
||||
from colossalai.legacy.trainer import Trainer, hooks
|
||||
```
|
||||
|
||||
- Other modules
|
||||
```python
|
||||
import os
|
||||
from typing import Any, Callable, Iterator
|
||||
|
||||
import torch
|
||||
from timm.models import vit_base_patch16_224
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from data import BeansDataset, beans_collator
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor
|
||||
|
||||
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
```
|
||||
|
||||
#### Launch Colossal-AI
|
||||
|
||||
In train script, you need to initialize the distributed environment for Colossal-AI after your config file is prepared. We call this process `launch`. In Colossal-AI, we provided several launch methods to initialize the distributed backend. In most cases, you can use `colossalai.launch` and `colossalai.get_default_parser` to pass the parameters via command line. Besides, Colossal-AI can utilize the existing launch tool provided by PyTorch as many users are familiar with by using `colossalai.launch_from_torch`. For more details, you can view the related [documents](https://www.colossalai.org/docs/basics/launch_colossalai).
|
||||
## Define the Vision Transformer (VIT) model.
|
||||
Define hyperparameters.
|
||||
```python
|
||||
SEED = 42
|
||||
MODEL_PATH = "google/vit-base-patch16-224"
|
||||
LEARNING_RATE = 5e-5
|
||||
WEIGHT_DECAY = 0.0
|
||||
NUM_EPOCH = 3
|
||||
WARMUP_RATIO = 0.3
|
||||
TP_SIZE = 2
|
||||
PP_SIZE = 2
|
||||
```
|
||||
Create a distributed environment.
|
||||
```python
|
||||
# Launch ColossalAI
|
||||
colossalai.launch_from_torch(config={}, seed=SEEDå)
|
||||
coordinator = DistCoordinator()
|
||||
world_size = coordinator.world_size
|
||||
```
|
||||
Before training, you can define the relevant components of the model training process as usual, such as defining the model, data loaders, optimizer, and so on. It's important to note that when using pipeline parallelism, you also need to define a criterion function. This function takes the input and output of the model forward pass as inputs and returns the loss.
|
||||
Prepare the dataset. BeansDataset is defined in [data.py](https://github.com/hpcaitech/ColossalAI/blob/main/examples/images/vit/data.py).
|
||||
|
||||
```python
|
||||
# initialize distributed setting
|
||||
parser = colossalai.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
|
||||
disable_existing_loggers()
|
||||
logger = get_dist_logger()
|
||||
image_processor = ViTImageProcessor.from_pretrained(MODEL_PATH)
|
||||
train_dataset = BeansDataset(image_processor, TP_SIZE, split="train")
|
||||
eval_dataset = BeansDataset(image_processor, RP_SIZE, split="validation")
|
||||
num_labels = train_dataset.num_labels
|
||||
```
|
||||
|
||||
After initialization, you can access the variables in the config file by using `colossalai.core.global_context`.
|
||||
|
||||
Define the VIT model:
|
||||
```python
|
||||
#access parameters
|
||||
print(gpc.config.BATCH_SIZE)
|
||||
config = ViTConfig.from_pretrained(MODEL_PATH)
|
||||
config.num_labels = num_labels
|
||||
config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)}
|
||||
config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)}
|
||||
model = ViTForImageClassification.from_pretrained(
|
||||
MODEL_PATH, config=config, ignore_mismatched_sizes=True
|
||||
)
|
||||
```
|
||||
|
||||
#### Build Model
|
||||
|
||||
If only data parallelism is required, you do not need to make any changes to your model. Here, we use `vit_base_patch16_224` from `timm`.
|
||||
Define the optimizer:
|
||||
```python
|
||||
# build model
|
||||
model = vit_base_patch16_224(drop_rate=0.1, num_classes=gpc.config.NUM_CLASSES)
|
||||
optimizer = HybridAdam(model.parameters(), lr=(LEARNING_RATE * world_size), weight_decay=WEIGHT_DECAY)
|
||||
```
|
||||
|
||||
#### Build CIFAR-10 Dataloader
|
||||
`colossalai.utils.get_dataloader` can help you build dataloader easily.
|
||||
|
||||
Define the learning rate scheduler:
|
||||
```python
|
||||
def build_cifar(batch_size):
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(224, pad_if_needed=True),
|
||||
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
transform_test = transforms.Compose([
|
||||
transforms.Resize(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
|
||||
train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train)
|
||||
test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test)
|
||||
train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True)
|
||||
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True)
|
||||
return train_dataloader, test_dataloader
|
||||
|
||||
|
||||
# build dataloader
|
||||
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE)
|
||||
```
|
||||
|
||||
#### Define optimizer, loss function and LR scheduler
|
||||
|
||||
Colossal-AI provides its own optimizer, loss function and LR scheduler. Those from PyTorch are also compatible.
|
||||
|
||||
```python
|
||||
# build optimizer
|
||||
optimizer = colossalai.nn.Lamb(model.parameters(), lr=1.8e-2, weight_decay=0.1)
|
||||
|
||||
# build loss
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
# lr_scheduler
|
||||
lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=gpc.config.NUM_EPOCHS)
|
||||
```
|
||||
|
||||
#### Start Colossal-AI engine
|
||||
|
||||
Engine is essentially a wrapper class for model, optimizer and loss function. When we call `colossalai.initialize`, an engine object will be returned, and it has already been equipped with functionalities such as gradient clipping, gradient accumulation and zero optimizer as specified in your configuration file. Further model training is based on Colossal-AI engine.
|
||||
|
||||
```python
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(
|
||||
model, optimizer, criterion, train_dataloader, test_dataloader
|
||||
total_steps = len(train_dataloader) * NUM_EPOCH
|
||||
num_warmup_steps = int(WARMUP_RATIO * total_steps)
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optimizer, total_steps=(len(train_dataloader) * NUM_EPOCH), warmup_steps=num_warmup_steps
|
||||
)
|
||||
```
|
||||
|
||||
#### Train: Trainer API
|
||||
Trainer is a more high-level wrapper for the user to execute training with fewer lines of code. It is easy to create a trainer object by passing the engine object.
|
||||
|
||||
Besides, In trainer, the user can customize some hooks and attach these hooks to the trainer object. A hook object will execute life-cycle methods periodically based on the training scheme. For example, The `LRSchedulerHook` will execute `lr_scheduler.step()` to update the learning rate of the model during either `after_train_iter` or `after_train_epoch` stages.
|
||||
|
||||
Define the criterion function:
|
||||
```python
|
||||
# build trainer
|
||||
trainer = Trainer(engine=engine, logger=logger)
|
||||
|
||||
# build hooks
|
||||
hook_list = [
|
||||
hooks.LossHook(),
|
||||
hooks.AccuracyHook(accuracy_func=MixupAccuracy()),
|
||||
hooks.LogMetricByEpochHook(logger),
|
||||
hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
|
||||
|
||||
# comment if you do not need to use the hooks below
|
||||
hooks.SaveCheckpointHook(interval=1, checkpoint_dir='./ckpt'),
|
||||
hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]),
|
||||
]
|
||||
def _criterion(outputs, inputs):
|
||||
return outputs.loss
|
||||
```
|
||||
|
||||
Use `trainer.fit` for training:
|
||||
|
||||
## Boost the VIT Model
|
||||
We begin using ColossalAI's hybrid parallelism strategy to enhance the model. First, let's define an object of `HybridParallelPlugin`. `HybridParallelPlugin` encapsulates various parallelism strategies in ColossalAI. Afterward, we use the `HybridParallelPlugin` object to initialize the booster and boost the VIT model.
|
||||
### Training with AMP
|
||||
In the HybridParallelPlugin plugin, you can determine the training precision by setting the precision parameter, which supports three types: 'fp16', 'bf16', and 'fp32'. 'fp16' and 'bf16' are half-precision types. Half-precision is used in two scenarios in the HybridParallelPlugin:
|
||||
1. When using zero-data parallelism, you should set it to half-precision.
|
||||
2. When specifying the use of AMP (Automatic Mixed Precision) for training.
|
||||
You can set related parameters when using half-precision.
|
||||
`initial_scale` (float, optional): Initial loss scaling factor for AMP. Default value is 2**16.
|
||||
`min_scale` (float, optional): Minimum loss scaling factor for AMP. Default value is 1.
|
||||
`growth_factor` (float, optional): Multiplicative factor used to increase the loss scaling factor when using AMP. Default value is 2.
|
||||
`backoff_factor` (float, optional): Multiplicative factor used to decrease the loss scaling factor when using AMP. Default value is 0.5.
|
||||
`growth_interval` (integer, optional): Number of steps to increase the loss scaling factor when using AMP, in cases where there is no overflow. Default value is 1000.
|
||||
`hysteresis` (integer, optional): Number of overflows required before reducing the loss scaling factor when using AMP. Default value is 2.
|
||||
`max_scale` (float, optional): Maximum loss scaling factor for AMP. Default value is 2**32.
|
||||
Plugin example when using amp:
|
||||
```python
|
||||
# start training
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
hooks=hook_list,
|
||||
display_progress=True,
|
||||
test_interval=1
|
||||
)
|
||||
```
|
||||
|
||||
### Start training
|
||||
`DATA` is the filepath where CIFAR-10 dataset will be automatically downloaded and stored.
|
||||
|
||||
`<NUM_GPUs>` is the number of GPUs you want to use to train ViT on CIFAR-10 with data parallelism.
|
||||
|
||||
```bash
|
||||
export DATA=<path_to_data>
|
||||
# If your torch >= 1.10.0
|
||||
torchrun --standalone --nproc_per_node <NUM_GPUs> train_dp.py --config ./configs/config_data_parallel.py
|
||||
# If your torch >= 1.9.0
|
||||
# python -m torch.distributed.run --standalone --nproc_per_node= <NUM_GPUs> train_dp.py --config ./configs/config_data_parallel.py
|
||||
# Otherwise
|
||||
# python -m torch.distributed.launch --nproc_per_node <NUM_GPUs> --master_addr <node_name> --master_port 29500 train_dp.py --config ./configs/config.py
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Pipeline Parallelism
|
||||
Aside from data parallelism, Colossal-AI also support pipeline parallelism. In specific, Colossal-AI uses 1F1B pipeline introduced by NVIDIA. For more details, you can view the related [documents](https://www.colossalai.org/tutorials/features/pipeline_parallel).
|
||||
|
||||
### Define your configuration file(`hybrid_parallel/configs/vit_pipeline.py`)
|
||||
To apply pipeline parallel on the data parallel basis, you only need to add a **parallel dict**
|
||||
```python
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
parallel = dict(
|
||||
pipeline=2
|
||||
)
|
||||
# pipeline config
|
||||
NUM_MICRO_BATCHES = parallel['pipeline']
|
||||
TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
clip_grad_norm = 1.0
|
||||
```
|
||||
|
||||
Other configs:
|
||||
```python
|
||||
# hyper parameters
|
||||
# BATCH_SIZE is as per GPU
|
||||
# global batch size = BATCH_SIZE x data parallel size
|
||||
BATCH_SIZE = 256
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
NUM_EPOCHS = 300
|
||||
WARMUP_EPOCHS = 32
|
||||
|
||||
# model config
|
||||
IMG_SIZE = 224
|
||||
PATCH_SIZE = 16
|
||||
HIDDEN_SIZE = 768
|
||||
DEPTH = 12
|
||||
NUM_HEADS = 12
|
||||
MLP_RATIO = 4
|
||||
NUM_CLASSES = 10
|
||||
CHECKPOINT = True
|
||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token
|
||||
```
|
||||
|
||||
### Build pipeline model (`/hybrid_parallel/model/vit.py`)
|
||||
Colossal-AI provides two methods to build a pipeline model from the existing model.
|
||||
- `colossalai.legacy.builder.build_pipeline_model_from_cfg`
|
||||
- `colossalai.legacy.builder.build_pipeline_model`
|
||||
|
||||
Besides, you can also build a pipeline model from scratch with Colossal-AI.
|
||||
```python
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import inspect
|
||||
import torch
|
||||
from colossalai import nn as col_nn
|
||||
from colossalai.legacy.registry import LAYERS, MODELS
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.legacy.builder.pipeline import partition_uniform
|
||||
from torch import dtype, nn
|
||||
from model_zoo.vit.vit import ViTBlock, ViTEmbedding, ViTHead
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
class PipelineVisionTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
dim: int = 768,
|
||||
mlp_ratio: int = 4,
|
||||
attention_dropout: float = 0.,
|
||||
dropout: float = 0.1,
|
||||
drop_path: float = 0.,
|
||||
layernorm_epsilon: float = 1e-6,
|
||||
activation: Callable = nn.functional.gelu,
|
||||
representation_size: int = None,
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch',
|
||||
first_stage=True,
|
||||
last_stage=True,
|
||||
start_idx=None,
|
||||
end_idx=None,):
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
|
||||
if first_stage:
|
||||
embed = ViTEmbedding(img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embedding_dim=dim,
|
||||
dropout=dropout,
|
||||
dtype=dtype,
|
||||
init_method=init_method)
|
||||
layers.append(embed)
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
|
||||
|
||||
if start_idx is None and end_idx is None:
|
||||
start_idx = 0
|
||||
end_idx = depth
|
||||
|
||||
blocks = [
|
||||
ViTBlock(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout=attention_dropout,
|
||||
dropout=dropout,
|
||||
drop_path=dpr[i],
|
||||
activation=activation,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
checkpoint=checkpoint,
|
||||
init_method=init_method,
|
||||
) for i in range(start_idx, end_idx)
|
||||
]
|
||||
layers.extend(blocks)
|
||||
|
||||
if last_stage:
|
||||
norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
|
||||
head = ViTHead(dim=dim,
|
||||
num_classes=num_classes,
|
||||
representation_size=representation_size,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
init_method=init_method)
|
||||
layers.extend([norm, head])
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
*layers
|
||||
plugin = HybridParallelPlugin(
|
||||
precision="fp16",
|
||||
initial_scale=1,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.layers(x)
|
||||
return x
|
||||
|
||||
|
||||
def _filter_kwargs(func, kwargs):
|
||||
sig = inspect.signature(func)
|
||||
return {k: v for k, v in kwargs.items() if k in sig.parameters}
|
||||
|
||||
|
||||
def _build_pipeline_vit(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||||
logger = get_dist_logger()
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
else:
|
||||
pipeline_size = 1
|
||||
pipeline_rank = 0
|
||||
rank = gpc.get_global_rank()
|
||||
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
|
||||
models = []
|
||||
|
||||
for start, end in parts:
|
||||
kwargs['first_stage'] = start == 0
|
||||
kwargs['last_stage'] = end == num_layers
|
||||
kwargs['start_idx'] = start
|
||||
kwargs['end_idx'] = end
|
||||
logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
|
||||
chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device)
|
||||
models.append(chunk)
|
||||
if len(models) == 1:
|
||||
model = models[0]
|
||||
else:
|
||||
model = nn.ModuleList(models)
|
||||
return model
|
||||
|
||||
|
||||
def build_pipeline_vit(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||||
return _build_pipeline_vit(PipelineVisionTransformer, num_layers, num_chunks, device, **kwargs)
|
||||
```
|
||||
|
||||
### Modify train script (`/hybrid_parallel/train_with_cifar10.py`)
|
||||
|
||||
#### Import modules
|
||||
### Tensor parallelism
|
||||
`HybridParallelPlugin` achieves tensor parallelism through Shardformer. In this plugin, you can set the `tp_size` to determine the size of tensor parallel groups. Additionally, there are multiple parameters that can be configured to optimize tensor parallelism features when using this plugin:
|
||||
`enable_all_optimization` (boolean, optional): Whether to enable all optimization methods supported by Shardformer. Currently, all optimization methods include fused normalization, flash attention, and JIT. Default is False.
|
||||
`enable_fused_normalization` (boolean, optional): Whether to enable fused normalization in Shardformer. Default is False.
|
||||
`enable_flash_attention` (boolean, optional): Whether to enable flash attention in Shardformer. Default is False.
|
||||
`enable_jit_fused` (boolean, optional): Whether to enable JIT (Just-In-Time) fusion in Shardformer. Default is False.
|
||||
`enable_sequence_parallelism` (boolean): Whether to enable sequence parallelism in Shardformer. Default is False.
|
||||
`enable_sequence_overlap` (boolean): Whether to enable sequence overlap in Shardformer. Default is False.
|
||||
Example of a tensor parallelism plugin:
|
||||
```python
|
||||
from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
|
||||
PipelineSchedule)
|
||||
from colossalai.utils import MultiTimer
|
||||
import os
|
||||
|
||||
import colossalai
|
||||
|
||||
import torch
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.utils import is_using_pp, get_dataloader
|
||||
from model.vit import build_pipeline_vit
|
||||
from model_zoo.vit.vit import _create_vit_model
|
||||
from tqdm import tqdm
|
||||
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
```
|
||||
|
||||
#### Launch Colossal-AI
|
||||
`colossalai.utils.is_using_pp` can help check whether pipeline parallelism is required in config file.
|
||||
|
||||
```python
|
||||
# initialize distributed setting
|
||||
parser = colossalai.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# launch from torch
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
logger.info("initialized distributed environment", ranks=[0])
|
||||
|
||||
if hasattr(gpc.config, 'LOG_PATH'):
|
||||
if gpc.get_global_rank() == 0:
|
||||
log_path = gpc.config.LOG_PATH
|
||||
if not os.path.exists(log_path):
|
||||
os.mkdir(log_path)
|
||||
logger.log_to_file(log_path)
|
||||
|
||||
use_pipeline = is_using_pp()
|
||||
```
|
||||
|
||||
#### Define model
|
||||
|
||||
```python
|
||||
# create model
|
||||
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
|
||||
patch_size=gpc.config.PATCH_SIZE,
|
||||
dim=gpc.config.HIDDEN_SIZE,
|
||||
depth=gpc.config.DEPTH,
|
||||
num_heads=gpc.config.NUM_HEADS,
|
||||
mlp_ratio=gpc.config.MLP_RATIO,
|
||||
num_classes=gpc.config.NUM_CLASSES,
|
||||
init_method='jax',
|
||||
checkpoint=gpc.config.CHECKPOINT)
|
||||
|
||||
if use_pipeline:
|
||||
model = build_pipeline_vit(num_layers=model_kwargs['depth'], num_chunks=1, **model_kwargs)
|
||||
else:
|
||||
model = _create_vit_model(**model_kwargs)
|
||||
```
|
||||
|
||||
#### Count number of parameters
|
||||
|
||||
You can count model parameters on different pipeline stages easily.
|
||||
|
||||
```
|
||||
# count number of parameters
|
||||
total_numel = 0
|
||||
for p in model.parameters():
|
||||
total_numel += p.numel()
|
||||
if not gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
pipeline_stage = 0
|
||||
else:
|
||||
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
|
||||
```
|
||||
|
||||
#### Build dataloader, optimizer, etc.
|
||||
|
||||
```python
|
||||
def build_cifar(batch_size):
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(224, pad_if_needed=True),
|
||||
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
transform_test = transforms.Compose([
|
||||
transforms.Resize(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
|
||||
train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train)
|
||||
test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test)
|
||||
train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True)
|
||||
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True)
|
||||
return train_dataloader, test_dataloader
|
||||
|
||||
|
||||
# create dataloaders
|
||||
train_dataloader , test_dataloader = build_cifar()
|
||||
|
||||
# create loss function
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
||||
# create optimizer
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
# create lr scheduler
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||
total_steps=gpc.config.NUM_EPOCHS,
|
||||
warmup_steps=gpc.config.WARMUP_EPOCHS)
|
||||
```
|
||||
|
||||
#### Start Colossal-AI engine
|
||||
|
||||
```python
|
||||
# initialize
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader)
|
||||
|
||||
logger.info("Engine is built", ranks=[0])
|
||||
```
|
||||
|
||||
#### Train: based on engine
|
||||
|
||||
In the data parallelism example, we show how to train a model with Trainer API. We can also directly train a model based on engine. In this way, you can customize your training with more features.
|
||||
|
||||
```python
|
||||
data_iter = iter(train_dataloader)
|
||||
|
||||
for epoch in range(gpc.config.NUM_EPOCHS):
|
||||
# training
|
||||
engine.train()
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
description = 'Epoch {} / {}'.format(
|
||||
epoch,
|
||||
gpc.config.NUM_EPOCHS
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=4,
|
||||
enable_all_optimization=True
|
||||
)
|
||||
progress = tqdm(range(len(train_dataloader)), desc=description)
|
||||
```
|
||||
### Pipeline Parallelism
|
||||
|
||||
`HybridParallelPlugin` determines the size of pipeline parallelism groups by setting `pp_size`. `num_microbatches` is used to specify the number of microbatches into which the entire batch is divided during pipeline parallelism, and `microbatch_size` can be set to define the size of these microbatches. The plugin will prioritize using `num_microbatches` to determine the microbatch configuration.
|
||||
Example of a plugin for pipeline parallelism:
|
||||
```python
|
||||
plugin = HybridParallelPlugin(
|
||||
pp_size=4,
|
||||
num_microbatches=None,
|
||||
microbatch_size=1
|
||||
)
|
||||
```
|
||||
### Data Parallelism
|
||||
The `HybridParallelPlugin`'s data parallelism includes both the zero-dp series and Torch DDP. When `zero_stage` is set to 0 (the default), it means using Torch DDP. Please note that Torch DDP conflicts with pipeline parallelism and cannot be used together. When `zero_stage` is set to 1, it indicates the use of the zero1 strategy. When `zero_stage` is set to 2, it implies the use of the zero2 strategy. The zero2 strategy also cannot be used together with pipeline parallelism. If you want to use zero3, please use the [`GeminiPlugin`](../basics/booster_plugins.md).
|
||||
When using data parallelism with the zero series, please set the training precision to half-precision. If you haven't specified the use of zero or pipeline parallelism, and if `world_size//(tp_size*pp_size)` is greater than 1, the HybridParallelPlugin will automatically enable the Torch DDP parallel strategy for you.
|
||||
Here are some related parameters for configuring Torch DDP:
|
||||
`broadcast_buffers` (boolean, optional): Whether to broadcast buffers at the beginning of training when using DDP. Default is True.
|
||||
`ddp_bucket_cap_mb` (integer, optional): Size of the bucket (in MB) when using DDP. Default is 25.
|
||||
`find_unused_parameters` (boolean, optional): Whether to search for unused parameters when using DDP. Default is False.
|
||||
`check_reduction` (boolean, optional): Whether to check the reduction operation when using DDP. Default is False.
|
||||
`gradient_as_bucket_view` (boolean, optional): Whether to use gradients as bucket views when using DDP. Default is False.
|
||||
`static_graph` (boolean, optional): Whether to use a static graph when using DDP. Default is False.
|
||||
Example of a plugin for Torch DDP.
|
||||
```python
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=2,
|
||||
pp_size=1,
|
||||
zero_stage=0,
|
||||
precision="fp16",
|
||||
initial_scale=1,
|
||||
)
|
||||
```
|
||||
If there are 4 parallel processes, the parallel group size for Torch DDP is 2.
|
||||
ZeRO-related parameters:
|
||||
`zero_bucket_size_in_m` (integer, optional): The bucket size for gradient reduction in megabytes when using ZeRO. Default is 12.
|
||||
`cpu_offload` (boolean, optional): Whether to enable cpu_offload when using ZeRO. Default is False.
|
||||
`communication_dtype` (torch data type, optional): The data type for communication when using ZeRO. If not specified, the data type of the parameters will be used. Default is None.
|
||||
`overlap_communication` (boolean, optional): Whether to overlap communication and computation when using ZeRO. Default is True.
|
||||
Example of a plugin for ZERO1.
|
||||
```python
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
zero_stage=1,
|
||||
cpu_offload=True,
|
||||
precision="fp16",
|
||||
initial_scale=1,
|
||||
)
|
||||
```
|
||||
|
||||
### Hybrid Parallelism
|
||||
You can refer to the above-mentioned strategies to customize an appropriate hybrid parallelism strategy. And use this plugin to define a booster.
|
||||
```python
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=TP_SIZE,
|
||||
pp_size=PP_SIZE,
|
||||
num_microbatches=None,
|
||||
microbatch_size=1,
|
||||
enable_all_optimization=True,
|
||||
precision="fp16",
|
||||
initial_scale=1,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
```
|
||||
Next, we use `booster.boost` to inject the features encapsulated by the plugin into the model training components.
|
||||
```python
|
||||
model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost(
|
||||
model=model, optimizer=optimizer, criterion=criterion, dataloader=train_dataloader, lr_scheduler=lr_scheduler
|
||||
)
|
||||
```
|
||||
## Train ViT using hybrid parallelism.
|
||||
Finally, we can use the hybrid parallelism strategy to train the model. Let's first define a training function that describes the training process. It's important to note that if the pipeline parallelism strategy is used, you should call `booster.execute_pipeline` to perform the model training. This function will invoke the `scheduler` to manage the model's forward and backward operations.
|
||||
```python
|
||||
def run_forward_backward(
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
data_iter: Iterator,
|
||||
booster: Booster,
|
||||
):
|
||||
if optimizer is not None:
|
||||
optimizer.zero_grad()
|
||||
if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1:
|
||||
# run pipeline forward backward when enabling pp in hybrid parallel plugin
|
||||
output_dict = booster.execute_pipeline(
|
||||
data_iter, model, criterion, optimizer, return_loss=True, return_outputs=True
|
||||
)
|
||||
loss, outputs = output_dict["loss"], output_dict["outputs"]
|
||||
else:
|
||||
progress = range(len(train_dataloader))
|
||||
for _ in progress:
|
||||
engine.zero_grad()
|
||||
engine.execute_schedule(data_iter, return_output_label=False)
|
||||
engine.step()
|
||||
lr_scheduler.step()
|
||||
batch = next(data_iter)
|
||||
batch = move_to_cuda(batch, torch.cuda.current_device())
|
||||
outputs = model(**batch)
|
||||
loss = criterion(outputs, None)
|
||||
if optimizer is not None:
|
||||
booster.backward(loss, optimizer)
|
||||
|
||||
def train_epoch(
|
||||
epoch: int,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
lr_scheduler: LRScheduler,
|
||||
dataloader: DataLoader,
|
||||
booster: Booster,
|
||||
coordinator: DistCoordinator,
|
||||
):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
num_steps = len(dataloader)
|
||||
data_iter = iter(dataloader)
|
||||
enable_pbar = coordinator.is_master()
|
||||
if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1:
|
||||
# when using pp, only the last stage of master pipeline (dp_rank and tp_rank are both zero) shows pbar
|
||||
tp_rank = dist.get_rank(booster.plugin.tp_group)
|
||||
dp_rank = dist.get_rank(booster.plugin.dp_group)
|
||||
enable_pbar = tp_rank == 0 and dp_rank == 0 and booster.plugin.stage_manager.is_last_stage()
|
||||
model.train()
|
||||
|
||||
with tqdm(range(num_steps), desc=f"Epoch [{epoch + 1}]", disable=not enable_pbar) as pbar:
|
||||
for _ in pbar:
|
||||
loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
# Print batch loss
|
||||
if enable_pbar:
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
```
|
||||
|
||||
### Start training
|
||||
```bash
|
||||
export DATA=<path_to_dataset>
|
||||
# If your torch >= 1.10.0
|
||||
torchrun --standalone --nproc_per_node <NUM_GPUs> train_hybrid.py --config ./configs/config_pipeline_parallel.py
|
||||
# If your torch >= 1.9.0
|
||||
# python -m torch.distributed.run --standalone --nproc_per_node= <NUM_GPUs> train_hybrid.py --config ./configs/config_pipeline_parallel.py
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
## Tensor Parallelism and Hybrid Parallelism
|
||||
Tensor parallelism partitions each weight parameter across multiple devices in order to reduce memory load. Colossal-AI support 1D, 2D, 2.5D and 3D tensor parallelism. Besides, you can combine tensor parallelism with pipeline parallelism and data parallelism to reach hybrid parallelism. Colossal-AI also provides an easy way to apply tensor parallelism and hybrid parallelism. On the basis of pipeline parallelism, a few lines of code changing in config file is all you need.
|
||||
|
||||
### Define your configuration file(`/hybrid_parallel/configs/vit_1d_tp2_pp2.py`)
|
||||
To use tensor parallelism, you only need to add related information to the **parallel dict**. To be specific, `TENSOR_PARALLEL_MODE` can be '1d', '2d', '2.5d', '3d'. And the size of different parallelism should satisfy: `#GPUs = pipeline parallel size x tensor parallel size x data parallel size`. `data parallel size` will automatically computed after you specify the number of GPUs, pipeline parallel size and tensor parallel size.
|
||||
|
||||
Start training the model.
|
||||
```python
|
||||
from colossalai.amp import AMP_TYPE
|
||||
# parallel setting
|
||||
TENSOR_PARALLEL_SIZE = 2
|
||||
TENSOR_PARALLEL_MODE = '1d'
|
||||
|
||||
parallel = dict(
|
||||
pipeline=2,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE)
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
|
||||
# pipeline config
|
||||
NUM_MICRO_BATCHES = parallel['pipeline']
|
||||
TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
```
|
||||
|
||||
Other configs:
|
||||
```python
|
||||
# hyper parameters
|
||||
# BATCH_SIZE is as per GPU
|
||||
# global batch size = BATCH_SIZE x data parallel size
|
||||
BATCH_SIZE = 256
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
NUM_EPOCHS = 300
|
||||
WARMUP_EPOCHS = 32
|
||||
|
||||
# model config
|
||||
IMG_SIZE = 224
|
||||
PATCH_SIZE = 16
|
||||
HIDDEN_SIZE = 768
|
||||
DEPTH = 12
|
||||
NUM_HEADS = 12
|
||||
MLP_RATIO = 4
|
||||
NUM_CLASSES = 10
|
||||
CHECKPOINT = True
|
||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token
|
||||
```
|
||||
|
||||
### Start training
|
||||
```bash
|
||||
export DATA=<path_to_dataset>
|
||||
# If your torch >= 1.10.0
|
||||
torchrun --standalone --nproc_per_node <NUM_GPUs> train_hybrid.py --config ./configs/config_hybrid_parallel.py
|
||||
# If your torch >= 1.9.0
|
||||
# python -m torch.distributed.run --standalone --nproc_per_node= <NUM_GPUs> train_hybrid.py --config ./configs/config_hybrid_parallel.py
|
||||
for epoch in range(NUM_EPOCH):
|
||||
train_epoch(epoch, model, optimizer, criterion, lr_scheduler, train_dataloader, booster, coordinator)
|
||||
```
|
||||
<!-- doc-test-command: echo -->
|
||||
|
Reference in New Issue
Block a user