[example] update vit example for hybrid parallel plugin (#4641)

* update vit example for hybrid plugin

* reset tp/pp size

* fix dataloader iteration bug

* update optimizer passing in evaluation/add grad_accum

* change criterion

* wrap tqdm

* change grad_accum to grad_checkpoint

* fix pbar
This commit is contained in:
Baizhou Zhang
2023-09-07 17:38:45 +08:00
committed by GitHub
parent 660eed9124
commit 295b38fecf
10 changed files with 246 additions and 192 deletions

View File

@@ -1,14 +1,14 @@
import time
import torch
import tqdm
import transformers
from args import parse_benchmark_args
from tqdm import tqdm
from transformers import ViTConfig, ViTForImageClassification
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
@@ -24,7 +24,7 @@ def format_num(num: int, bytes=False):
num /= factor
def get_data(batch_size, num_labels, num_channels=3, height=224, width=224):
def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224):
pixel_values = torch.randn(batch_size,
num_channels,
height,
@@ -32,7 +32,7 @@ def get_data(batch_size, num_labels, num_channels=3, height=224, width=224):
device=torch.cuda.current_device(),
dtype=torch.float)
labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64)
return pixel_values, labels
return dict(pixel_values=pixel_values, labels=labels)
def colo_memory_cap(size_in_GB):
@@ -70,7 +70,8 @@ def main():
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
# Set plugin
booster_kwargs = {}
@@ -82,34 +83,57 @@ def main():
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
elif args.plugin == 'hybrid_parallel':
plugin = HybridParallelPlugin(tp_size=2,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
precision='fp16',
initial_scale=1)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))
# Set criterion (loss function)
def criterion(outputs, inputs):
return outputs.loss
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, _, _ = booster.boost(model, optimizer)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion=criterion)
# Start training.
logger.info(f"Start testing", ranks=[0])
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
torch.cuda.synchronize()
model.train()
start_time = time.time()
for _ in range(args.max_train_steps):
with tqdm(range(args.max_train_steps), desc="Training Step", disable=not coordinator.is_master()) as pbar:
for _ in pbar:
optimizer.zero_grad()
batch = get_data_batch(args.batch_size, args.num_labels, 3, 224, 224)
pixel_values, labels = get_data(args.batch_size, args.num_labels, 3, 224, 224)
optimizer.zero_grad()
outputs = model(pixel_values=pixel_values, labels=labels)
loss = outputs['loss']
booster.backward(loss, optimizer)
optimizer.step()
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
# run pipeline forward backward
batch = iter([batch])
outputs = booster.execute_pipeline(batch,
model,
criterion,
optimizer,
return_loss=True,
return_outputs=True)
else:
outputs = model(**batch)
loss = criterion(outputs, None)
# Backward
booster.backward(loss, optimizer)
torch.cuda.synchronize()
progress_bar.update(1)
optimizer.step()
torch.cuda.synchronize()
# Compute Statistics
end_time = time.time()
@@ -124,6 +148,8 @@ def main():
f"maximum memory usage per gpu: {max_mem}.",
ranks=[0])
torch.cuda.empty_cache()
if __name__ == "__main__":
main()