mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[autoparallel] update_getattr_handler (#2193)
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
@@ -96,27 +97,23 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
|
||||
# to the same strategy of the user node.
|
||||
if node.op == 'get_attr':
|
||||
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
|
||||
new_sharding_spec = target_sharding_specs[0]
|
||||
user_strategy = node.strategies_vector.successor_nodes[0].best_strategy
|
||||
op_data_in_user = user_strategy.get_op_data_by_name(str(node))
|
||||
origin_node_sharding_spec_dict[index] = new_sharding_spec
|
||||
target_node = node.strategies_vector.successor_nodes[0]
|
||||
node_name = str(node)
|
||||
if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP:
|
||||
node_name = str(target_node)
|
||||
target_node = target_node.strategies_vector.successor_nodes[0]
|
||||
user_strategy = target_node.best_strategy
|
||||
op_data_in_user = user_strategy.get_op_data_by_name(node_name)
|
||||
origin_pending_strategy = node.best_strategy
|
||||
origin_op_data = origin_pending_strategy.get_op_data_by_name(str(node))
|
||||
new_sharding_specs = origin_pending_strategy.sharding_specs
|
||||
new_sharding_specs[origin_op_data] = new_sharding_spec
|
||||
|
||||
new_communication_actions = {}
|
||||
if op_data_in_user in user_strategy.communication_actions:
|
||||
new_communication_action = user_strategy.communication_actions.pop(op_data_in_user)
|
||||
new_communication_action.arg_index = 0
|
||||
new_communication_actions[origin_op_data] = new_communication_action
|
||||
new_strategy = ShardingStrategy(name=str(new_sharding_spec.sharding_sequence),
|
||||
sharding_specs=new_sharding_specs,
|
||||
compute_cost=origin_pending_strategy.compute_cost,
|
||||
communication_cost=origin_pending_strategy.communication_cost,
|
||||
memory_cost=origin_pending_strategy.memory_cost,
|
||||
communication_actions=new_communication_actions)
|
||||
setattr(node, 'best_strategy', new_strategy)
|
||||
setattr(node, 'sharding_spec', new_sharding_spec)
|
||||
node.best_strategy.communication_actions = new_communication_actions
|
||||
|
||||
comm_action_dict = {}
|
||||
for op_data, comm_action in node.best_strategy.communication_actions.items():
|
||||
comm_action_dict[op_data.name] = comm_action
|
||||
|
@@ -86,12 +86,7 @@ class NodeHandler(ABC):
|
||||
if prev_sharding_spec is None:
|
||||
return TrainCycleItem(fwd=0, bwd=0, total=0)
|
||||
elif isinstance(prev_sharding_spec, ShardingSpec):
|
||||
if isinstance(data, torch.nn.parameter.Parameter):
|
||||
# we won't compute the resharding cost for the parameters,
|
||||
# since the parameters will be sharded before runtime and
|
||||
# not converted during runtime.
|
||||
return TrainCycleItem(fwd=0, bwd=0, total=0)
|
||||
elif isinstance(data, torch.Tensor):
|
||||
if isinstance(data, torch.Tensor):
|
||||
dtype = data.dtype
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
_, _, consistency_cost = shape_consistency_manager.shape_consistency(
|
||||
|
@@ -1,6 +1,12 @@
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (
|
||||
enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
ignore_sharding_exception,
|
||||
)
|
||||
from colossalai.tensor.sharding_spec import ShardingSpecException
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
|
||||
@@ -37,17 +43,47 @@ class GetattrGenerator(StrategyGenerator):
|
||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
strategy.memory_cost = memory_cost
|
||||
|
||||
@ignore_sharding_exception
|
||||
def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
|
||||
# we check for the output logical shape to get the number of dimensions
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['output'].logical_shape)
|
||||
|
||||
# enumerate all the 2D sharding cases
|
||||
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
|
||||
dim_partition_list.extend(sharding_list_2d)
|
||||
|
||||
# enumerate all the 1D sharding cases
|
||||
sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
|
||||
dim_partition_list.extend(sharding_list_1d_on_dim_0)
|
||||
sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
|
||||
dim_partition_list.extend(sharding_list_1d_on_dim_1)
|
||||
|
||||
# add empty dict for fully replicated case
|
||||
dim_partition_list.append({})
|
||||
|
||||
# sharding strategy bookkeeping
|
||||
strategy_list = []
|
||||
|
||||
# convert these dim partition dict to sharding strategy
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
dim_partition_dict_mapping = dict(output=dim_partition_dict)
|
||||
|
||||
try:
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
communication_action_mapping = {}
|
||||
|
||||
# get name
|
||||
name = f"get_attr {sharding_spec_mapping['output'].sharding_sequence}"
|
||||
sharding_strategy = self.get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
strategy_list.append(sharding_strategy)
|
||||
except ShardingSpecException:
|
||||
continue
|
||||
|
||||
return strategy_list
|
||||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
dim_partition_dict_mapping = {
|
||||
"output": {},
|
||||
}
|
||||
communication_action_mapping = {}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
name = 'Replica Attribute'
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
return [strategy]
|
||||
return self.enumerate_all_possible_output(0, 1)
|
||||
|
Reference in New Issue
Block a user