1
0
mirror of https://github.com/hpcaitech/ColossalAI.git synced 2025-05-03 22:18:23 +00:00

[shardformer] Pytree fix ()

* pytree test

* test bert

* test bert

* test bert

* revise

* add register

* add register
This commit is contained in:
Jianghai 2023-09-04 17:52:23 +08:00 committed by GitHub
parent 508ca36fe3
commit 24c0768795
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 81 additions and 17 deletions
colossalai
pipeline/schedule
shardformer/policies
tests/test_shardformer/test_model

View File

@ -1,9 +1,59 @@
from typing import Any, List, Optional
from collections import OrderedDict
from typing import Any, List, Optional, Tuple
import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from torch.utils._pytree import (
SUPPORTED_NODES,
LeafSpec,
TreeSpec,
_is_leaf,
_register_pytree_node,
tree_flatten,
tree_map,
tree_unflatten,
)
# this register are for torch under version 1.13.1, maybe removed in the future
def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Any]:
return list(d.values()), list(d.keys())
def _odict_unflatten(values: List[Any], context: Any) -> 'OrderedDict[Any, Any]':
return OrderedDict((key, value) for key, value in zip(context, values))
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
def tree_map_hf(fn: Any, pytree: Any):
flat_args, spec = tree_flatten_hf(pytree)
return tree_unflatten([fn(i) for i in flat_args], spec)
# use this flatten function to handle the ModelingOutput Class instance.
def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]:
"""Flattens a pytree into a list of values an a TreeSpec that can be used
to reconstruct the pytree.
"""
if isinstance(pytree, OrderedDict):
node_type = OrderedDict
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, context = flatten_fn(pytree)
# Recursively flatten the children
result: List[Any] = []
children_specs: List['TreeSpec'] = []
for child in child_pytrees:
flat, child_spec = tree_flatten_hf(child)
result += flat
children_specs.append(child_spec)
return result, TreeSpec(node_type, context, children_specs)
else:
result, tree_spec = tree_flatten(pytree)
return result, tree_spec
def to_device(x: Any, device: Optional[torch.device] = None) -> Any:
@ -104,7 +154,7 @@ def detach(x: Any) -> Any:
return x
def merge_batch(data: List[Any]) -> Any:
def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
"""Merge micro batches into a batch.
Args:
@ -118,15 +168,17 @@ def merge_batch(data: List[Any]) -> Any:
flattened_data = []
tree_spec = None
for d in data:
elems, tree_spec = tree_flatten(d)
# elems should be an instance of OrderedDict
elems, tree_spec = tree_flatten_hf(d)
flattened_data.append(elems)
merged_data = []
for elem_batch in zip(*flattened_data):
if isinstance(elem_batch[0], torch.Tensor):
if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs
merged_data.append(None)
else:
merged_data.append(torch.cat(elem_batch, dim=0))
merged_data.append(torch.cat(elem_batch, dim=batch_size_dim))
else:
merged_data.append(list(elem_batch))
return tree_unflatten(merged_data, tree_spec)

View File

@ -6,12 +6,21 @@ import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from ._utils import (
detach,
get_batch_size,
get_micro_batch,
merge_batch,
model_forward,
retain_grad,
to_device,
tree_map_hf,
)
from .base import PipelineSchedule
@ -154,7 +163,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
outputs.append(tree_map_hf(detach, output_obj))
return loss
else:
return output_obj
@ -302,5 +311,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.send_backward(input_obj_grad)
if outputs is not None:
outputs = merge_batch(outputs)
if isinstance(model, ModelWrapper):
model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, 'batch_size_dim', 0))
return {'loss': accum_loss, 'outputs': outputs}

View File

@ -41,6 +41,11 @@ class ChatGLMPolicy(Policy):
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
if self.pipeline_stage_manager is not None:
# the batch_size_dim is bounded to Model
bsz_dim = 1
setattr(self.model, 'batch_size_dim', bsz_dim)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:

View File

@ -191,15 +191,10 @@ def check_output_hidden_state(org_output: Tensor,
org_hidden_state = org_output.last_hidden_state
if stage_manager is None:
sharded_hidden_state = sharded_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage():
pipeline_output = sharded_output['outputs']
if isinstance(pipeline_output, List):
sharded_hidden_state = torch.cat([output.last_hidden_state for output in pipeline_output], dim=dim)
else:
sharded_hidden_state = pipeline_output.last_hidden_state
sharded_hidden_state = sharded_output['outputs']['last_hidden_state']
else:
sharded_hidden_state = sharded_output.last_hidden_state
assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"

View File

@ -179,6 +179,7 @@ def run_bert_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()