mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-28 16:28:10 +00:00
[shardformer] update bert finetune example with HybridParallelPlugin (#4584)
* [shardformer] fix opt test hanging * fix * test * test * test * fix test * fix test * remove print * add fix * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] fix epoch change * [shardformer] broadcast add pp group * [shardformer] fix opt test hanging * fix * test * test * [shardformer] zero1+pp and the corresponding tests (#4517) * pause * finish pp+zero1 * Update test_shard_vit.py * [shardformer/fix overlap bug] fix overlap bug, add overlap as an option in shardco… (#4516) * fix overlap bug and support bert, add overlap as an option in shardconfig * support overlap for chatglm and bloom * [shardformer] fix emerged bugs after updating transformers (#4526) * test * fix test * fix test * remove print * add fix * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] Add overlap support for gpt2 (#4535) * add overlap support for gpt2 * remove unused code * remove unused code * [shardformer] support pp+tp+zero1 tests (#4531) * [shardformer] fix opt test hanging * fix * test * test * test * fix test * fix test * remove print * add fix * [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] fix submodule replacement bug when enabling pp (#4544) * [shardformer] support sharded optimizer checkpointIO of HybridParallelPlugin (#4540) * implement sharded optimizer saving * add more param info * finish implementation of sharded optimizer saving * fix bugs in optimizer sharded saving * add pp+zero test * param group loading * greedy loading of optimizer * fix bug when loading * implement optimizer sharded saving * add optimizer test & arrange checkpointIO utils * fix gemini sharding state_dict * add verbose option * add loading of master params * fix typehint * fix master/working mapping in fp16 amp * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] fix epoch change * [shardformer] broadcast add pp group * rebase feature/shardformer * update pipeline * [shardformer] fix * [shardformer] fix * [shardformer] bert finetune fix * [shardformer] add all_reduce operation to loss add all_reduce operation to loss * [shardformer] make compatible with pytree. make compatible with pytree. * [shardformer] disable tp disable tp * [shardformer] add 3d plugin to ci test * [shardformer] update num_microbatches to None * [shardformer] update microbatchsize * [shardformer] update assert * update scheduler * update scheduler --------- Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn>
This commit is contained in:
parent
24c0768795
commit
0a94fcd351
@ -325,7 +325,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
self.schedule = None
|
self.schedule = None
|
||||||
assert zero_stage in (0, 1, 2)
|
assert zero_stage in (0, 1, 2)
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism'
|
assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
|
||||||
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
|
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
|
||||||
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
|
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
|
||||||
self.schedule = OneForwardOneBackwardSchedule(self.stage_manager,
|
self.schedule = OneForwardOneBackwardSchedule(self.stage_manager,
|
||||||
|
@ -46,6 +46,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||||||
self.batch: Optional[Any] = None
|
self.batch: Optional[Any] = None
|
||||||
self.batch_size: Optional[int] = None
|
self.batch_size: Optional[int] = None
|
||||||
self.microbatch_offset: Optional[int] = None
|
self.microbatch_offset: Optional[int] = None
|
||||||
|
self._use_microbatch_size = num_microbatches is None
|
||||||
|
|
||||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||||
"""Load a batch from data iterator.
|
"""Load a batch from data iterator.
|
||||||
@ -60,7 +61,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||||||
self.batch = batch
|
self.batch = batch
|
||||||
self.batch_size = get_batch_size(batch)
|
self.batch_size = get_batch_size(batch)
|
||||||
self.microbatch_offset = 0
|
self.microbatch_offset = 0
|
||||||
if self.num_microbatches is not None:
|
if not self._use_microbatch_size:
|
||||||
assert self.batch_size % self.num_microbatches == 0, \
|
assert self.batch_size % self.num_microbatches == 0, \
|
||||||
"Batch size should divided by the number of microbatches"
|
"Batch size should divided by the number of microbatches"
|
||||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from typing import List, Union
|
from contextlib import nullcontext
|
||||||
|
from typing import Callable, List, Union
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from data import GLUEDataBuilder
|
from data import GLUEDataBuilder
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Adam, Optimizer
|
||||||
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@ -18,8 +20,9 @@ from transformers import (
|
|||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
@ -32,14 +35,26 @@ LEARNING_RATE = 2.4e-5
|
|||||||
WEIGHT_DECAY = 0.01
|
WEIGHT_DECAY = 0.01
|
||||||
WARMUP_FRACTION = 0.1
|
WARMUP_FRACTION = 0.1
|
||||||
|
|
||||||
|
output_transform_fn = lambda x: x
|
||||||
|
criterion = lambda x: x.loss
|
||||||
|
|
||||||
|
|
||||||
def move_to_cuda(batch):
|
def move_to_cuda(batch):
|
||||||
return {k: v.cuda() for k, v in batch.items()}
|
return {k: v.cuda() for k, v in batch.items()}
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str,
|
def evaluate_model(
|
||||||
eval_splits: List[str], coordinator: DistCoordinator):
|
model: nn.Module,
|
||||||
|
optimizer,
|
||||||
|
criterion,
|
||||||
|
test_dataloader: Union[DataLoader, List[DataLoader]],
|
||||||
|
num_labels: int,
|
||||||
|
task_name: str,
|
||||||
|
eval_splits: List[str],
|
||||||
|
booster: Booster,
|
||||||
|
coordinator: DistCoordinator,
|
||||||
|
):
|
||||||
metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
|
metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -47,23 +62,66 @@ def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[Dat
|
|||||||
accum_loss = torch.zeros(1, device=get_current_device())
|
accum_loss = torch.zeros(1, device=get_current_device())
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
batch = move_to_cuda(batch)
|
batch = move_to_cuda(batch)
|
||||||
outputs = model(**batch)
|
|
||||||
val_loss, logits = outputs[:2]
|
|
||||||
accum_loss.add_(val_loss)
|
|
||||||
|
|
||||||
if num_labels > 1:
|
|
||||||
preds = torch.argmax(logits, axis=1)
|
|
||||||
elif num_labels == 1:
|
|
||||||
preds = logits.squeeze()
|
|
||||||
|
|
||||||
labels = batch["labels"]
|
labels = batch["labels"]
|
||||||
|
batch_size = batch["input_ids"].shape[0]
|
||||||
|
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
|
||||||
|
pg_mesh = booster.plugin.pg_mesh
|
||||||
|
pp_group = booster.plugin.pp_group
|
||||||
|
current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
|
||||||
|
current_rank = dist.get_rank()
|
||||||
|
#TODO pass dataloader to execute_pipeline directly
|
||||||
|
batch = iter([batch])
|
||||||
|
outputs = booster.execute_pipeline(batch,
|
||||||
|
model,
|
||||||
|
criterion,
|
||||||
|
optimizer,
|
||||||
|
return_loss=True,
|
||||||
|
return_outputs=True)
|
||||||
|
|
||||||
metric.add_batch(predictions=preds, references=labels)
|
if booster.plugin.stage_manager.is_last_stage():
|
||||||
|
val_loss = outputs["loss"]
|
||||||
|
|
||||||
|
logits = outputs["outputs"]["logits"]
|
||||||
|
|
||||||
|
accum_loss.add_(val_loss)
|
||||||
|
|
||||||
|
if num_labels > 1:
|
||||||
|
preds = torch.argmax(logits, axis=1)
|
||||||
|
elif num_labels == 1:
|
||||||
|
preds = logits.squeeze()
|
||||||
|
|
||||||
|
dist.broadcast(preds, src=current_rank, group=pp_group)
|
||||||
|
dist.broadcast(val_loss, src=current_rank, group=pp_group)
|
||||||
|
|
||||||
|
metric.add_batch(predictions=preds, references=labels)
|
||||||
|
elif current_rank in current_pp_group_ranks:
|
||||||
|
val_loss = torch.empty((1,), device=get_current_device())
|
||||||
|
preds = torch.empty((batch_size,), dtype=torch.int64, device=get_current_device())
|
||||||
|
|
||||||
|
dist.broadcast(preds, src=current_pp_group_ranks[-1], group=pp_group)
|
||||||
|
dist.broadcast(val_loss, src=current_pp_group_ranks[-1], group=pp_group)
|
||||||
|
|
||||||
|
accum_loss.add_(val_loss)
|
||||||
|
metric.add_batch(predictions=preds, references=labels)
|
||||||
|
|
||||||
|
else:
|
||||||
|
batch = move_to_cuda(batch)
|
||||||
|
outputs = model(**batch)
|
||||||
|
val_loss, logits = outputs[:2]
|
||||||
|
accum_loss.add_(val_loss)
|
||||||
|
|
||||||
|
if num_labels > 1:
|
||||||
|
preds = torch.argmax(logits, axis=1)
|
||||||
|
elif num_labels == 1:
|
||||||
|
preds = logits.squeeze()
|
||||||
|
|
||||||
|
metric.add_batch(predictions=preds, references=labels)
|
||||||
|
|
||||||
results = metric.compute()
|
results = metric.compute()
|
||||||
dist.all_reduce(accum_loss.div_(len(dataloader)))
|
dist.all_reduce(accum_loss.div_(len(dataloader)))
|
||||||
if coordinator.is_master():
|
if coordinator.is_master() and results is not None:
|
||||||
results['loss'] = accum_loss.item() / coordinator.world_size
|
results['loss'] = accum_loss.item() / coordinator.world_size
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
if isinstance(test_dataloader, DataLoader):
|
if isinstance(test_dataloader, DataLoader):
|
||||||
@ -77,25 +135,43 @@ def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[Dat
|
|||||||
return final_results
|
return final_results
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader,
|
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler,
|
||||||
booster: Booster, coordinator: DistCoordinator):
|
train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar:
|
is_pp_last_stage = hasattr(
|
||||||
|
booster.plugin,
|
||||||
|
"stage_manager") and booster.plugin.stage_manager is not None and booster.plugin.stage_manager.is_last_stage()
|
||||||
|
with tqdm(train_dataloader,
|
||||||
|
desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
|
||||||
|
disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar:
|
||||||
for batch in pbar:
|
for batch in pbar:
|
||||||
# Forward pass
|
# Forward pass
|
||||||
batch = move_to_cuda(batch)
|
batch = move_to_cuda(batch)
|
||||||
outputs = model(**batch)
|
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
|
||||||
loss = outputs[0]
|
#TODO pass train_dataloader to execute_pipeline directly
|
||||||
|
batch = iter([batch])
|
||||||
|
outputs = booster.execute_pipeline(batch,
|
||||||
|
model,
|
||||||
|
_criterion,
|
||||||
|
optimizer,
|
||||||
|
return_loss=True,
|
||||||
|
return_outputs=True)
|
||||||
|
# Backward and optimize
|
||||||
|
if booster.plugin.stage_manager.is_last_stage():
|
||||||
|
loss = outputs['loss']
|
||||||
|
pbar.set_postfix({'loss': loss.item()})
|
||||||
|
else:
|
||||||
|
outputs = model(**batch)
|
||||||
|
loss = _criterion(outputs, None)
|
||||||
|
# Backward
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
pbar.set_postfix({'loss': loss.item()})
|
||||||
|
|
||||||
# Backward and optimize
|
|
||||||
booster.backward(loss, optimizer)
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
# Print log info
|
|
||||||
pbar.set_postfix({'loss': loss.item()})
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# ==============================
|
# ==============================
|
||||||
@ -107,7 +183,7 @@ def main():
|
|||||||
'--plugin',
|
'--plugin',
|
||||||
type=str,
|
type=str,
|
||||||
default='torch_ddp',
|
default='torch_ddp',
|
||||||
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
|
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'],
|
||||||
help="plugin to use")
|
help="plugin to use")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_type",
|
"--model_type",
|
||||||
@ -116,6 +192,7 @@ def main():
|
|||||||
help="bert or albert",
|
help="bert or albert",
|
||||||
)
|
)
|
||||||
parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached")
|
parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached")
|
||||||
|
parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.model_type == 'bert':
|
if args.model_type == 'bert':
|
||||||
@ -145,6 +222,17 @@ def main():
|
|||||||
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
|
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
|
||||||
elif args.plugin == 'low_level_zero':
|
elif args.plugin == 'low_level_zero':
|
||||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||||
|
elif args.plugin == 'hybrid_parallel':
|
||||||
|
|
||||||
|
# modify the param accordingly for finetuning test cases
|
||||||
|
plugin = HybridParallelPlugin(tp_size=1,
|
||||||
|
pp_size=2,
|
||||||
|
num_microbatches=None,
|
||||||
|
microbatch_size=1,
|
||||||
|
enable_all_optimization=True,
|
||||||
|
zero_stage=1,
|
||||||
|
precision='fp16',
|
||||||
|
initial_scale=1)
|
||||||
|
|
||||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||||
|
|
||||||
@ -165,8 +253,9 @@ def main():
|
|||||||
# bert pretrained model
|
# bert pretrained model
|
||||||
|
|
||||||
cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
|
cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
|
||||||
|
|
||||||
if model_name == "bert-base-uncased":
|
if model_name == "bert-base-uncased":
|
||||||
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg)
|
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
|
||||||
elif model_name == "albert-xxlarge-v2":
|
elif model_name == "albert-xxlarge-v2":
|
||||||
model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg)
|
model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg)
|
||||||
else:
|
else:
|
||||||
@ -196,19 +285,27 @@ def main():
|
|||||||
num_training_steps=total_steps,
|
num_training_steps=total_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _criterion(outputs, inputs):
|
||||||
|
outputs = output_transform_fn(outputs)
|
||||||
|
loss = criterion(outputs)
|
||||||
|
return loss
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Boost with ColossalAI
|
# Boost with ColossalAI
|
||||||
# ==============================
|
# ==============================
|
||||||
model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler)
|
model, optimizer, _criterion, _, lr_scheduler = booster.boost(model,
|
||||||
|
optimizer,
|
||||||
|
criterion=_criterion,
|
||||||
|
lr_scheduler=lr_scheduler)
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Train model
|
# Train model
|
||||||
# ==============================
|
# ==============================
|
||||||
for epoch in range(NUM_EPOCHS):
|
for epoch in range(NUM_EPOCHS):
|
||||||
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
|
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
|
||||||
|
|
||||||
results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits,
|
results = evaluate_model(model, optimizer, _criterion, test_dataloader, data_builder.num_labels, args.task,
|
||||||
coordinator)
|
data_builder.eval_splits, booster, coordinator)
|
||||||
|
|
||||||
if coordinator.is_master():
|
if coordinator.is_master():
|
||||||
print(results)
|
print(results)
|
||||||
|
@ -3,6 +3,6 @@ set -xe
|
|||||||
|
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
|
||||||
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do
|
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do
|
||||||
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert"
|
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert"
|
||||||
done
|
done
|
||||||
|
Loading…
Reference in New Issue
Block a user