[Gemini] add GeminiAdamOptimizer (#1960)

This commit is contained in:
Jiarui Fang
2022-11-16 14:44:28 +08:00
committed by GitHub
parent 7066dfbf82
commit f7e276fa71
12 changed files with 66 additions and 44 deletions

View File

@@ -3,10 +3,8 @@ This example shows how to use Colossal-AI to run huggingface GPT training in dis
## GPT
We use the GPT2 model from huggingface transformers. The input data is randonly generated.
## Our Modifications
The `train_gpt_demo.py` provides three distributed plans, i.e. Colossal-AI, PyTorch DDP and ZeRO.
The Colossal-AI leverages Tensor Parallel and Gemini.
The `train_gpt_demo.py` provides three distributed plans, i.e. ColossalAI, PyTorch DDP and ZeRO.
The ColossalAI leverages Tensor Parallel and Gemini.
## Quick Start
You can launch training by using the following bash script.

View File

@@ -10,11 +10,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from transformers import GPT2Config, GPT2LMHeadModel
@@ -222,7 +223,7 @@ def main():
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
# build GPT model
with ColoInitContext(device='cuda', default_dist_spec=default_dist_spec, default_pg=default_pg):
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
model = gpt2_medium(checkpoint=True)
pg = default_pg
@@ -232,8 +233,9 @@ def main():
model = gemini_zero_dpp(model, pg, args.placement)
# build optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
# optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
elif args.distplan == "ddp":