fix llama pretrain (#5287)

This commit is contained in:
flybird11111
2024-01-19 17:49:02 +08:00
committed by GitHub
parent 6a56967855
commit f7e3f82a7e

View File

@@ -273,11 +273,10 @@ def main():
dataloader.sampler.set_start_index(sampler_start_idx)
for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch)
step_nums = num_steps_per_epoch - start_step
dataloader_iter = iter(dataloader)
with tqdm(
range(step_nums),
range(start_step, num_steps_per_epoch),
desc=f"Epoch {epoch}",
disable=not print_flag,
total=num_steps_per_epoch,