mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-12-17 17:42:45 +00:00
[autoparallel] introduced baseclass for op handler and reduced code redundancy (#1471)
* [autoparallel] introduced baseclass for op handler and reduced code redundancy * polish code
This commit is contained in:
45
colossalai/auto_parallel/solver/operator_handler.py
Normal file
45
colossalai/auto_parallel/solver/operator_handler.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from torch.fx.node import Node
|
||||
import torch.nn as nn
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from .sharding_strategy import StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
class OperatorHanlder(ABC):
|
||||
'''
|
||||
The OperatorHanlder is an abstract class used to generate every possible strategies for a operator node.
|
||||
|
||||
Argument:
|
||||
input_node(Node): the input node in node argument list.
|
||||
input_index(int): the index of input node in the node argument list.
|
||||
weight(torch.Tensor): Weight of the node.
|
||||
output_node(Node): Output_node is the output of the node.
|
||||
device_mesh(DeviceMesh): A logical view of a physical mesh.
|
||||
strategies_vector(StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
||||
shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs.
|
||||
'''
|
||||
|
||||
def __init__(self, input_node: Node, input_index: int, weight: nn.Parameter, output_node: Node,
|
||||
device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
|
||||
shape_consistency_manager: ShapeConsistencyManager):
|
||||
self.input_node = input_node
|
||||
self.input_data = self.input_node._meta_data
|
||||
self.weight = weight
|
||||
self.input_index = input_index
|
||||
self.output_node = output_node
|
||||
self.output = self.output_node._meta_data
|
||||
self.device_mesh = device_mesh
|
||||
self.strategies_vector = strategies_vector
|
||||
self.shape_consistency_manager = shape_consistency_manager
|
||||
|
||||
@abstractmethod
|
||||
def register_strategy_into_strategies_vector(self):
|
||||
pass
|
||||
|
||||
def _generate_sharding_spec(self, tensor, dim_partition_dict):
|
||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=tensor.shape,
|
||||
dim_partition_dict=dim_partition_dict)
|
||||
return sharding_spec
|
||||
Reference in New Issue
Block a user