[autoparallel] adapt solver with gpt (#1653)

This commit is contained in:
YuliangLiu0306
2022-09-28 11:17:26 +08:00
committed by GitHub
parent c638bec028
commit 1e7816a460
6 changed files with 164 additions and 18 deletions

View File

@@ -24,7 +24,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
"""
if isinstance(input_, Node):
assert hasattr(input_, '_meta_data'), f'The given node has not attribte _meta_data'
assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data'
meta_tensor = input_._meta_data
assert meta_tensor is not None, "The given node's _meta_data attribute is None"
shape = meta_tensor.shape
@@ -47,7 +47,8 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
def generate_resharding_costs(nodes: List[Node],
sharding_specs: List[ShardingSpec],
count_backward: Optional[bool] = True,
dtype: Optional[torch.dtype] = None):
dtype: Optional[torch.dtype] = None,
index=None):
'''
Compute the resharding costs with this specific strategy.
@@ -68,6 +69,9 @@ def generate_resharding_costs(nodes: List[Node],
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
if not isinstance(input_sharding_spec, ShardingSpec):
assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.'
input_sharding_spec = input_sharding_spec[index]
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
try:
# compute the resharding cost

View File

@@ -9,12 +9,22 @@ __all__ = [
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
torch.abs, torch.cos, torch.exp, operator.neg, torch.multiply, torch.nn.functional.relu,
torch.nn.functional.dropout, torch.flatten
torch.abs,
torch.cos,
torch.exp,
operator.neg,
torch.multiply,
torch.nn.functional.relu,
torch.nn.functional.dropout,
torch.flatten,
# softmax should not be here
torch.nn.functional.softmax
]
ELEMENTWISE_METHOD_OP = [
torch.Tensor.to,
torch.Tensor.type,
# TODO: contiguous maybe need some extra processes.
torch.Tensor.contiguous
]
RESHAPE_FUNC_OP = [torch.flatten, torch.reshape]
RESHAPE_METHOD_OP = [
@@ -26,7 +36,7 @@ RESHAPE_METHOD_OP = [
]
BCAST_FUNC_OP = [
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
operator.mul, operator.floordiv, operator.truediv, torch.matmul
operator.mul, operator.floordiv, operator.truediv, torch.matmul, torch.where, operator.pow, torch.pow, torch.tanh
]
CONV_MODULE_OP = [
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
@@ -41,6 +51,34 @@ LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
LAYERNORM_MODULE_OP = [torch.nn.LayerNorm]
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
NON_PARAM_FUNC_OP = RESHAPE_FUNC_OP + ELEMENTWISE_FUNC_OP
NON_PARAM_FUNC_OP = [
torch.flatten,
torch.reshape,
torch.abs,
torch.cos,
torch.exp,
operator.neg,
torch.multiply,
torch.nn.functional.relu,
torch.nn.functional.dropout,
torch.flatten,
torch.where,
operator.pow,
torch.pow,
torch.tanh,
torch.add,
torch.sub,
torch.mul,
torch.div,
torch.floor_divide,
torch.true_divide,
operator.add,
operator.sub,
operator.mul,
operator.floordiv,
operator.truediv,
# softmax should not be here
torch.nn.functional.softmax
]
INFINITY_COST = 1e13

View File

@@ -431,7 +431,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute computation cost
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
@@ -473,7 +473,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
@@ -510,7 +510,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
@@ -548,7 +548,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim]
@@ -583,7 +583,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim]
@@ -619,7 +619,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
@@ -655,7 +655,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
@@ -692,7 +692,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]

View File

@@ -121,7 +121,13 @@ class OperatorHandler(ABC):
def _generate_resharding_costs(self, sharding_specs):
# The resharding_cost of weight is counted due to sharing weight cases.
dtype = self.node._meta_data.dtype
if hasattr(self.node._meta_data, 'dtype'):
dtype = self.node._meta_data.dtype
else:
assert isinstance(self.node._meta_data,
tuple), f'Only torch.Tensor, torch.fx.Node and tuple of torch.Tensor is expected'
dtype = self.node._meta_data[0].dtype
nodes = self.predecessor_node
return generate_resharding_costs(nodes=nodes,
sharding_specs=sharding_specs,

View File

@@ -1,9 +1,14 @@
import colorsys
from .operator_handler import OperatorHandler
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from copy import deepcopy
import math
from colossalai.auto_parallel.solver._utils import exception_handler
import warnings
import torch
from ..constants import INFINITY_COST
class ReshapeHandler(OperatorHandler):
@@ -19,6 +24,7 @@ class ReshapeHandler(OperatorHandler):
def _generate_compute_cost(self, *args, **kwargs):
return super()._generate_compute_cost(*args, **kwargs)
@exception_handler
def register_strategy(self):
# TODO: add strategies with more output sharding specs other than only fully replicated.
input_node = self.strategies_vector.predecessor_nodes[0]
@@ -37,11 +43,23 @@ class ReshapeHandler(OperatorHandler):
continue
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict_for_output = {}
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
if isinstance(self.output_data, tuple):
dim_partition_dict_for_output = [{} for _ in range(len(self.output_data))]
try:
if isinstance(self.output_data, tuple):
output_sharding_spec = []
for output, dim_partition_dict in zip(self.output_data, dim_partition_dict_for_output):
output_sharding_spec.append(self._generate_sharding_spec(output, dim_partition_dict))
else:
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
except AssertionError as e:
warnings.warn(f'{e}')
continue
name = f'{input_sharding_spec.sharding_sequence} -> FULLY REPLICATED'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = 0
memory_cost = self.node._meta_data.numel()
# consider node._meta_data is in type of tuple
memory_cost = 0
# compute the communication cost, in reshape op, the communication happens during casting the input sharding spec to fully replicating.
dim_partition_dict_for_replicate_input = {}
@@ -56,7 +74,7 @@ class ReshapeHandler(OperatorHandler):
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_node] = [0 if cost == 0 else math.inf for cost in resharding_costs[input_node]]
resharding_costs[input_node] = [0 if cost == 0 else INFINITY_COST for cost in resharding_costs[input_node]]
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,