[example] update ViT example using booster api (#3940)

This commit is contained in:
Baizhou Zhang
2023-06-12 15:02:27 +08:00
committed by GitHub
parent 1aadeedeea
commit b3ab7fbabf
17 changed files with 582 additions and 598 deletions

View File

@@ -74,17 +74,8 @@ def main():
transformers.utils.logging.set_verbosity_error()
# Build OPT model
# Initialize the model under ColoInitContext if using GeminiPlugin
config = AutoConfig.from_pretrained(args.model_name_or_path)
if args.plugin == 'gemini':
shard_pg = ProcessGroup(tp_degree=world_size)
default_dist_spec = ShardSpec([-1], [world_size])
with ColoInitContext(device='cpu',
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = OPTForCausalLM(config)
else:
model = OPTForCausalLM(config)
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
@@ -116,7 +107,9 @@ def main():
collate_fn=netflix_collator)
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))
optimizer = HybridAdam(model.parameters(),
lr=(args.learning_rate * world_size),
weight_decay=args.weight_decay)
# Set lr scheduler
total_steps = len(dataloader) * args.num_epoch