mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[shardformer] Pytree fix (#4533)
* pytree test * test bert * test bert * test bert * revise * add register * add register
This commit is contained in:
@@ -1,9 +1,59 @@
|
||||
from typing import Any, List, Optional
|
||||
from collections import OrderedDict
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
from torch.nn import Module
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
from torch.utils._pytree import (
|
||||
SUPPORTED_NODES,
|
||||
LeafSpec,
|
||||
TreeSpec,
|
||||
_is_leaf,
|
||||
_register_pytree_node,
|
||||
tree_flatten,
|
||||
tree_map,
|
||||
tree_unflatten,
|
||||
)
|
||||
|
||||
|
||||
# this register are for torch under version 1.13.1, maybe removed in the future
|
||||
def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Any]:
|
||||
return list(d.values()), list(d.keys())
|
||||
|
||||
|
||||
def _odict_unflatten(values: List[Any], context: Any) -> 'OrderedDict[Any, Any]':
|
||||
return OrderedDict((key, value) for key, value in zip(context, values))
|
||||
|
||||
|
||||
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
|
||||
|
||||
|
||||
def tree_map_hf(fn: Any, pytree: Any):
|
||||
flat_args, spec = tree_flatten_hf(pytree)
|
||||
return tree_unflatten([fn(i) for i in flat_args], spec)
|
||||
|
||||
|
||||
# use this flatten function to handle the ModelingOutput Class instance.
|
||||
def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]:
|
||||
"""Flattens a pytree into a list of values an a TreeSpec that can be used
|
||||
to reconstruct the pytree.
|
||||
"""
|
||||
if isinstance(pytree, OrderedDict):
|
||||
node_type = OrderedDict
|
||||
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
||||
child_pytrees, context = flatten_fn(pytree)
|
||||
|
||||
# Recursively flatten the children
|
||||
result: List[Any] = []
|
||||
children_specs: List['TreeSpec'] = []
|
||||
for child in child_pytrees:
|
||||
flat, child_spec = tree_flatten_hf(child)
|
||||
result += flat
|
||||
children_specs.append(child_spec)
|
||||
return result, TreeSpec(node_type, context, children_specs)
|
||||
else:
|
||||
result, tree_spec = tree_flatten(pytree)
|
||||
return result, tree_spec
|
||||
|
||||
|
||||
def to_device(x: Any, device: Optional[torch.device] = None) -> Any:
|
||||
@@ -104,7 +154,7 @@ def detach(x: Any) -> Any:
|
||||
return x
|
||||
|
||||
|
||||
def merge_batch(data: List[Any]) -> Any:
|
||||
def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
|
||||
"""Merge micro batches into a batch.
|
||||
|
||||
Args:
|
||||
@@ -118,15 +168,17 @@ def merge_batch(data: List[Any]) -> Any:
|
||||
flattened_data = []
|
||||
tree_spec = None
|
||||
for d in data:
|
||||
elems, tree_spec = tree_flatten(d)
|
||||
# elems should be an instance of OrderedDict
|
||||
elems, tree_spec = tree_flatten_hf(d)
|
||||
flattened_data.append(elems)
|
||||
merged_data = []
|
||||
|
||||
for elem_batch in zip(*flattened_data):
|
||||
if isinstance(elem_batch[0], torch.Tensor):
|
||||
if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs
|
||||
merged_data.append(None)
|
||||
else:
|
||||
merged_data.append(torch.cat(elem_batch, dim=0))
|
||||
merged_data.append(torch.cat(elem_batch, dim=batch_size_dim))
|
||||
else:
|
||||
merged_data.append(list(elem_batch))
|
||||
return tree_unflatten(merged_data, tree_spec)
|
||||
|
Reference in New Issue
Block a user