[pipeline] refactor 1f1b schedule (#4115)

* [api] update optimizer wrapper to fit pipeline

* [pipeline] add base schedule

* [pipeline] add 1f1b schedule

* [test] add pipeline schedule utils test

* [pipeline] fix import
This commit is contained in:
Hongxin Liu
2023-06-29 13:35:39 +08:00
parent 45fdc9b42c
commit f51ce1bc8e
6 changed files with 451 additions and 0 deletions

View File

@@ -1,5 +1,6 @@
from typing import Union
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
@@ -53,6 +54,9 @@ class OptimizerWrapper:
"""
loss.backward(*args, **kwargs)
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
torch.autograd.backward(tensor, grad)
def state_dict(self):
"""
Returns the optimizer state.