mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 11:08:50 +00:00
[gemini] accelerate inference (#3641)
* [gemini] support don't scatter after inference * [chat] update colossalai strategy * [chat] fix opt benchmark * [chat] update opt benchmark * [gemini] optimize inference * [test] add gemini inference test * [chat] fix unit test ci * [chat] fix ci * [chat] fix ci * [chat] skip checkpoint test
This commit is contained in:
@@ -9,11 +9,11 @@ from . import (
|
||||
resnet,
|
||||
simple_net,
|
||||
)
|
||||
from .utils import run_fwd_bwd
|
||||
from .utils import run_fwd, run_fwd_bwd
|
||||
|
||||
from . import albert # isort:skip
|
||||
|
||||
__all__ = [
|
||||
'bert', 'gpt2', 'hanging_param_model', 'inline_op_model', 'nested_model', 'repeated_computed_layers', 'resnet',
|
||||
'simple_net', 'run_fwd_bwd', 'albert', 'beit'
|
||||
'simple_net', 'run_fwd_bwd', 'albert', 'beit', 'run_fwd'
|
||||
]
|
||||
|
@@ -1,2 +1,2 @@
|
||||
from .dummy_data_generator import DummyDataGenerator
|
||||
from .executor import run_fwd_bwd
|
||||
from .executor import run_fwd, run_fwd_bwd
|
||||
|
@@ -1,9 +1,9 @@
|
||||
import torch
|
||||
|
||||
|
||||
def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor:
|
||||
"""run_fwd_bwd
|
||||
run fwd and bwd for the model
|
||||
def run_fwd(model, data, label, criterion) -> torch.Tensor:
|
||||
"""run_fwd
|
||||
run fwd for the model
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): a PyTorch model
|
||||
@@ -22,6 +22,23 @@ def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor:
|
||||
loss = model(data, label)
|
||||
|
||||
loss = loss.float()
|
||||
return loss
|
||||
|
||||
|
||||
def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor:
|
||||
"""run_fwd_bwd
|
||||
run fwd and bwd for the model
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): a PyTorch model
|
||||
data (torch.Tensor): input data
|
||||
label (torch.Tensor): label
|
||||
criterion (Optional[Callable]): a function of criterion
|
||||
|
||||
Returns:
|
||||
torch.Tensor: loss of fwd
|
||||
"""
|
||||
loss = run_fwd(model, data, label, criterion)
|
||||
if optimizer:
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user