mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[pipeline] refactor 1f1b schedule (#4115)
* [api] update optimizer wrapper to fit pipeline * [pipeline] add base schedule * [pipeline] add 1f1b schedule * [test] add pipeline schedule utils test * [pipeline] fix import
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
@@ -53,6 +54,9 @@ class OptimizerWrapper:
|
||||
"""
|
||||
loss.backward(*args, **kwargs)
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||
torch.autograd.backward(tensor, grad)
|
||||
|
||||
def state_dict(self):
|
||||
"""
|
||||
Returns the optimizer state.
|
||||
|
Reference in New Issue
Block a user