[autoparallel]integrate auto parallel feature with new tracer (#3408)

* [autoparallel] integrate new analyzer in module level

* unify the profiling method

* polish

* fix no codegen bug

* fix pass bug

* fix liveness test

* polish
This commit is contained in:
YuliangLiu0306
2023-04-04 17:40:45 +08:00
committed by GitHub
parent 573af84184
commit ffcdbf0f65
46 changed files with 396 additions and 470 deletions

View File

@@ -4,7 +4,7 @@ from typing import Dict, List
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler import MetaInfo
from colossalai._analyzer.fx.node_util import MetaInfo
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
@@ -128,9 +128,10 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
if 'activation_checkpoint' in user_node.meta:
shape_consistency_node.meta['activation_checkpoint'] = user_node.meta['activation_checkpoint']
if hasattr(user_node.meta['info'], 'activation_checkpoint'):
MetaInfo(shape_consistency_node,
mod_dir=user_node.meta['info'].mod_dir,
activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint))
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)
# the origin node may be a positional argument or key word argument of user node
@@ -210,9 +211,10 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs
if 'activation_checkpoint' in node.meta:
comm_spec_apply_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
if hasattr(node.meta['info'], 'activation_checkpoint'):
MetaInfo(comm_spec_apply_node,
mod_dir=node.meta['info'].mod_dir,
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
return gm