[fix] fix linear (no tp) ops func name;

This commit is contained in:
duanjunwen
2024-10-31 08:18:28 +00:00
parent d2e05a99b3
commit 5f0924361d
6 changed files with 19 additions and 41 deletions

View File

@@ -366,10 +366,10 @@ def main():
)
loss = outputs["loss"]
if args.pp_style == "zbv":
if dist.get_rank() == 0:
if coordinator.is_master():
print(f"Step {step} loss: {loss}")
else:
if dist.get_rank() == dist.get_world_size() - 1:
if coordinator.is_last_process():
print(f"Step {step} loss: {loss}")
optimizer.step()
optimizer.zero_grad()