[autoparallel] add getattr handler (#1767)

* [autoparallel] add getattr haandler

* polish code

* add extra processes for Parameters

* add unit test for param resharding cost

* add docstring and polish test
This commit is contained in:
YuliangLiu0306
2022-11-03 12:31:33 +08:00
committed by GitHub
parent c6a1a62636
commit 2c4c7b3618
11 changed files with 306 additions and 37 deletions

View File

@@ -6,9 +6,10 @@ from typing import Dict, List
import torch
from torch.fx import Graph, Node
from colossalai.auto_parallel.tensor_shard.node_handler import (OuputHandler, PlacehodlerHandler, operator_registry)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.auto_parallel.tensor_shard.utils import (generate_resharding_costs, generate_sharding_spec)
from colossalai.auto_parallel.tensor_shard.node_handler import OuputHandler, PlacehodlerHandler, operator_registry
from colossalai.auto_parallel.tensor_shard.node_handler.getatrr_handler import GetattrHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
@@ -71,25 +72,8 @@ class StrategiesConstructor:
# get_attr node
if node.op == 'get_attr':
# Same as placeholder nodes, if solver_options.fast is True, we just let them in
# fully replicate status, then strategies of following node will be treated equally due
# to replicate status has no resharding cost to other status. At the same time, the searching
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
if self.solver_options.fast:
# create sharding strategy for get_attr
name = 'Replica Attribute'
dim_partition_dict = {}
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost)
strategies_vector.append(sharding_strategy_attribute)
# # get_attr node
# elif node.op == 'get_attr':
# # TODO: implement getattr node handler
# pass
getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector)
getattr_handler.register_strategy()
# call_module node
elif node.op == 'call_module':