mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[autoparallel] add getattr handler (#1767)
* [autoparallel] add getattr haandler * polish code * add extra processes for Parameters * add unit test for param resharding cost * add docstring and polish test
This commit is contained in:
@@ -2,6 +2,7 @@ from .batch_norm_handler import BatchNormModuleHandler
|
||||
from .binary_elementwise_handler import BinaryElementwiseHandler
|
||||
from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
|
||||
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
|
||||
from .getatrr_handler import GetattrHandler
|
||||
from .layer_norm_handler import LayerNormModuleHandler
|
||||
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from .matmul_handler import MatMulHandler
|
||||
|
@@ -0,0 +1,34 @@
|
||||
from typing import Dict, List
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import NodeHandler
|
||||
from .strategy import GetattrGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['GetattrHandler']
|
||||
|
||||
|
||||
class GetattrHandler(NodeHandler):
|
||||
"""
|
||||
A GetattrHandler which deals with the sharding strategies for Getattr Node.
|
||||
"""
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(GetattrGenerator(op_data_mapping, self.device_mesh))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# use transposed shape for strategies
|
||||
# the strategies will be transformed back to its original shape in self.post_process
|
||||
|
||||
# There are only two possible types for get_attr node:
|
||||
# 1. torch.Tensor(torch.nn.Parameters or torch.nn.Buffers)
|
||||
# 2. torch.nn.Module
|
||||
# temporarily, we just support first case in Tracer, so we don't have to worry about
|
||||
# issue related to the node._meta_data type.
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||
|
||||
mapping = {"output": physical_output}
|
||||
|
||||
return mapping
|
@@ -6,6 +6,7 @@ from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
@@ -49,6 +50,9 @@ class NodeHandler(ABC):
|
||||
|
||||
for node in self.predecessor_node:
|
||||
node_name = str(node)
|
||||
# get the current sharding spec generated by this node handler
|
||||
op_data = strategy.get_op_data_by_name(node_name)
|
||||
current_sharding_spec = strategy.sharding_specs[op_data]
|
||||
|
||||
# get the sharding specs for this node generated
|
||||
# in its own node handler
|
||||
@@ -59,10 +63,6 @@ class NodeHandler(ABC):
|
||||
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
|
||||
]
|
||||
|
||||
# get the current sharding spec generated by this node handler
|
||||
op_data = strategy.get_op_data_by_name(node_name)
|
||||
current_sharding_spec = strategy.sharding_specs[op_data]
|
||||
|
||||
# create data structrure to store costs
|
||||
if op_data not in resharding_costs:
|
||||
resharding_costs[node] = []
|
||||
@@ -71,11 +71,14 @@ class NodeHandler(ABC):
|
||||
# compute the resharding cost to switch to the sharding spec generated
|
||||
# by the current node handler
|
||||
for prev_sharding_spec in prev_sharding_specs:
|
||||
_, _, resharding_cost = shape_consistency_manager.shape_consistency(prev_sharding_spec,
|
||||
current_sharding_spec)
|
||||
resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"],
|
||||
bwd=resharding_cost["backward"],
|
||||
total=resharding_cost["total"])
|
||||
if op_data.type == OperationDataType.PARAM:
|
||||
resharding_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
|
||||
else:
|
||||
_, _, resharding_cost = shape_consistency_manager.shape_consistency(
|
||||
prev_sharding_spec, current_sharding_spec)
|
||||
resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"],
|
||||
bwd=resharding_cost["backward"],
|
||||
total=resharding_cost["total"])
|
||||
resharding_costs[node].append(resharding_cost)
|
||||
strategy.resharding_costs = resharding_costs
|
||||
return strategy
|
||||
|
@@ -13,6 +13,7 @@ __all__ = ['ReshapeHandler']
|
||||
@operator_registry.register(torch.reshape)
|
||||
@operator_registry.register(torch.flatten)
|
||||
@operator_registry.register(torch.Tensor.permute)
|
||||
@operator_registry.register(torch.Tensor.view)
|
||||
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
|
||||
class ReshapeHandler(NodeHandler):
|
||||
"""
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from .batch_norm_generator import BatchNormStrategyGenerator
|
||||
from .binary_elementwise_generator import BinaryElementwiseStrategyGenerator
|
||||
from .conv_strategy_generator import ConvStrategyGenerator
|
||||
from .getattr_generator import GetattrGenerator
|
||||
from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
|
||||
from .layer_norm_generator import LayerNormGenerator
|
||||
from .matmul_strategy_generator import (
|
||||
@@ -22,5 +23,5 @@ __all__ = [
|
||||
'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator',
|
||||
'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator',
|
||||
'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator',
|
||||
'ReshapeGenerator', 'NormalPoolStrategyGenerator', 'BinaryElementwiseStrategyGenerator'
|
||||
'ReshapeGenerator', 'NormalPoolStrategyGenerator', 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator'
|
||||
]
|
||||
|
@@ -0,0 +1,53 @@
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
|
||||
__all__ = ['GetattrGenerator']
|
||||
|
||||
|
||||
class GetattrGenerator(StrategyGenerator):
|
||||
"""
|
||||
PlaceholderGenerator is a generic class to generate strategies for placeholder node.
|
||||
"""
|
||||
|
||||
def validate(self) -> bool:
|
||||
return super().validate()
|
||||
|
||||
def update_compute_cost(self, strategy: ShardingStrategy):
|
||||
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
|
||||
strategy.compute_cost = compute_cost
|
||||
|
||||
def update_memory_cost(self, strategy: ShardingStrategy):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
'''
|
||||
forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
|
||||
|
||||
# compute fwd cost incurred
|
||||
# fwd_cost = output
|
||||
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()])
|
||||
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0)
|
||||
|
||||
bwd_mem_cost = MemoryCost(activation=0, parameter=0)
|
||||
|
||||
# compute total cost
|
||||
total_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0)
|
||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
strategy.memory_cost = memory_cost
|
||||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
dim_partition_dict_mapping = {
|
||||
"output": {},
|
||||
}
|
||||
communication_action_mapping = {}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
name = 'Replica Attribute'
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
return [strategy]
|
@@ -6,9 +6,10 @@ from typing import Dict, List
|
||||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import (OuputHandler, PlacehodlerHandler, operator_registry)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (generate_resharding_costs, generate_sharding_spec)
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import OuputHandler, PlacehodlerHandler, operator_registry
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.getatrr_handler import GetattrHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
@@ -71,25 +72,8 @@ class StrategiesConstructor:
|
||||
|
||||
# get_attr node
|
||||
if node.op == 'get_attr':
|
||||
# Same as placeholder nodes, if solver_options.fast is True, we just let them in
|
||||
# fully replicate status, then strategies of following node will be treated equally due
|
||||
# to replicate status has no resharding cost to other status. At the same time, the searching
|
||||
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
|
||||
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
|
||||
if self.solver_options.fast:
|
||||
# create sharding strategy for get_attr
|
||||
name = 'Replica Attribute'
|
||||
dim_partition_dict = {}
|
||||
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
|
||||
# TODO: use meta_info_prop to profile memory cost
|
||||
memory_cost = 0
|
||||
sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost)
|
||||
strategies_vector.append(sharding_strategy_attribute)
|
||||
|
||||
# # get_attr node
|
||||
# elif node.op == 'get_attr':
|
||||
# # TODO: implement getattr node handler
|
||||
# pass
|
||||
getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector)
|
||||
getattr_handler.register_strategy()
|
||||
|
||||
# call_module node
|
||||
elif node.op == 'call_module':
|
||||
|
Reference in New Issue
Block a user