mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[test] merge old components to test to model zoo (#4945)
* [test] add custom models in model zoo * [test] update legacy test * [test] update model zoo * [test] update gemini test * [test] remove components to test
This commit is contained in:
58
tests/kit/model_zoo/executor.py
Normal file
58
tests/kit/model_zoo/executor.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
|
||||
|
||||
def run_fwd(
|
||||
model: Module, data: Dict, output_transform_fn: Callable, criterion: Optional[Callable] = None
|
||||
) -> torch.Tensor:
|
||||
"""run_fwd
|
||||
run fwd for the model
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): a PyTorch model
|
||||
data (torch.Tensor): input data
|
||||
label (torch.Tensor): label
|
||||
criterion (Optional[Callable]): a function of criterion
|
||||
|
||||
Returns:
|
||||
torch.Tensor: loss of fwd
|
||||
"""
|
||||
outputs = model(**data)
|
||||
outputs = output_transform_fn(outputs)
|
||||
if criterion:
|
||||
loss = criterion(outputs)
|
||||
else:
|
||||
loss = next(iter(outputs.values())).sum()
|
||||
return loss
|
||||
|
||||
|
||||
def run_fwd_bwd(
|
||||
model: Module,
|
||||
data: Dict,
|
||||
output_transform_fn: Callable,
|
||||
criterion: Optional[Callable] = None,
|
||||
optimizer: Optional[Union[Optimizer, OptimizerWrapper]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""run_fwd_bwd
|
||||
run fwd and bwd for the model
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): a PyTorch model
|
||||
data (torch.Tensor): input data
|
||||
label (torch.Tensor): label
|
||||
criterion (Optional[Callable]): a function of criterion
|
||||
|
||||
Returns:
|
||||
torch.Tensor: loss of fwd
|
||||
"""
|
||||
loss = run_fwd(model, data, output_transform_fn, criterion)
|
||||
if optimizer:
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
return loss
|
Reference in New Issue
Block a user