mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[Gemini] add GeminiAdamOptimizer (#1960)
This commit is contained in:
@@ -3,10 +3,8 @@ This example shows how to use Colossal-AI to run huggingface GPT training in dis
|
||||
|
||||
## GPT
|
||||
We use the GPT2 model from huggingface transformers. The input data is randonly generated.
|
||||
|
||||
## Our Modifications
|
||||
The `train_gpt_demo.py` provides three distributed plans, i.e. Colossal-AI, PyTorch DDP and ZeRO.
|
||||
The Colossal-AI leverages Tensor Parallel and Gemini.
|
||||
The `train_gpt_demo.py` provides three distributed plans, i.e. ColossalAI, PyTorch DDP and ZeRO.
|
||||
The ColossalAI leverages Tensor Parallel and Gemini.
|
||||
|
||||
## Quick Start
|
||||
You can launch training by using the following bash script.
|
||||
|
@@ -10,11 +10,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
|
||||
@@ -222,7 +223,7 @@ def main():
|
||||
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
|
||||
|
||||
# build GPT model
|
||||
with ColoInitContext(device='cuda', default_dist_spec=default_dist_spec, default_pg=default_pg):
|
||||
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
|
||||
model = gpt2_medium(checkpoint=True)
|
||||
|
||||
pg = default_pg
|
||||
@@ -232,8 +233,9 @@ def main():
|
||||
model = gemini_zero_dpp(model, pg, args.placement)
|
||||
|
||||
# build optimizer
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
|
||||
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)
|
||||
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
# optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
|
||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||
|
||||
elif args.distplan == "ddp":
|
||||
|
Reference in New Issue
Block a user