Merge branch 'main' into dev/zero_bubble

This commit is contained in:
duanjunwen
2024-11-01 03:10:53 +00:00
60 changed files with 1690 additions and 834 deletions

View File

@@ -3,8 +3,9 @@ from typing import Any, List, Optional, Tuple
import torch
import torch.cuda
from packaging.version import Version
from torch.nn import Module
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, tree_flatten, tree_map, tree_unflatten
# this register are for torch under version 1.13.1, maybe removed in the future
@@ -16,7 +17,12 @@ def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]"
return OrderedDict((key, value) for key, value in zip(context, values))
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
if Version(torch.__version__) <= Version("1.13.1"):
try:
from torch.utils._pytree import register_pytree_node as _register_pytree_node
except ImportError:
from torch.utils._pytree import _register_pytree_node
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
def tree_map_hf(fn: Any, pytree: Any):

View File

@@ -351,15 +351,16 @@ class InterleavedSchedule(PipelineSchedule):
if output_obj_grad is None:
optimizer.backward(output_obj)
else:
if "backward_tensor_keys" not in output_obj:
for k, grad in output_obj_grad.items():
optimizer.backward_by_grad(output_obj[k], grad)
keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys())
tensors_to_backward = []
grads_to_backward = []
for k in keys:
tensors_to_backward.append(output_obj[k])
grads_to_backward.append(output_obj_grad[k])
if len(tensors_to_backward) == 1:
optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0])
else:
for k, grad in output_obj_grad.items():
output_obj[k].grad = grad
for k in output_obj["backward_tensor_keys"]:
tensor_to_backward = output_obj[k]
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
optimizer.backward_by_grad(tensors_to_backward, grads_to_backward)
# Collect the grad of the input_obj.
input_obj_grad = None

View File

@@ -305,15 +305,16 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if output_obj_grad is None:
optimizer.backward(output_obj)
else:
if "backward_tensor_keys" not in output_obj:
for k, grad in output_obj_grad.items():
optimizer.backward_by_grad(output_obj[k], grad)
keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys())
tensors_to_backward = []
grads_to_backward = []
for k in keys:
tensors_to_backward.append(output_obj[k])
grads_to_backward.append(output_obj_grad[k])
if len(tensors_to_backward) == 1:
optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0])
else:
for k, grad in output_obj_grad.items():
output_obj[k].grad = grad
for k in output_obj["backward_tensor_keys"]:
tensor_to_backward = output_obj[k]
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
optimizer.backward_by_grad(tensors_to_backward, grads_to_backward)
# Collect the grad of the input_obj.
input_obj_grad = None