[llama] fix neftune & pbar with start_step (#5364)

This commit is contained in:
Camille Zhong 2024-02-05 18:04:23 +08:00 committed by GitHub
parent a4cec1715b
commit 44ca61a22b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 3 deletions

View File

@ -17,7 +17,7 @@ import torch
def unwrap(model): def unwrap(model):
if hasattr(model, "module"): if hasattr(model, "module"):
return unwrap_model(model.module) return model.unwrap()
else: else:
return model return model

View File

@ -329,9 +329,9 @@ def main() -> None:
for epoch in range(start_epoch, args.num_epochs): for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch=epoch) dataloader.sampler.set_epoch(epoch=epoch)
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch) pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, initial=start_step // args.accumulation_steps)
total_loss = torch.tensor(0.0, device=get_current_device()) total_loss = torch.tensor(0.0, device=get_current_device())
for step, batch in enumerate(dataloader): for step, batch in enumerate(dataloader, start=start_step):
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
batch_output = model(**batch) batch_output = model(**batch)