mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[example] update ViT example using booster api (#3940)
This commit is contained in:
@@ -67,17 +67,8 @@ def main():
|
||||
colo_memory_cap(args.mem_cap)
|
||||
|
||||
# Build OPT model
|
||||
# Initialize the model under ColoInitContext if using GeminiPlugin
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
if args.plugin == 'gemini':
|
||||
shard_pg = ProcessGroup(tp_degree=world_size)
|
||||
default_dist_spec = ShardSpec([-1], [world_size])
|
||||
with ColoInitContext(device='cpu',
|
||||
default_dist_spec=default_dist_spec,
|
||||
default_pg=shard_pg):
|
||||
model = OPTForCausalLM(config)
|
||||
else:
|
||||
model = OPTForCausalLM(config)
|
||||
model = OPTForCausalLM(config=config)
|
||||
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||
|
||||
# Enable gradient checkpointing
|
||||
@@ -91,10 +82,10 @@ def main():
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(device=get_current_device(),
|
||||
placement_policy='cpu',
|
||||
pin_memory=True,
|
||||
strict_ddp_mode=True,
|
||||
initial_scale=2**5)
|
||||
placement_policy='cpu',
|
||||
pin_memory=True,
|
||||
strict_ddp_mode=True,
|
||||
initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||
|
@@ -74,17 +74,8 @@ def main():
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
# Build OPT model
|
||||
# Initialize the model under ColoInitContext if using GeminiPlugin
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
if args.plugin == 'gemini':
|
||||
shard_pg = ProcessGroup(tp_degree=world_size)
|
||||
default_dist_spec = ShardSpec([-1], [world_size])
|
||||
with ColoInitContext(device='cpu',
|
||||
default_dist_spec=default_dist_spec,
|
||||
default_pg=shard_pg):
|
||||
model = OPTForCausalLM(config)
|
||||
else:
|
||||
model = OPTForCausalLM(config)
|
||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
|
||||
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||
|
||||
# Enable gradient checkpointing
|
||||
@@ -116,7 +107,9 @@ def main():
|
||||
collate_fn=netflix_collator)
|
||||
|
||||
# Set optimizer
|
||||
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))
|
||||
optimizer = HybridAdam(model.parameters(),
|
||||
lr=(args.learning_rate * world_size),
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
# Set lr scheduler
|
||||
total_steps = len(dataloader) * args.num_epoch
|
||||
|
Reference in New Issue
Block a user