[zero] low level optim supports ProcessGroup (#2464)

This commit is contained in:
Jiarui Fang
2023-01-13 10:05:58 +08:00
committed by GitHub
parent e6943e2d11
commit 867c8c2d3a
8 changed files with 106 additions and 52 deletions

View File

@@ -290,14 +290,19 @@ def main():
from torch.distributed.optim import ZeroRedundancyOptimizer
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
elif args.distplan.startswith("zero"):
pg = ProcessGroup()
model = model.half()
partition_flag = args.distplan == "zero2"
partition_flag = (args.distplan == "zero2")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer = LowLevelZeroOptimizer(optimizer,
reduce_bucket_size=12 * 1024 * 1024,
overlap_communication=True,
partition_grad=partition_flag,
verbose=True)
optimizer = LowLevelZeroOptimizer(
optimizer,
pg=pg,
reduce_bucket_size=12 * 1024 * 1024,
overlap_communication=True,
partition_grad=partition_flag,
verbose=True,
)
# model is shared after TP
numel = get_model_size(model)