mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[example] update gpt gemini example ci test (#2477)
This commit is contained in:
@@ -65,6 +65,7 @@ def parse_args():
|
||||
default="gpt2_medium",
|
||||
help="model model scale",
|
||||
)
|
||||
parser.add_argument("--steps", type=int, default=10, help="num of training steps")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@@ -236,7 +237,7 @@ def main():
|
||||
SEQ_LEN = 1024
|
||||
VOCAB_SIZE = 50257
|
||||
|
||||
NUM_STEPS = 10
|
||||
NUM_STEPS = args.steps
|
||||
WARMUP_STEPS = 1
|
||||
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
|
||||
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median "
|
||||
@@ -290,14 +291,12 @@ def main():
|
||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
|
||||
elif args.distplan.startswith("zero"):
|
||||
pg = ProcessGroup()
|
||||
model = model.half()
|
||||
partition_flag = (args.distplan == "zero2")
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
||||
|
||||
optimizer = LowLevelZeroOptimizer(
|
||||
optimizer,
|
||||
pg=pg,
|
||||
reduce_bucket_size=12 * 1024 * 1024,
|
||||
overlap_communication=True,
|
||||
partition_grad=partition_flag,
|
||||
|
Reference in New Issue
Block a user