[Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
Edenzzzz
2024-07-19 10:10:08 +08:00
committed by GitHub
parent e86127925a
commit 8cc8f645cd
3 changed files with 47 additions and 21 deletions

View File

@@ -1,4 +1,5 @@
import time
from contextlib import nullcontext
import torch
import tqdm
@@ -8,9 +9,11 @@ from transformers import AutoConfig, OPTForCausalLM
from transformers.utils.versions import require_version
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
@@ -62,14 +65,6 @@ def main():
if args.mem_cap > 0:
colo_memory_cap(args.mem_cap)
# Build OPT model
config = AutoConfig.from_pretrained(args.model_name_or_path)
model = OPTForCausalLM(config=config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Set plugin
booster_kwargs = {}
if args.plugin == "torch_ddp_fp16":
@@ -82,6 +77,19 @@ def main():
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Build OPT model
init_ctx = (
LazyInitContext(default_device=get_accelerator().get_current_device())
if isinstance(plugin, (GeminiPlugin))
else nullcontext()
)
config = AutoConfig.from_pretrained(args.model_name_or_path)
with init_ctx:
model = OPTForCausalLM(config=config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=args.learning_rate)