mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-12 10:18:08 +00:00
[feat] add optim backward_b_by_grad
This commit is contained in:
@@ -58,6 +58,28 @@ class OptimizerWrapper:
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||
torch.autograd.backward(tensor, grad)
|
||||
|
||||
def backward_b_by_grad(self, tensor: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True):
|
||||
"""
|
||||
Performs a backward pass for dx, we only calculate dx = w*dy here
|
||||
|
||||
Args:
|
||||
tensor (Tensor): y or loss of current chunk;
|
||||
grad_tensors (Tensor): dy of current chunk;
|
||||
input_obj (Tensor): x of current chunk;
|
||||
retain_graph (bool): default to be True, we retain graph in backward_b
|
||||
"""
|
||||
torch.autograd.backward(
|
||||
tensors=tensor,
|
||||
grad_tensors=grad_tensors,
|
||||
inputs=inputs,
|
||||
retain_graph=retain_graph,
|
||||
)
|
||||
|
||||
def backward_w_by_grad():
|
||||
"""
|
||||
Performs a backward pass for dw, we only calculate dw = x*dy here
|
||||
"""
|
||||
|
||||
def state_dict(self):
|
||||
"""
|
||||
Returns the optimizer state.
|
||||
|
||||
Reference in New Issue
Block a user