[shardformer] Pytree fix (#4533)

* pytree test

* test bert

* test bert

* test bert

* revise

* add register

* add register
This commit is contained in:
Jianghai 2023-09-04 17:52:23 +08:00 committed by GitHub
parent 508ca36fe3
commit 24c0768795
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 81 additions and 17 deletions

View File

@ -1,9 +1,59 @@
from typing import Any, List, Optional from collections import OrderedDict
from typing import Any, List, Optional, Tuple
import torch import torch
import torch.cuda import torch.cuda
from torch.nn import Module from torch.nn import Module
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten from torch.utils._pytree import (
SUPPORTED_NODES,
LeafSpec,
TreeSpec,
_is_leaf,
_register_pytree_node,
tree_flatten,
tree_map,
tree_unflatten,
)
# this register are for torch under version 1.13.1, maybe removed in the future
def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Any]:
return list(d.values()), list(d.keys())
def _odict_unflatten(values: List[Any], context: Any) -> 'OrderedDict[Any, Any]':
return OrderedDict((key, value) for key, value in zip(context, values))
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
def tree_map_hf(fn: Any, pytree: Any):
flat_args, spec = tree_flatten_hf(pytree)
return tree_unflatten([fn(i) for i in flat_args], spec)
# use this flatten function to handle the ModelingOutput Class instance.
def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]:
"""Flattens a pytree into a list of values an a TreeSpec that can be used
to reconstruct the pytree.
"""
if isinstance(pytree, OrderedDict):
node_type = OrderedDict
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, context = flatten_fn(pytree)
# Recursively flatten the children
result: List[Any] = []
children_specs: List['TreeSpec'] = []
for child in child_pytrees:
flat, child_spec = tree_flatten_hf(child)
result += flat
children_specs.append(child_spec)
return result, TreeSpec(node_type, context, children_specs)
else:
result, tree_spec = tree_flatten(pytree)
return result, tree_spec
def to_device(x: Any, device: Optional[torch.device] = None) -> Any: def to_device(x: Any, device: Optional[torch.device] = None) -> Any:
@ -104,7 +154,7 @@ def detach(x: Any) -> Any:
return x return x
def merge_batch(data: List[Any]) -> Any: def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
"""Merge micro batches into a batch. """Merge micro batches into a batch.
Args: Args:
@ -118,15 +168,17 @@ def merge_batch(data: List[Any]) -> Any:
flattened_data = [] flattened_data = []
tree_spec = None tree_spec = None
for d in data: for d in data:
elems, tree_spec = tree_flatten(d) # elems should be an instance of OrderedDict
elems, tree_spec = tree_flatten_hf(d)
flattened_data.append(elems) flattened_data.append(elems)
merged_data = [] merged_data = []
for elem_batch in zip(*flattened_data): for elem_batch in zip(*flattened_data):
if isinstance(elem_batch[0], torch.Tensor): if isinstance(elem_batch[0], torch.Tensor):
if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs
merged_data.append(None) merged_data.append(None)
else: else:
merged_data.append(torch.cat(elem_batch, dim=0)) merged_data.append(torch.cat(elem_batch, dim=batch_size_dim))
else: else:
merged_data.append(list(elem_batch)) merged_data.append(list(elem_batch))
return tree_unflatten(merged_data, tree_spec) return tree_unflatten(merged_data, tree_spec)

View File

@ -6,12 +6,21 @@ import torch.cuda
from torch.nn import Module from torch.nn import Module
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from ._utils import (
detach,
get_batch_size,
get_micro_batch,
merge_batch,
model_forward,
retain_grad,
to_device,
tree_map_hf,
)
from .base import PipelineSchedule from .base import PipelineSchedule
@ -154,7 +163,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if accum_loss is not None: if accum_loss is not None:
accum_loss.add_(loss.detach()) accum_loss.add_(loss.detach())
if outputs is not None: if outputs is not None:
outputs.append(tree_map(detach, output_obj)) outputs.append(tree_map_hf(detach, output_obj))
return loss return loss
else: else:
return output_obj return output_obj
@ -302,5 +311,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.send_backward(input_obj_grad) self.send_backward(input_obj_grad)
if outputs is not None: if outputs is not None:
outputs = merge_batch(outputs) if isinstance(model, ModelWrapper):
model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, 'batch_size_dim', 0))
return {'loss': accum_loss, 'outputs': outputs} return {'loss': accum_loss, 'outputs': outputs}

View File

@ -41,6 +41,11 @@ class ChatGLMPolicy(Policy):
new_vocab_size = vocab_size + world_size - vocab_size % world_size new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size) self.model.resize_token_embeddings(new_vocab_size)
if self.pipeline_stage_manager is not None:
# the batch_size_dim is bounded to Model
bsz_dim = 1
setattr(self.model, 'batch_size_dim', bsz_dim)
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:

View File

@ -191,15 +191,10 @@ def check_output_hidden_state(org_output: Tensor,
org_hidden_state = org_output.last_hidden_state org_hidden_state = org_output.last_hidden_state
if stage_manager is None:
sharded_hidden_state = sharded_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage(): if stage_manager and stage_manager.is_last_stage():
pipeline_output = sharded_output['outputs'] sharded_hidden_state = sharded_output['outputs']['last_hidden_state']
if isinstance(pipeline_output, List):
sharded_hidden_state = torch.cat([output.last_hidden_state for output in pipeline_output], dim=dim)
else: else:
sharded_hidden_state = pipeline_output.last_hidden_state sharded_hidden_state = sharded_output.last_hidden_state
assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \ assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"

View File

@ -179,6 +179,7 @@ def run_bert_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter() clear_layout_converter()