mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[autoparallel]integrate auto parallel feature with new tracer (#3408)
* [autoparallel] integrate new analyzer in module level * unify the profiling method * polish * fix no codegen bug * fix pass bug * fix liveness test * polish
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Dict, List
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai._analyzer.fx.node_util import MetaInfo
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
@@ -128,9 +128,10 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
||||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
if 'activation_checkpoint' in user_node.meta:
|
||||
shape_consistency_node.meta['activation_checkpoint'] = user_node.meta['activation_checkpoint']
|
||||
|
||||
if hasattr(user_node.meta['info'], 'activation_checkpoint'):
|
||||
MetaInfo(shape_consistency_node,
|
||||
mod_dir=user_node.meta['info'].mod_dir,
|
||||
activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint))
|
||||
new_args = list(user_node.args)
|
||||
new_kwargs = dict(user_node.kwargs)
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
@@ -210,9 +211,10 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
||||
# substitute the origin node with comm_spec_apply_node
|
||||
new_kwargs[str(node)] = comm_spec_apply_node
|
||||
user.kwargs = new_kwargs
|
||||
|
||||
if 'activation_checkpoint' in node.meta:
|
||||
comm_spec_apply_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
|
||||
if hasattr(node.meta['info'], 'activation_checkpoint'):
|
||||
MetaInfo(comm_spec_apply_node,
|
||||
mod_dir=node.meta['info'].mod_dir,
|
||||
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
|
||||
|
||||
return gm
|
||||
|
||||
|
Reference in New Issue
Block a user