mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
Merge branch 'main' into dev/zero_bubble
This commit is contained in:
@@ -3,8 +3,9 @@ from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
from packaging.version import Version
|
||||
from torch.nn import Module
|
||||
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten
|
||||
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
# this register are for torch under version 1.13.1, maybe removed in the future
|
||||
@@ -16,7 +17,12 @@ def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]"
|
||||
return OrderedDict((key, value) for key, value in zip(context, values))
|
||||
|
||||
|
||||
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
|
||||
if Version(torch.__version__) <= Version("1.13.1"):
|
||||
try:
|
||||
from torch.utils._pytree import register_pytree_node as _register_pytree_node
|
||||
except ImportError:
|
||||
from torch.utils._pytree import _register_pytree_node
|
||||
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
|
||||
|
||||
|
||||
def tree_map_hf(fn: Any, pytree: Any):
|
||||
|
@@ -351,15 +351,16 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
if output_obj_grad is None:
|
||||
optimizer.backward(output_obj)
|
||||
else:
|
||||
if "backward_tensor_keys" not in output_obj:
|
||||
for k, grad in output_obj_grad.items():
|
||||
optimizer.backward_by_grad(output_obj[k], grad)
|
||||
keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys())
|
||||
tensors_to_backward = []
|
||||
grads_to_backward = []
|
||||
for k in keys:
|
||||
tensors_to_backward.append(output_obj[k])
|
||||
grads_to_backward.append(output_obj_grad[k])
|
||||
if len(tensors_to_backward) == 1:
|
||||
optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0])
|
||||
else:
|
||||
for k, grad in output_obj_grad.items():
|
||||
output_obj[k].grad = grad
|
||||
for k in output_obj["backward_tensor_keys"]:
|
||||
tensor_to_backward = output_obj[k]
|
||||
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
|
||||
optimizer.backward_by_grad(tensors_to_backward, grads_to_backward)
|
||||
|
||||
# Collect the grad of the input_obj.
|
||||
input_obj_grad = None
|
||||
|
@@ -305,15 +305,16 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
if output_obj_grad is None:
|
||||
optimizer.backward(output_obj)
|
||||
else:
|
||||
if "backward_tensor_keys" not in output_obj:
|
||||
for k, grad in output_obj_grad.items():
|
||||
optimizer.backward_by_grad(output_obj[k], grad)
|
||||
keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys())
|
||||
tensors_to_backward = []
|
||||
grads_to_backward = []
|
||||
for k in keys:
|
||||
tensors_to_backward.append(output_obj[k])
|
||||
grads_to_backward.append(output_obj_grad[k])
|
||||
if len(tensors_to_backward) == 1:
|
||||
optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0])
|
||||
else:
|
||||
for k, grad in output_obj_grad.items():
|
||||
output_obj[k].grad = grad
|
||||
for k in output_obj["backward_tensor_keys"]:
|
||||
tensor_to_backward = output_obj[k]
|
||||
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
|
||||
optimizer.backward_by_grad(tensors_to_backward, grads_to_backward)
|
||||
|
||||
# Collect the grad of the input_obj.
|
||||
input_obj_grad = None
|
||||
|
Reference in New Issue
Block a user