mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-29 00:38:02 +00:00
[llama] fix neftune & pbar with start_step (#5364)
This commit is contained in:
parent
a4cec1715b
commit
44ca61a22b
@ -17,7 +17,7 @@ import torch
|
|||||||
|
|
||||||
def unwrap(model):
|
def unwrap(model):
|
||||||
if hasattr(model, "module"):
|
if hasattr(model, "module"):
|
||||||
return unwrap_model(model.module)
|
return model.unwrap()
|
||||||
else:
|
else:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -329,9 +329,9 @@ def main() -> None:
|
|||||||
|
|
||||||
for epoch in range(start_epoch, args.num_epochs):
|
for epoch in range(start_epoch, args.num_epochs):
|
||||||
dataloader.sampler.set_epoch(epoch=epoch)
|
dataloader.sampler.set_epoch(epoch=epoch)
|
||||||
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch)
|
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, initial=start_step // args.accumulation_steps)
|
||||||
total_loss = torch.tensor(0.0, device=get_current_device())
|
total_loss = torch.tensor(0.0, device=get_current_device())
|
||||||
for step, batch in enumerate(dataloader):
|
for step, batch in enumerate(dataloader, start=start_step):
|
||||||
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||||
|
|
||||||
batch_output = model(**batch)
|
batch_output = model(**batch)
|
||||||
|
Loading…
Reference in New Issue
Block a user