mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 06:00:07 +00:00
[zero] fix unit-tests (#2039)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
|
||||
def run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) -> torch.Tensor:
|
||||
def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor:
|
||||
"""run_fwd_bwd
|
||||
run fwd and bwd for the model
|
||||
|
||||
@@ -10,7 +10,6 @@ def run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) -> torch.Tens
|
||||
data (torch.Tensor): input data
|
||||
label (torch.Tensor): label
|
||||
criterion (Optional[Callable]): a function of criterion
|
||||
use_init_ctx (bool, optional): whether the model is initialized under the contxt of ColoInitCtx. Defaults to False.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: loss of fwd
|
||||
@@ -23,8 +22,8 @@ def run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) -> torch.Tens
|
||||
loss = model(data, label)
|
||||
|
||||
loss = loss.float()
|
||||
if use_init_ctx:
|
||||
model.backward(loss)
|
||||
if optimizer:
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
return loss
|
||||
|
Reference in New Issue
Block a user