mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[hotfix] set return_outputs=False in examples and polish code (#5404)
* fix: simplify merge_batch * fix: use return_outputs=False to eliminate extra memory consumption * feat: add return_outputs warning * style: remove `return_outputs=False` as it is the default value
This commit is contained in:
@@ -104,7 +104,7 @@ def run_pp(
|
||||
torch_loss.backward()
|
||||
|
||||
pp_ret = schedule.forward_backward_step(
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
|
||||
)
|
||||
|
||||
# check loss
|
||||
@@ -134,7 +134,7 @@ def run_pp(
|
||||
torch_loss = criterion(torch_output)
|
||||
|
||||
pp_ret = schedule.forward_backward_step(
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
|
||||
)
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
|
@@ -100,7 +100,7 @@ def examine_pp(num_microbatch: int, batch_size: int):
|
||||
torch_loss = criterion(torch_output)
|
||||
torch_loss.backward()
|
||||
pp_ret = schedule.forward_backward_step(
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
|
||||
)
|
||||
|
||||
# check loss
|
||||
@@ -130,7 +130,7 @@ def examine_pp(num_microbatch: int, batch_size: int):
|
||||
torch_loss = criterion(torch_output)
|
||||
|
||||
pp_ret = schedule.forward_backward_step(
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
|
Reference in New Issue
Block a user