From f7e3f82a7e49057593932bbd8d9f1797c7584bf6 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 19 Jan 2024 17:49:02 +0800 Subject: [PATCH] fix llama pretrain (#5287) --- examples/language/llama2/pretrain.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py index bb10f7a00..8d5b7c8db 100644 --- a/examples/language/llama2/pretrain.py +++ b/examples/language/llama2/pretrain.py @@ -273,11 +273,10 @@ def main(): dataloader.sampler.set_start_index(sampler_start_idx) for epoch in range(start_epoch, args.num_epochs): dataloader.sampler.set_epoch(epoch) - step_nums = num_steps_per_epoch - start_step dataloader_iter = iter(dataloader) with tqdm( - range(step_nums), + range(start_step, num_steps_per_epoch), desc=f"Epoch {epoch}", disable=not print_flag, total=num_steps_per_epoch,