mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -25,19 +25,21 @@ def move_to_cuda(batch, device):
|
||||
return {k: v.to(device) for k, v in batch.items()}
|
||||
|
||||
|
||||
def run_forward_backward(model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor],
|
||||
data_iter: Iterator, booster: Booster):
|
||||
def run_forward_backward(
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
data_iter: Iterator,
|
||||
booster: Booster,
|
||||
):
|
||||
if optimizer is not None:
|
||||
optimizer.zero_grad()
|
||||
if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1:
|
||||
# run pipeline forward backward when enabling pp in hybrid parallel plugin
|
||||
output_dict = booster.execute_pipeline(data_iter,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True)
|
||||
loss, outputs = output_dict['loss'], output_dict['outputs']
|
||||
output_dict = booster.execute_pipeline(
|
||||
data_iter, model, criterion, optimizer, return_loss=True, return_outputs=True
|
||||
)
|
||||
loss, outputs = output_dict["loss"], output_dict["outputs"]
|
||||
else:
|
||||
batch = next(data_iter)
|
||||
batch = move_to_cuda(batch, torch.cuda.current_device())
|
||||
@@ -49,9 +51,16 @@ def run_forward_backward(model: nn.Module, optimizer: Optimizer, criterion: Call
|
||||
return loss, outputs
|
||||
|
||||
|
||||
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor],
|
||||
lr_scheduler: LRScheduler, dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
|
||||
|
||||
def train_epoch(
|
||||
epoch: int,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
lr_scheduler: LRScheduler,
|
||||
dataloader: DataLoader,
|
||||
booster: Booster,
|
||||
coordinator: DistCoordinator,
|
||||
):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
num_steps = len(dataloader)
|
||||
@@ -61,12 +70,11 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: C
|
||||
# when using pp, only the last stage of master pipeline (dp_rank and tp_rank are both zero) shows pbar
|
||||
tp_rank = dist.get_rank(booster.plugin.tp_group)
|
||||
dp_rank = dist.get_rank(booster.plugin.dp_group)
|
||||
enable_pbar = tp_rank == 0 and dp_rank == 0 \
|
||||
and booster.plugin.stage_manager.is_last_stage()
|
||||
enable_pbar = tp_rank == 0 and dp_rank == 0 and booster.plugin.stage_manager.is_last_stage()
|
||||
|
||||
model.train()
|
||||
|
||||
with tqdm(range(num_steps), desc=f'Epoch [{epoch + 1}]', disable=not enable_pbar) as pbar:
|
||||
with tqdm(range(num_steps), desc=f"Epoch [{epoch + 1}]", disable=not enable_pbar) as pbar:
|
||||
for _ in pbar:
|
||||
loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster)
|
||||
optimizer.step()
|
||||
@@ -74,13 +82,18 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: C
|
||||
|
||||
# Print batch loss
|
||||
if enable_pbar:
|
||||
pbar.set_postfix({'loss': loss.item()})
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_model(epoch: int, model: nn.Module, criterion: Callable[[Any, Any], torch.Tensor],
|
||||
eval_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
|
||||
|
||||
def evaluate_model(
|
||||
epoch: int,
|
||||
model: nn.Module,
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
eval_dataloader: DataLoader,
|
||||
booster: Booster,
|
||||
coordinator: DistCoordinator,
|
||||
):
|
||||
torch.cuda.synchronize()
|
||||
model.eval()
|
||||
accum_loss = torch.zeros(1, device=torch.cuda.current_device())
|
||||
@@ -99,13 +112,13 @@ def evaluate_model(epoch: int, model: nn.Module, criterion: Callable[[Any, Any],
|
||||
to_accum = to_accum and booster.plugin.stage_manager.is_last_stage()
|
||||
|
||||
if to_accum:
|
||||
accum_loss += (loss / len(eval_dataloader))
|
||||
accum_loss += loss / len(eval_dataloader)
|
||||
logits = outputs["logits"]
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
|
||||
labels = batch["labels"]
|
||||
total_num += batch["labels"].shape[0]
|
||||
accum_correct += (torch.sum(preds == labels))
|
||||
accum_correct += torch.sum(preds == labels)
|
||||
|
||||
dist.all_reduce(accum_loss)
|
||||
dist.all_reduce(total_num)
|
||||
@@ -113,13 +126,14 @@ def evaluate_model(epoch: int, model: nn.Module, criterion: Callable[[Any, Any],
|
||||
avg_loss = "{:.4f}".format(accum_loss.item())
|
||||
accuracy = "{:.4f}".format(accum_correct.item() / total_num.item())
|
||||
if coordinator.is_master():
|
||||
print(f"Evaluation result for epoch {epoch + 1}: \
|
||||
print(
|
||||
f"Evaluation result for epoch {epoch + 1}: \
|
||||
average_loss={avg_loss}, \
|
||||
accuracy={accuracy}.")
|
||||
accuracy={accuracy}."
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
args = parse_demo_args()
|
||||
|
||||
# Launch ColossalAI
|
||||
@@ -136,14 +150,14 @@ def main():
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
# Reset tp_size and pp_size to 1 if not using hybrid parallel.
|
||||
if args.plugin != 'hybrid_parallel':
|
||||
if args.plugin != "hybrid_parallel":
|
||||
args.tp_size = 1
|
||||
args.pp_size = 1
|
||||
|
||||
# Prepare Dataset
|
||||
image_processor = ViTImageProcessor.from_pretrained(args.model_name_or_path)
|
||||
train_dataset = BeansDataset(image_processor, args.tp_size, split='train')
|
||||
eval_dataset = BeansDataset(image_processor, args.tp_size, split='validation')
|
||||
train_dataset = BeansDataset(image_processor, args.tp_size, split="train")
|
||||
eval_dataset = BeansDataset(image_processor, args.tp_size, split="validation")
|
||||
num_labels = train_dataset.num_labels
|
||||
|
||||
# Load pretrained ViT model
|
||||
@@ -151,9 +165,9 @@ def main():
|
||||
config.num_labels = num_labels
|
||||
config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)}
|
||||
config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)}
|
||||
model = ViTForImageClassification.from_pretrained(args.model_name_or_path,
|
||||
config=config,
|
||||
ignore_mismatched_sizes=True)
|
||||
model = ViTForImageClassification.from_pretrained(
|
||||
args.model_name_or_path, config=config, ignore_mismatched_sizes=True
|
||||
)
|
||||
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||
|
||||
# Enable gradient checkpointing
|
||||
@@ -162,37 +176,35 @@ def main():
|
||||
|
||||
# Set plugin
|
||||
booster_kwargs = {}
|
||||
if args.plugin == 'torch_ddp_fp16':
|
||||
booster_kwargs['mixed_precision'] = 'fp16'
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
if args.plugin == "torch_ddp_fp16":
|
||||
booster_kwargs["mixed_precision"] = "fp16"
|
||||
if args.plugin.startswith("torch_ddp"):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
elif args.plugin == "low_level_zero":
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
elif args.plugin == 'hybrid_parallel':
|
||||
plugin = HybridParallelPlugin(tp_size=args.tp_size,
|
||||
pp_size=args.pp_size,
|
||||
num_microbatches=None,
|
||||
microbatch_size=1,
|
||||
enable_all_optimization=True,
|
||||
precision='fp16',
|
||||
initial_scale=1)
|
||||
elif args.plugin == "hybrid_parallel":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp_size,
|
||||
pp_size=args.pp_size,
|
||||
num_microbatches=None,
|
||||
microbatch_size=1,
|
||||
enable_all_optimization=True,
|
||||
precision="fp16",
|
||||
initial_scale=1,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Plugin with name {args.plugin} is not supported!")
|
||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||
|
||||
# Prepare dataloader
|
||||
train_dataloader = plugin.prepare_dataloader(train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=beans_collator)
|
||||
eval_dataloader = plugin.prepare_dataloader(eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=beans_collator)
|
||||
train_dataloader = plugin.prepare_dataloader(
|
||||
train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=beans_collator
|
||||
)
|
||||
eval_dataloader = plugin.prepare_dataloader(
|
||||
eval_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=beans_collator
|
||||
)
|
||||
|
||||
# Set optimizer
|
||||
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
|
||||
@@ -204,17 +216,15 @@ def main():
|
||||
# Set lr scheduler
|
||||
total_steps = len(train_dataloader) * args.num_epoch
|
||||
num_warmup_steps = int(args.warmup_ratio * total_steps)
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||
total_steps=(len(train_dataloader) * args.num_epoch),
|
||||
warmup_steps=num_warmup_steps)
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optimizer, total_steps=(len(train_dataloader) * args.num_epoch), warmup_steps=num_warmup_steps
|
||||
)
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
dataloader=train_dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost(
|
||||
model=model, optimizer=optimizer, criterion=criterion, dataloader=train_dataloader, lr_scheduler=lr_scheduler
|
||||
)
|
||||
|
||||
# Finetuning
|
||||
logger.info(f"Start finetuning", ranks=[0])
|
||||
|
Reference in New Issue
Block a user