mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-03 22:18:23 +00:00
[shardformer] Pytree fix (#4533)
* pytree test * test bert * test bert * test bert * revise * add register * add register
This commit is contained in:
parent
508ca36fe3
commit
24c0768795
colossalai
tests/test_shardformer/test_model
@ -1,9 +1,59 @@
|
||||
from typing import Any, List, Optional
|
||||
from collections import OrderedDict
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
from torch.nn import Module
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
from torch.utils._pytree import (
|
||||
SUPPORTED_NODES,
|
||||
LeafSpec,
|
||||
TreeSpec,
|
||||
_is_leaf,
|
||||
_register_pytree_node,
|
||||
tree_flatten,
|
||||
tree_map,
|
||||
tree_unflatten,
|
||||
)
|
||||
|
||||
|
||||
# this register are for torch under version 1.13.1, maybe removed in the future
|
||||
def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Any]:
|
||||
return list(d.values()), list(d.keys())
|
||||
|
||||
|
||||
def _odict_unflatten(values: List[Any], context: Any) -> 'OrderedDict[Any, Any]':
|
||||
return OrderedDict((key, value) for key, value in zip(context, values))
|
||||
|
||||
|
||||
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
|
||||
|
||||
|
||||
def tree_map_hf(fn: Any, pytree: Any):
|
||||
flat_args, spec = tree_flatten_hf(pytree)
|
||||
return tree_unflatten([fn(i) for i in flat_args], spec)
|
||||
|
||||
|
||||
# use this flatten function to handle the ModelingOutput Class instance.
|
||||
def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]:
|
||||
"""Flattens a pytree into a list of values an a TreeSpec that can be used
|
||||
to reconstruct the pytree.
|
||||
"""
|
||||
if isinstance(pytree, OrderedDict):
|
||||
node_type = OrderedDict
|
||||
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
||||
child_pytrees, context = flatten_fn(pytree)
|
||||
|
||||
# Recursively flatten the children
|
||||
result: List[Any] = []
|
||||
children_specs: List['TreeSpec'] = []
|
||||
for child in child_pytrees:
|
||||
flat, child_spec = tree_flatten_hf(child)
|
||||
result += flat
|
||||
children_specs.append(child_spec)
|
||||
return result, TreeSpec(node_type, context, children_specs)
|
||||
else:
|
||||
result, tree_spec = tree_flatten(pytree)
|
||||
return result, tree_spec
|
||||
|
||||
|
||||
def to_device(x: Any, device: Optional[torch.device] = None) -> Any:
|
||||
@ -104,7 +154,7 @@ def detach(x: Any) -> Any:
|
||||
return x
|
||||
|
||||
|
||||
def merge_batch(data: List[Any]) -> Any:
|
||||
def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
|
||||
"""Merge micro batches into a batch.
|
||||
|
||||
Args:
|
||||
@ -118,15 +168,17 @@ def merge_batch(data: List[Any]) -> Any:
|
||||
flattened_data = []
|
||||
tree_spec = None
|
||||
for d in data:
|
||||
elems, tree_spec = tree_flatten(d)
|
||||
# elems should be an instance of OrderedDict
|
||||
elems, tree_spec = tree_flatten_hf(d)
|
||||
flattened_data.append(elems)
|
||||
merged_data = []
|
||||
|
||||
for elem_batch in zip(*flattened_data):
|
||||
if isinstance(elem_batch[0], torch.Tensor):
|
||||
if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs
|
||||
merged_data.append(None)
|
||||
else:
|
||||
merged_data.append(torch.cat(elem_batch, dim=0))
|
||||
merged_data.append(torch.cat(elem_batch, dim=batch_size_dim))
|
||||
else:
|
||||
merged_data.append(list(elem_batch))
|
||||
return tree_unflatten(merged_data, tree_spec)
|
||||
|
@ -6,12 +6,21 @@ import torch.cuda
|
||||
from torch.nn import Module
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
|
||||
from ._utils import (
|
||||
detach,
|
||||
get_batch_size,
|
||||
get_micro_batch,
|
||||
merge_batch,
|
||||
model_forward,
|
||||
retain_grad,
|
||||
to_device,
|
||||
tree_map_hf,
|
||||
)
|
||||
from .base import PipelineSchedule
|
||||
|
||||
|
||||
@ -154,7 +163,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
if outputs is not None:
|
||||
outputs.append(tree_map(detach, output_obj))
|
||||
outputs.append(tree_map_hf(detach, output_obj))
|
||||
return loss
|
||||
else:
|
||||
return output_obj
|
||||
@ -302,5 +311,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
self.send_backward(input_obj_grad)
|
||||
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
outputs = merge_batch(outputs, getattr(model, 'batch_size_dim', 0))
|
||||
return {'loss': accum_loss, 'outputs': outputs}
|
||||
|
@ -41,6 +41,11 @@ class ChatGLMPolicy(Policy):
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
# the batch_size_dim is bounded to Model
|
||||
bsz_dim = 1
|
||||
setattr(self.model, 'batch_size_dim', bsz_dim)
|
||||
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
|
@ -191,15 +191,10 @@ def check_output_hidden_state(org_output: Tensor,
|
||||
|
||||
org_hidden_state = org_output.last_hidden_state
|
||||
|
||||
if stage_manager is None:
|
||||
sharded_hidden_state = sharded_output.last_hidden_state
|
||||
|
||||
if stage_manager and stage_manager.is_last_stage():
|
||||
pipeline_output = sharded_output['outputs']
|
||||
if isinstance(pipeline_output, List):
|
||||
sharded_hidden_state = torch.cat([output.last_hidden_state for output in pipeline_output], dim=dim)
|
||||
else:
|
||||
sharded_hidden_state = pipeline_output.last_hidden_state
|
||||
sharded_hidden_state = sharded_output['outputs']['last_hidden_state']
|
||||
else:
|
||||
sharded_hidden_state = sharded_output.last_hidden_state
|
||||
|
||||
assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
|
||||
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
|
||||
|
@ -179,6 +179,7 @@ def run_bert_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
|
Loading…
Reference in New Issue
Block a user