mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[autoparallel] refactored shape consistency to remove redundancy (#1591)
* [autoparallel] refactored shape consistency to remove redundancy * polish code * polish code * polish code
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from typing import Union, Dict, List
|
||||
from typing import Union, Dict, List, Optional
|
||||
|
||||
|
||||
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
|
||||
@@ -31,3 +32,45 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
|
||||
|
||||
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
|
||||
return sharding_spec
|
||||
|
||||
|
||||
def generate_resharding_costs(nodes: List[Node],
|
||||
sharding_specs: List[ShardingSpec],
|
||||
count_backward: Optional[bool] = True,
|
||||
dtype: Optional[torch.dtype] = None):
|
||||
'''
|
||||
Compute the resharding costs with this specific strategy.
|
||||
|
||||
Argument:
|
||||
nodes (List[Node]): a list of nodes
|
||||
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
|
||||
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
|
||||
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
|
||||
'''
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
resharding_costs = {}
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# shape consistency manager is a singleton class
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
for input_node, input_spec in zip(nodes, sharding_specs):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
# compute the resharding cost during forward phase
|
||||
_, _, resharding_cost_forward = shape_consistency_manager.shape_consistency(input_sharding_spec, input_spec)
|
||||
|
||||
if count_backward:
|
||||
# In backward phase, we should convert grad with target_spec into input_sharding_spec
|
||||
_, _, resharding_cost_backward = shape_consistency_manager.shape_consistency(
|
||||
input_spec, input_sharding_spec)
|
||||
total_resharding_cost = resharding_cost_forward + resharding_cost_backward
|
||||
else:
|
||||
total_resharding_cost = resharding_cost_forward
|
||||
|
||||
# we need multiply the size of elem dtype to get correct communication cost
|
||||
resharding_cost = total_resharding_cost * size_per_elem_bytes
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
return resharding_costs
|
||||
|
@@ -4,7 +4,6 @@ import warnings
|
||||
import torch
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from .._utils import generate_sharding_spec
|
||||
|
||||
__all__ = ['BatchNormHandler']
|
||||
|
||||
@@ -115,15 +114,13 @@ class BatchNormHandler(OperatorHandler):
|
||||
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
@@ -156,8 +153,7 @@ class BatchNormHandler(OperatorHandler):
|
||||
new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]}
|
||||
new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
# the computation cost is all the same
|
||||
new_compute_cost = compute_cost
|
||||
|
||||
@@ -192,15 +188,13 @@ class BatchNormHandler(OperatorHandler):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
@@ -234,15 +228,13 @@ class BatchNormHandler(OperatorHandler):
|
||||
name = f'RR = RR x R'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
@@ -273,8 +265,7 @@ class BatchNormHandler(OperatorHandler):
|
||||
|
||||
def _construct_batch_sharding_strategies(mesh_dim_list, new_name):
|
||||
dim_partition_dict_for_output = {0: mesh_dim_list}
|
||||
new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# the computation cost is all the same
|
||||
new_compute_cost = compute_cost
|
||||
@@ -332,15 +323,13 @@ class BatchNormHandler(OperatorHandler):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
@@ -374,15 +363,13 @@ class BatchNormHandler(OperatorHandler):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
@@ -416,15 +403,13 @@ class BatchNormHandler(OperatorHandler):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
@@ -459,7 +444,7 @@ class BatchNormHandler(OperatorHandler):
|
||||
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
|
||||
|
||||
Example:
|
||||
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector,
|
||||
norm_handler = BatchNormHandler(node, strategies_vector,
|
||||
self.shape_consistency_manager)
|
||||
norm_handler.register_strategy()
|
||||
for strategy in norm_handler.strategies_vector:
|
||||
|
@@ -4,7 +4,6 @@ import warnings
|
||||
import torch
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from .._utils import generate_sharding_spec
|
||||
|
||||
__all__ = ['ConvHandler']
|
||||
|
||||
@@ -109,15 +108,13 @@ class ConvHandler(OperatorHandler):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
@@ -158,15 +155,13 @@ class ConvHandler(OperatorHandler):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
@@ -205,15 +200,13 @@ class ConvHandler(OperatorHandler):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
@@ -252,15 +245,13 @@ class ConvHandler(OperatorHandler):
|
||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
@@ -296,15 +287,13 @@ class ConvHandler(OperatorHandler):
|
||||
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
@@ -340,15 +329,13 @@ class ConvHandler(OperatorHandler):
|
||||
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
@@ -384,15 +371,13 @@ class ConvHandler(OperatorHandler):
|
||||
name = f'RR = RR x RR'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
@@ -426,15 +411,13 @@ class ConvHandler(OperatorHandler):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
@@ -475,15 +458,13 @@ class ConvHandler(OperatorHandler):
|
||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@@ -3,7 +3,6 @@ import torch
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from functools import reduce
|
||||
from .._utils import generate_sharding_spec
|
||||
|
||||
__all__ = ['DotHandler']
|
||||
|
||||
@@ -29,16 +28,14 @@ class DotHandler(OperatorHandler):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
# linear layer weight is transposed during init
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
@@ -69,17 +66,15 @@ class DotHandler(OperatorHandler):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
# since weight of the linear layer is transposed
|
||||
# the actual dim to be sharded is 1
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
@@ -106,15 +101,13 @@ class DotHandler(OperatorHandler):
|
||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
@@ -141,15 +134,13 @@ class DotHandler(OperatorHandler):
|
||||
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
@@ -176,15 +167,13 @@ class DotHandler(OperatorHandler):
|
||||
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim]}
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
@@ -211,15 +200,13 @@ class DotHandler(OperatorHandler):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
@@ -246,15 +233,13 @@ class DotHandler(OperatorHandler):
|
||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
@@ -281,15 +266,13 @@ class DotHandler(OperatorHandler):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@@ -7,6 +7,7 @@ from typing import Dict, List
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from .._utils import generate_resharding_costs, generate_sharding_spec
|
||||
|
||||
from ..sharding_strategy import StrategiesVector
|
||||
|
||||
@@ -17,24 +18,24 @@ class OperatorHandler(ABC):
|
||||
'''
|
||||
The OperatorHandler is an abstract class used to generate every possible strategies for an operator node.
|
||||
|
||||
Argument:
|
||||
input_node(Node): the input node in node argument list.
|
||||
input_index(int): the index of input node in the node argument list.
|
||||
weight(torch.Tensor): Weight of the node.
|
||||
output_node(Node): Output_node is the output of the node.
|
||||
device_mesh(DeviceMesh): A logical view of a physical mesh.
|
||||
strategies_vector(StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
||||
shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs.
|
||||
Args:
|
||||
node (Node): the input node in node argument list.
|
||||
device_mesh (DeviceMesh): A logical view of a physical mesh.
|
||||
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
||||
handle_backward (Optional[bool]): whether to consider the backward pass. The default value is True. False can be used for inference.
|
||||
'''
|
||||
|
||||
def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
|
||||
shape_consistency_manager: ShapeConsistencyManager):
|
||||
def __init__(self,
|
||||
node: Node,
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_vector: StrategiesVector,
|
||||
handle_backward: bool = True):
|
||||
self.node = node
|
||||
self.predecessor_node = list(node._input_nodes.keys())
|
||||
self.successor_node = list(node.users.keys())
|
||||
self.device_mesh = device_mesh
|
||||
self.strategies_vector = strategies_vector
|
||||
self.shape_consistency_manager = shape_consistency_manager
|
||||
self.handle_backward = handle_backward
|
||||
|
||||
# find the module and its parameters associated with this node
|
||||
# this can be used to compute the compute/communication/sharding cost
|
||||
@@ -102,35 +103,23 @@ class OperatorHandler(ABC):
|
||||
|
||||
return total_memory_cost, activation_memory_cost, weight_memory_cost
|
||||
|
||||
def _generate_resharding_costs(self, sharding_spec_for_input):
|
||||
'''
|
||||
Compute the resharding costs with this specific strategy.
|
||||
|
||||
Note: The resharding_cost of weight is NOT counted.
|
||||
|
||||
Argument:
|
||||
resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary.
|
||||
Resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
||||
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
||||
strategy.
|
||||
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
|
||||
'''
|
||||
def _generate_resharding_costs(self, sharding_specs):
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
resharding_costs = {}
|
||||
dtype = self.node._meta_data.dtype
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
# compute the resharding cost during forward phase
|
||||
_, _, resharding_cost_forward = self.shape_consistency_manager.shape_consistency(
|
||||
input_sharding_spec, input_spec)
|
||||
# In backward phase, we should convert grad with target_spec into input_sharding_spec
|
||||
_, _, resharding_cost_backward = self.shape_consistency_manager.shape_consistency(
|
||||
input_spec, input_sharding_spec)
|
||||
# we need multiply the size of elem dtype to get correct communication cost
|
||||
resharding_cost = (resharding_cost_forward + resharding_cost_backward) * size_per_elem_bytes
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
return resharding_costs
|
||||
nodes = self.predecessor_node
|
||||
return generate_resharding_costs(nodes=nodes,
|
||||
sharding_specs=sharding_specs,
|
||||
count_backward=self.handle_backward,
|
||||
dtype=dtype)
|
||||
|
||||
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
return generate_sharding_spec(input_=input_,
|
||||
device_mesh=self.device_mesh,
|
||||
dim_partition_dict=dim_partition_dict)
|
||||
|
||||
@abstractmethod
|
||||
def _generate_compute_cost(self, *args, **kwargs):
|
||||
"""
|
||||
Compute the flops involved in the node.
|
||||
"""
|
||||
pass
|
||||
|
@@ -11,7 +11,7 @@ import math
|
||||
import torch
|
||||
import operator
|
||||
from typing import Dict, List
|
||||
from ._utils import generate_sharding_spec
|
||||
from ._utils import generate_sharding_spec, generate_resharding_costs
|
||||
|
||||
|
||||
class StrategiesConstructor:
|
||||
@@ -21,12 +21,10 @@ class StrategiesConstructor:
|
||||
Args:
|
||||
graph (Graph): a Graph object used for analysis and strategy generation.
|
||||
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
|
||||
shape_consistency_manager (ShapeConsistencyManager): a ShapeConsistencyManager object to make sure the sharding specs are consistent.
|
||||
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph, device_mesh: DeviceMesh, shape_consistency_manager: ShapeConsistencyManager,
|
||||
solver_options: SolverOptions):
|
||||
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
|
||||
self.graph = graph
|
||||
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
|
||||
self.root_module = self.graph.owning_module
|
||||
@@ -34,27 +32,8 @@ class StrategiesConstructor:
|
||||
self.device_mesh = device_mesh
|
||||
self.leaf_strategies = []
|
||||
self.strategy_map = {}
|
||||
self.shape_consistency_manager = shape_consistency_manager
|
||||
self.solver_options = solver_options
|
||||
|
||||
def _generate_resharding_costs(self, input_nodes, target_sharding_specs):
|
||||
'''
|
||||
Compute the resharding costs with this specific strategy.
|
||||
|
||||
Argument:
|
||||
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
|
||||
'''
|
||||
resharding_costs = {}
|
||||
for input_node, target_sharding_spec in zip(input_nodes, target_sharding_specs):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
|
||||
input_sharding_spec, target_sharding_spec)
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
return resharding_costs
|
||||
|
||||
def remove_duplicated_strategy(self, strategies_vector):
|
||||
'''
|
||||
In build_strategies_and_cost method, we may produce some duplicated strategies.
|
||||
@@ -120,14 +99,13 @@ class StrategiesConstructor:
|
||||
# conv module
|
||||
if submod_type in CONV_MODULE_OP:
|
||||
# use ConvHandler to create sharding strategies for conv module node
|
||||
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector,
|
||||
self.shape_consistency_manager)
|
||||
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
|
||||
conv_handler.register_strategy()
|
||||
|
||||
# linear module
|
||||
elif submod_type in LINEAR_MODULE_OP:
|
||||
# use DotHandler to create sharding strategies for linear module node
|
||||
dot_handler = DotHandler(node, self.device_mesh, strategies_vector, self.shape_consistency_manager)
|
||||
dot_handler = DotHandler(node, self.device_mesh, strategies_vector)
|
||||
dot_handler.register_strategy()
|
||||
|
||||
# element-wise module
|
||||
@@ -158,8 +136,8 @@ class StrategiesConstructor:
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = node._meta_data.numel()
|
||||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_node] = [
|
||||
@@ -214,8 +192,8 @@ class StrategiesConstructor:
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = node._meta_data.numel()
|
||||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
@@ -275,8 +253,8 @@ class StrategiesConstructor:
|
||||
compute_cost = node._meta_data.numel()
|
||||
memory_cost = 0
|
||||
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_node] = [
|
||||
@@ -317,8 +295,8 @@ class StrategiesConstructor:
|
||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[new_input_sharding_spec])
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[new_input_sharding_spec])
|
||||
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
@@ -335,8 +313,8 @@ class StrategiesConstructor:
|
||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
@@ -360,8 +338,8 @@ class StrategiesConstructor:
|
||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_tensor_node] = [
|
||||
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
|
||||
@@ -397,8 +375,8 @@ class StrategiesConstructor:
|
||||
output_sharding_spec = input_sharding_specs
|
||||
# TODO: use meta_info_prop to profile memory cost
|
||||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
input_sharding_specs)
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
input_sharding_specs)
|
||||
|
||||
# clear the resharding cost for the output node
|
||||
# TODO: we may remove this in final version
|
||||
|
Reference in New Issue
Block a user