diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py index bb10f7a00..8d5b7c8db 100644 --- a/examples/language/llama2/pretrain.py +++ b/examples/language/llama2/pretrain.py @@ -273,11 +273,10 @@ def main(): dataloader.sampler.set_start_index(sampler_start_idx) for epoch in range(start_epoch, args.num_epochs): dataloader.sampler.set_epoch(epoch) - step_nums = num_steps_per_epoch - start_step dataloader_iter = iter(dataloader) with tqdm( - range(step_nums), + range(start_step, num_steps_per_epoch), desc=f"Epoch {epoch}", disable=not print_flag, total=num_steps_per_epoch,