[example] add zero1, zero2 example in GPT examples (#2146)

* [example] add zero1 and zero2 for GPT

* update readme in gpt example

* polish code

* change init value

* update readme
This commit is contained in:
HELSON
2022-12-20 14:30:27 +08:00
committed by GitHub
parent 1cce6e36ca
commit a7d95b7024
5 changed files with 40 additions and 27 deletions

View File

@@ -6,6 +6,7 @@ import torch
import torch.nn as nn
from packaging import version
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import GPT2Config, GPT2LMHeadModel
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
@@ -16,7 +17,7 @@ from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from transformers import GPT2Config, GPT2LMHeadModel
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer
def parse_args():
@@ -25,7 +26,7 @@ def parse_args():
"--distplan",
type=str,
default='colossalai',
help="The distributed plan [colossalai, ddp, zero].",
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
)
parser.add_argument(
"--tp_degree",
@@ -202,6 +203,9 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
def main():
args = parse_args()
if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
raise TypeError(f"{args.distplan} is error")
BATCH_SIZE = 8
SEQ_LEN = 1024
VOCAB_SIZE = 50257
@@ -237,19 +241,24 @@ def main():
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
# optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
elif args.distplan == "ddp":
model = gpt2_medium(checkpoint=True).cuda()
ddp_model = DDP(model)
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)
elif args.distplan == "zero":
from torch.distributed.optim import ZeroRedundancyOptimizer
model = gpt2_medium(checkpoint=True).cuda()
ddp_model = DDP(model)
optimizer = ZeroRedundancyOptimizer(ddp_model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
else:
raise TypeError(f"{args.distplan} is error")
model = gpt2_medium(checkpoint=True).cuda()
if args.distplan.startswith("torch"):
model = DDP(model)
if args.distplan.endswith("ddp"):
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
elif args.distplan.endswith("zero"):
from torch.distributed.optim import ZeroRedundancyOptimizer
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
elif args.distplan.startswith("zero"):
partition_flag = args.distplan == "zero2"
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer = LowLevelZeroOptimizer(optimizer,
overlap_communication=True,
partition_grad=partition_flag,
verbose=True)
# notice that the model is still in fp32
numel = sum([p.numel() for p in model.parameters()])
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
@@ -265,12 +274,13 @@ def main():
outputs = model(input_ids, attn_mask)
loss = criterion(outputs, input_ids)
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
if args.distplan == "colossalai":
if args.distplan in ["colossalai", "zero1", "zero2"]:
optimizer.backward(loss)
elif args.distplan in ["ddp", "zero"]:
elif args.distplan in ["torch_ddp", "torch_zero"]:
loss.backward()
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
if args.distplan in ["zero1", "zero2"]:
optimizer.sync_grad()
optimizer.step()
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
step_time = time() - start