[hotfix] set return_outputs=False in examples and polish code (#5404)

* fix: simplify merge_batch

* fix: use return_outputs=False to eliminate extra memory consumption

* feat: add return_outputs warning

* style: remove `return_outputs=False` as it is the default value
This commit is contained in:
Wenhao Chen
2024-03-25 12:31:09 +08:00
committed by GitHub
parent 5fcd7795cd
commit bb0a668fee
24 changed files with 28 additions and 36 deletions

View File

@@ -104,7 +104,7 @@ def run_pp(
torch_loss.backward()
pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
)
# check loss
@@ -134,7 +134,7 @@ def run_pp(
torch_loss = criterion(torch_output)
pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
)
if stage_manager.is_last_stage(ignore_chunk=True):
assert torch.allclose(torch_loss, pp_ret["loss"])

View File

@@ -100,7 +100,7 @@ def examine_pp(num_microbatch: int, batch_size: int):
torch_loss = criterion(torch_output)
torch_loss.backward()
pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
)
# check loss
@@ -130,7 +130,7 @@ def examine_pp(num_microbatch: int, batch_size: int):
torch_loss = criterion(torch_output)
pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
)
if stage_manager.is_last_stage():
assert torch.allclose(torch_loss, pp_ret["loss"])