mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-11 08:58:34 +00:00
[example] add zero1, zero2 example in GPT examples (#2146)
* [example] add zero1 and zero2 for GPT * update readme in gpt example * polish code * change init value * update readme
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
@@ -16,7 +17,7 @@ from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer
|
||||
|
||||
|
||||
def parse_args():
|
||||
@@ -25,7 +26,7 @@ def parse_args():
|
||||
"--distplan",
|
||||
type=str,
|
||||
default='colossalai',
|
||||
help="The distributed plan [colossalai, ddp, zero].",
|
||||
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_degree",
|
||||
@@ -202,6 +203,9 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
|
||||
raise TypeError(f"{args.distplan} is error")
|
||||
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LEN = 1024
|
||||
VOCAB_SIZE = 50257
|
||||
@@ -237,19 +241,24 @@ def main():
|
||||
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
# optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
|
||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||
|
||||
elif args.distplan == "ddp":
|
||||
model = gpt2_medium(checkpoint=True).cuda()
|
||||
ddp_model = DDP(model)
|
||||
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)
|
||||
|
||||
elif args.distplan == "zero":
|
||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
model = gpt2_medium(checkpoint=True).cuda()
|
||||
ddp_model = DDP(model)
|
||||
optimizer = ZeroRedundancyOptimizer(ddp_model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
|
||||
else:
|
||||
raise TypeError(f"{args.distplan} is error")
|
||||
model = gpt2_medium(checkpoint=True).cuda()
|
||||
|
||||
if args.distplan.startswith("torch"):
|
||||
model = DDP(model)
|
||||
if args.distplan.endswith("ddp"):
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
||||
elif args.distplan.endswith("zero"):
|
||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
|
||||
elif args.distplan.startswith("zero"):
|
||||
partition_flag = args.distplan == "zero2"
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
||||
optimizer = LowLevelZeroOptimizer(optimizer,
|
||||
overlap_communication=True,
|
||||
partition_grad=partition_flag,
|
||||
verbose=True)
|
||||
# notice that the model is still in fp32
|
||||
|
||||
numel = sum([p.numel() for p in model.parameters()])
|
||||
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||
@@ -265,12 +274,13 @@ def main():
|
||||
outputs = model(input_ids, attn_mask)
|
||||
loss = criterion(outputs, input_ids)
|
||||
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
|
||||
if args.distplan == "colossalai":
|
||||
if args.distplan in ["colossalai", "zero1", "zero2"]:
|
||||
optimizer.backward(loss)
|
||||
elif args.distplan in ["ddp", "zero"]:
|
||||
elif args.distplan in ["torch_ddp", "torch_zero"]:
|
||||
loss.backward()
|
||||
|
||||
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
|
||||
if args.distplan in ["zero1", "zero2"]:
|
||||
optimizer.sync_grad()
|
||||
optimizer.step()
|
||||
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
|
||||
step_time = time() - start
|
||||
|
Reference in New Issue
Block a user