mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[example] update vit example for hybrid parallel plugin (#4641)
* update vit example for hybrid plugin * reset tp/pp size * fix dataloader iteration bug * update optimizer passing in evaluation/add grad_accum * change criterion * wrap tqdm * change grad_accum to grad_checkpoint * fix pbar
This commit is contained in:
@@ -1,14 +1,14 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
import transformers
|
||||
from args import parse_benchmark_args
|
||||
from tqdm import tqdm
|
||||
from transformers import ViTConfig, ViTForImageClassification
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
@@ -24,7 +24,7 @@ def format_num(num: int, bytes=False):
|
||||
num /= factor
|
||||
|
||||
|
||||
def get_data(batch_size, num_labels, num_channels=3, height=224, width=224):
|
||||
def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224):
|
||||
pixel_values = torch.randn(batch_size,
|
||||
num_channels,
|
||||
height,
|
||||
@@ -32,7 +32,7 @@ def get_data(batch_size, num_labels, num_channels=3, height=224, width=224):
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.float)
|
||||
labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64)
|
||||
return pixel_values, labels
|
||||
return dict(pixel_values=pixel_values, labels=labels)
|
||||
|
||||
|
||||
def colo_memory_cap(size_in_GB):
|
||||
@@ -70,7 +70,8 @@ def main():
|
||||
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Set plugin
|
||||
booster_kwargs = {}
|
||||
@@ -82,34 +83,57 @@ def main():
|
||||
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
elif args.plugin == 'hybrid_parallel':
|
||||
plugin = HybridParallelPlugin(tp_size=2,
|
||||
pp_size=2,
|
||||
num_microbatches=None,
|
||||
microbatch_size=1,
|
||||
enable_all_optimization=True,
|
||||
precision='fp16',
|
||||
initial_scale=1)
|
||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||
|
||||
# Set optimizer
|
||||
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))
|
||||
|
||||
# Set criterion (loss function)
|
||||
def criterion(outputs, inputs):
|
||||
return outputs.loss
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
model, optimizer, _, _, _ = booster.boost(model, optimizer)
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion=criterion)
|
||||
|
||||
# Start training.
|
||||
logger.info(f"Start testing", ranks=[0])
|
||||
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
|
||||
|
||||
torch.cuda.synchronize()
|
||||
model.train()
|
||||
start_time = time.time()
|
||||
|
||||
for _ in range(args.max_train_steps):
|
||||
with tqdm(range(args.max_train_steps), desc="Training Step", disable=not coordinator.is_master()) as pbar:
|
||||
for _ in pbar:
|
||||
optimizer.zero_grad()
|
||||
batch = get_data_batch(args.batch_size, args.num_labels, 3, 224, 224)
|
||||
|
||||
pixel_values, labels = get_data(args.batch_size, args.num_labels, 3, 224, 224)
|
||||
optimizer.zero_grad()
|
||||
outputs = model(pixel_values=pixel_values, labels=labels)
|
||||
loss = outputs['loss']
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
|
||||
# run pipeline forward backward
|
||||
batch = iter([batch])
|
||||
outputs = booster.execute_pipeline(batch,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True)
|
||||
else:
|
||||
outputs = model(**batch)
|
||||
loss = criterion(outputs, None)
|
||||
# Backward
|
||||
booster.backward(loss, optimizer)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
progress_bar.update(1)
|
||||
optimizer.step()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Compute Statistics
|
||||
end_time = time.time()
|
||||
@@ -124,6 +148,8 @@ def main():
|
||||
f"maximum memory usage per gpu: {max_mem}.",
|
||||
ranks=[0])
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user