ColossalAI/colossalai/auto_parallel/solver/operator_handler.py
2022-08-22 10:32:17 +08:00

66 lines
3.3 KiB
Python

from abc import ABC, abstractmethod
from torch.fx.node import Node
import torch.nn as nn
from colossalai.device.device_mesh import DeviceMesh
from .sharding_strategy import StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
class OperatorHanlder(ABC):
'''
The OperatorHanlder is an abstract class used to generate every possible strategies for a operator node.
Argument:
input_node(Node): the input node in node argument list.
input_index(int): the index of input node in the node argument list.
weight(torch.Tensor): Weight of the node.
output_node(Node): Output_node is the output of the node.
device_mesh(DeviceMesh): A logical view of a physical mesh.
strategies_vector(StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs.
'''
def __init__(self, input_node: Node, input_index: int, weight: nn.Parameter, output_node: Node,
device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
shape_consistency_manager: ShapeConsistencyManager):
self.input_node = input_node
self.input_data = self.input_node._meta_data
self.weight = weight
self.input_index = input_index
self.output_node = output_node
self.output = self.output_node._meta_data
self.device_mesh = device_mesh
self.strategies_vector = strategies_vector
self.shape_consistency_manager = shape_consistency_manager
@abstractmethod
def register_strategy_into_strategies_vector(self):
pass
def _generate_sharding_spec(self, tensor, dim_partition_dict):
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=tensor.shape,
dim_partition_dict=dim_partition_dict)
return sharding_spec
def _generate_resharding_costs(self, resharding_costs, sharding_spec_for_input):
'''
Compute the resharding costs with this specific strategy.
Note: The resharding_cost of weight is NOT counted.
Argument:
resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary.
Resharding_cost[i][j] means the cost of i-th argument in the output node argument list
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
strategy.
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
'''
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs[self.input_index] = []
for stategy in self.input_node.strategies_vector.strategies:
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(stategy, sharding_spec_for_input)
resharding_costs[self.input_index].append(resharding_cost)
return resharding_cost