[zero] fix unit-tests (#2039)

This commit is contained in:
HELSON
2022-11-30 10:40:31 +08:00
committed by GitHub
parent eb7742a4bb
commit 17a3c685b0
4 changed files with 44 additions and 44 deletions

View File

@@ -1,7 +1,7 @@
import torch
def run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) -> torch.Tensor:
def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor:
"""run_fwd_bwd
run fwd and bwd for the model
@@ -10,7 +10,6 @@ def run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) -> torch.Tens
data (torch.Tensor): input data
label (torch.Tensor): label
criterion (Optional[Callable]): a function of criterion
use_init_ctx (bool, optional): whether the model is initialized under the contxt of ColoInitCtx. Defaults to False.
Returns:
torch.Tensor: loss of fwd
@@ -23,8 +22,8 @@ def run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) -> torch.Tens
loss = model(data, label)
loss = loss.float()
if use_init_ctx:
model.backward(loss)
if optimizer:
optimizer.backward(loss)
else:
loss.backward()
return loss