mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[autoparallel] add getattr handler (#1767)
* [autoparallel] add getattr haandler * polish code * add extra processes for Parameters * add unit test for param resharding cost * add docstring and polish test
This commit is contained in:
@@ -6,9 +6,10 @@ from typing import Dict, List
|
||||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import (OuputHandler, PlacehodlerHandler, operator_registry)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (generate_resharding_costs, generate_sharding_spec)
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import OuputHandler, PlacehodlerHandler, operator_registry
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.getatrr_handler import GetattrHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
@@ -71,25 +72,8 @@ class StrategiesConstructor:
|
||||
|
||||
# get_attr node
|
||||
if node.op == 'get_attr':
|
||||
# Same as placeholder nodes, if solver_options.fast is True, we just let them in
|
||||
# fully replicate status, then strategies of following node will be treated equally due
|
||||
# to replicate status has no resharding cost to other status. At the same time, the searching
|
||||
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
|
||||
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
|
||||
if self.solver_options.fast:
|
||||
# create sharding strategy for get_attr
|
||||
name = 'Replica Attribute'
|
||||
dim_partition_dict = {}
|
||||
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
|
||||
# TODO: use meta_info_prop to profile memory cost
|
||||
memory_cost = 0
|
||||
sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost)
|
||||
strategies_vector.append(sharding_strategy_attribute)
|
||||
|
||||
# # get_attr node
|
||||
# elif node.op == 'get_attr':
|
||||
# # TODO: implement getattr node handler
|
||||
# pass
|
||||
getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector)
|
||||
getattr_handler.register_strategy()
|
||||
|
||||
# call_module node
|
||||
elif node.op == 'call_module':
|
||||
|
Reference in New Issue
Block a user