mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[zero] low level optim supports ProcessGroup (#2464)
This commit is contained in:
@@ -290,14 +290,19 @@ def main():
|
||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
|
||||
elif args.distplan.startswith("zero"):
|
||||
pg = ProcessGroup()
|
||||
model = model.half()
|
||||
partition_flag = args.distplan == "zero2"
|
||||
partition_flag = (args.distplan == "zero2")
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
||||
optimizer = LowLevelZeroOptimizer(optimizer,
|
||||
reduce_bucket_size=12 * 1024 * 1024,
|
||||
overlap_communication=True,
|
||||
partition_grad=partition_flag,
|
||||
verbose=True)
|
||||
|
||||
optimizer = LowLevelZeroOptimizer(
|
||||
optimizer,
|
||||
pg=pg,
|
||||
reduce_bucket_size=12 * 1024 * 1024,
|
||||
overlap_communication=True,
|
||||
partition_grad=partition_flag,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# model is shared after TP
|
||||
numel = get_model_size(model)
|
||||
|
Reference in New Issue
Block a user