[example] update ViT example using booster api (#3940)

This commit is contained in:
Baizhou Zhang
2023-06-12 15:02:27 +08:00
committed by GitHub
parent 1aadeedeea
commit b3ab7fbabf
17 changed files with 582 additions and 598 deletions

View File

@@ -67,17 +67,8 @@ def main():
colo_memory_cap(args.mem_cap)
# Build OPT model
# Initialize the model under ColoInitContext if using GeminiPlugin
config = AutoConfig.from_pretrained(args.model_name_or_path)
if args.plugin == 'gemini':
shard_pg = ProcessGroup(tp_degree=world_size)
default_dist_spec = ShardSpec([-1], [world_size])
with ColoInitContext(device='cpu',
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = OPTForCausalLM(config)
else:
model = OPTForCausalLM(config)
model = OPTForCausalLM(config=config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
@@ -91,10 +82,10 @@ def main():
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(device=get_current_device(),
placement_policy='cpu',
pin_memory=True,
strict_ddp_mode=True,
initial_scale=2**5)
placement_policy='cpu',
pin_memory=True,
strict_ddp_mode=True,
initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])

View File

@@ -74,17 +74,8 @@ def main():
transformers.utils.logging.set_verbosity_error()
# Build OPT model
# Initialize the model under ColoInitContext if using GeminiPlugin
config = AutoConfig.from_pretrained(args.model_name_or_path)
if args.plugin == 'gemini':
shard_pg = ProcessGroup(tp_degree=world_size)
default_dist_spec = ShardSpec([-1], [world_size])
with ColoInitContext(device='cpu',
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = OPTForCausalLM(config)
else:
model = OPTForCausalLM(config)
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
@@ -116,7 +107,9 @@ def main():
collate_fn=netflix_collator)
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))
optimizer = HybridAdam(model.parameters(),
lr=(args.learning_rate * world_size),
weight_decay=args.weight_decay)
# Set lr scheduler
total_steps = len(dataloader) * args.num_epoch