mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-14 06:05:26 +00:00
[autoparallel] added sharding spec conversion for linear handler (#1687)
This commit is contained in:
parent
af718e83f2
commit
4973157ad7
@ -1,10 +1,13 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingException
|
||||||
from .node_handler import ModuleHandler, NodeHandler
|
from .node_handler import ModuleHandler, NodeHandler
|
||||||
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData
|
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData
|
||||||
from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator_V2, BatchedMatMulStrategyGenerator
|
from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator_V2, BatchedMatMulStrategyGenerator
|
||||||
from typing import List, Dict
|
from typing import List, Dict, Union
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
|
from copy import deepcopy
|
||||||
|
from .utils import switch_partition_dim, update_partition_dim
|
||||||
|
|
||||||
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler']
|
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler']
|
||||||
|
|
||||||
@ -24,14 +27,22 @@ class LinearModuleHandler(ModuleHandler):
|
|||||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||||
# use transposed shape for strategies
|
# use transposed shape for strategies
|
||||||
# the strategies will be transformed back to its original shape in self.post_process
|
# the strategies will be transformed back to its original shape in self.post_process
|
||||||
|
input_meta_data = self.node.args[0]._meta_data
|
||||||
|
input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape
|
||||||
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||||||
type=OperationDataType.ARG,
|
type=OperationDataType.ARG,
|
||||||
data=self.node.args[0]._meta_data)
|
data=input_meta_data,
|
||||||
|
logical_shape=input_logical_shape)
|
||||||
physical_other_operand = OperationData(name="weight",
|
physical_other_operand = OperationData(name="weight",
|
||||||
type=OperationDataType.PARAM,
|
type=OperationDataType.PARAM,
|
||||||
data=self.named_parameters['weight'],
|
data=self.named_parameters['weight'],
|
||||||
logical_shape=self.named_parameters['weight'].shape[::-1])
|
logical_shape=self.named_parameters['weight'].shape[::-1])
|
||||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
output_meta_data = self.node._meta_data
|
||||||
|
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
|
||||||
|
physical_output = OperationData(name=str(self.node),
|
||||||
|
type=OperationDataType.OUTPUT,
|
||||||
|
data=output_meta_data,
|
||||||
|
logical_shape=output_logical_shape)
|
||||||
|
|
||||||
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||||
|
|
||||||
@ -42,28 +53,46 @@ class LinearModuleHandler(ModuleHandler):
|
|||||||
mapping['bias'] = physical_bias_operand
|
mapping['bias'] = physical_bias_operand
|
||||||
return mapping
|
return mapping
|
||||||
|
|
||||||
def post_process(self, strategy: ShardingStrategy_V2):
|
def post_process(self, strategy: ShardingStrategy_V2) -> Union[ShardingStrategy_V2, List[ShardingStrategy_V2]]:
|
||||||
"""
|
"""
|
||||||
Convert the sharding spec of the weight parameter back to its original shape.
|
Convert the sharding spec from the logical shape to the physical shape.
|
||||||
"""
|
"""
|
||||||
|
# switch the dimensions of the transposed weight
|
||||||
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
||||||
if op_data.name == "weight":
|
if op_data.name == "weight":
|
||||||
assert op_data.logical_shape != op_data.data.shape
|
assert op_data.logical_shape != op_data.data.shape
|
||||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
switch_partition_dim(sharding_spec, 0, -1)
|
||||||
|
|
||||||
# switch first and last dim of the linear module weight
|
# create multiple sharding strategies for the inputs
|
||||||
first_dim_partition = dim_partition_dict.pop(-1, None)
|
# as input can be multi-dimensinal and the partition dim is only 2D,
|
||||||
last_dim_partition = dim_partition_dict.pop(0, None)
|
# we need to map the partition at dim 0 to one of the first few dimensions of the input
|
||||||
|
sharding_strategies = []
|
||||||
|
input_op_data = strategy.get_op_data_by_name(str(self.node.args[0]))
|
||||||
|
output_op_data = strategy.get_op_data_by_name(str(self.node))
|
||||||
|
num_input_dims = input_op_data.data.dim()
|
||||||
|
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
|
||||||
|
|
||||||
if first_dim_partition:
|
if 0 in input_sharding_spec.dim_partition_dict:
|
||||||
dim_partition_dict[0] = first_dim_partition
|
for i in range(num_input_dims - 1):
|
||||||
|
new_strategy = strategy.clone()
|
||||||
|
input_sharding_spec = new_strategy.get_sharding_spec_by_name(input_op_data.name)
|
||||||
|
output_sharding_spec = new_strategy.get_sharding_spec_by_name(output_op_data.name)
|
||||||
|
try:
|
||||||
|
update_partition_dim(sharding_spec=input_sharding_spec,
|
||||||
|
dim_mapping={0: i},
|
||||||
|
physical_shape=input_op_data.data.shape,
|
||||||
|
inplace=True)
|
||||||
|
update_partition_dim(sharding_spec=output_sharding_spec,
|
||||||
|
dim_mapping={0: i},
|
||||||
|
physical_shape=output_op_data.data.shape,
|
||||||
|
inplace=True)
|
||||||
|
sharding_strategies.append(new_strategy)
|
||||||
|
except ShardingException:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
sharding_strategies.append(strategy)
|
||||||
|
|
||||||
if last_dim_partition:
|
return sharding_strategies
|
||||||
dim_partition_dict[-1] = last_dim_partition
|
|
||||||
|
|
||||||
# re-init the sharding spec
|
|
||||||
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
|
||||||
return strategy
|
|
||||||
|
|
||||||
|
|
||||||
@operator_registry.register(F.linear)
|
@operator_registry.register(F.linear)
|
||||||
@ -118,20 +147,37 @@ class LinearFunctionHandler(NodeHandler):
|
|||||||
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
||||||
if op_data.name == str(self.node.args[1]):
|
if op_data.name == str(self.node.args[1]):
|
||||||
assert op_data.logical_shape != op_data.data.shape
|
assert op_data.logical_shape != op_data.data.shape
|
||||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
switch_partition_dim(sharding_spec, 0, -1)
|
||||||
|
|
||||||
# switch first and last dim of the linear module weight
|
# create multiple sharding strategies for the inputs
|
||||||
first_dim_partition = dim_partition_dict.pop(-1, None)
|
# as input can be multi-dimensinal and the partition dim is only 2D,
|
||||||
last_dim_partition = dim_partition_dict.pop(0, None)
|
# we need to map the partition at dim 0 to one of the first few dimensions of the input
|
||||||
|
sharding_strategies = []
|
||||||
|
input_op_data = strategy.get_op_data_by_name(str(self.node.args[0]))
|
||||||
|
output_op_data = strategy.get_op_data_by_name(str(self.node))
|
||||||
|
num_input_dims = input_op_data.data.dim()
|
||||||
|
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
|
||||||
|
|
||||||
if first_dim_partition:
|
if 0 in input_sharding_spec.dim_partition_dict:
|
||||||
dim_partition_dict[0] = first_dim_partition
|
for i in range(num_input_dims - 1):
|
||||||
|
new_strategy = strategy.clone()
|
||||||
|
input_sharding_spec = new_strategy.get_sharding_spec_by_name(input_op_data.name)
|
||||||
|
output_sharding_spec = new_strategy.get_sharding_spec_by_name(output_op_data.name)
|
||||||
|
try:
|
||||||
|
update_partition_dim(sharding_spec=input_sharding_spec,
|
||||||
|
dim_mapping={0: i},
|
||||||
|
physical_shape=input_op_data.data.shape,
|
||||||
|
inplace=True)
|
||||||
|
update_partition_dim(sharding_spec=output_sharding_spec,
|
||||||
|
dim_mapping={0: i},
|
||||||
|
physical_shape=output_op_data.data.shape,
|
||||||
|
inplace=True)
|
||||||
|
sharding_strategies.append(new_strategy)
|
||||||
|
except ShardingException:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
sharding_strategies.append(strategy)
|
||||||
|
|
||||||
if last_dim_partition:
|
|
||||||
dim_partition_dict[-1] = last_dim_partition
|
|
||||||
|
|
||||||
# re-init the sharding spec
|
|
||||||
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
|
||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
|||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Union
|
||||||
from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData, TrainCycleItem
|
from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData, TrainCycleItem
|
||||||
from ..strategy import StrategyGenerator_V2
|
from ..strategy import StrategyGenerator_V2
|
||||||
|
|
||||||
@ -72,17 +72,27 @@ class NodeHandler(ABC):
|
|||||||
for generator in strategy_generators:
|
for generator in strategy_generators:
|
||||||
strategies = generator.generate()
|
strategies = generator.generate()
|
||||||
|
|
||||||
|
# postprocess a strategy
|
||||||
|
# postprocess can produce one strategy or multiple strategies
|
||||||
|
post_processed_strategies_map = map(self.post_process, strategies)
|
||||||
|
post_processed_strategies = []
|
||||||
|
|
||||||
|
for strategy in post_processed_strategies_map:
|
||||||
|
if isinstance(strategy, (list, tuple)):
|
||||||
|
post_processed_strategies.extend(strategy)
|
||||||
|
else:
|
||||||
|
post_processed_strategies.append(strategy)
|
||||||
|
|
||||||
# compute the resharding costs based on the previous node
|
# compute the resharding costs based on the previous node
|
||||||
# strategies if specified
|
# strategies if specified
|
||||||
if compute_resharding_cost:
|
if compute_resharding_cost:
|
||||||
strategies = list(map(self.update_resharding_cost, strategies))
|
post_processed_strategies = list(map(self.update_resharding_cost, post_processed_strategies))
|
||||||
self.strategies_vector.extend(strategies)
|
|
||||||
|
self.strategies_vector.extend(post_processed_strategies)
|
||||||
|
|
||||||
strategies_vector = map(self.post_process, self.strategies_vector)
|
|
||||||
self.strategies_vector = list(strategies_vector)
|
|
||||||
return self.strategies_vector
|
return self.strategies_vector
|
||||||
|
|
||||||
def post_process(self, strategy: ShardingStrategy_V2):
|
def post_process(self, strategy: ShardingStrategy_V2) -> Union[ShardingStrategy_V2, List[ShardingStrategy_V2]]:
|
||||||
# tranform the strategy generated
|
# tranform the strategy generated
|
||||||
# e.g. to process the sharding strategy for the transposed weights
|
# e.g. to process the sharding strategy for the transposed weights
|
||||||
return strategy
|
return strategy
|
||||||
|
68
colossalai/auto_parallel/solver/op_handler/utils.py
Normal file
68
colossalai/auto_parallel/solver/op_handler/utils.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Dict
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
|
def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
|
||||||
|
"""
|
||||||
|
Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sharding_spec (ShardingSpec): the sharding spec for which partition dim are switched
|
||||||
|
dim1 (int): the tensor dimension to switch
|
||||||
|
dim2 (int): the tensor dimension to switch
|
||||||
|
"""
|
||||||
|
assert len(sharding_spec.entire_shape) == 2
|
||||||
|
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||||
|
dim1_partition = dim_partition_dict.pop(dim1, None)
|
||||||
|
dim2_partition = dim_partition_dict.pop(dim2, None)
|
||||||
|
|
||||||
|
if dim1_partition:
|
||||||
|
dim_partition_dict[dim2] = dim1_partition
|
||||||
|
|
||||||
|
if dim2_partition:
|
||||||
|
dim_partition_dict[dim1] = dim2_partition
|
||||||
|
|
||||||
|
# re-init the sharding spec
|
||||||
|
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
||||||
|
return sharding_spec
|
||||||
|
|
||||||
|
|
||||||
|
def update_partition_dim(sharding_spec: ShardingSpec,
|
||||||
|
dim_mapping: Dict[int, int],
|
||||||
|
physical_shape: torch.Size,
|
||||||
|
inplace: bool = False):
|
||||||
|
"""
|
||||||
|
This method is used to update the partition dim dict from the logical one to the physical one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sharding_spec (ShardingSpec): the sharding spec for which partition dims are updated
|
||||||
|
dim_mapping (Dict[int, int]): the mapping from the logical tensor dimension to the physical tensor dimension
|
||||||
|
physical_shape (torch.Size): the physical shape for the tensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
if inplace:
|
||||||
|
current_sharding_spec = sharding_spec
|
||||||
|
else:
|
||||||
|
current_sharding_spec = deepcopy(sharding_spec)
|
||||||
|
|
||||||
|
old_dim_partition_dict = current_sharding_spec.dim_partition_dict
|
||||||
|
new_dim_partition_dict = {}
|
||||||
|
|
||||||
|
# assign new dim
|
||||||
|
for old_dim, new_dim in dim_mapping.items():
|
||||||
|
mesh_dims = old_dim_partition_dict.pop(old_dim)
|
||||||
|
new_dim_partition_dict[new_dim] = mesh_dims
|
||||||
|
|
||||||
|
for tensor_dim, mesh_dims in old_dim_partition_dict.items():
|
||||||
|
if tensor_dim in new_dim_partition_dict:
|
||||||
|
raise KeyError(f"There are duplicated entries for the tensor sharding dimension {tensor_dim}")
|
||||||
|
else:
|
||||||
|
new_dim_partition_dict[tensor_dim] = mesh_dims
|
||||||
|
|
||||||
|
# update sharding spec
|
||||||
|
current_sharding_spec.__init__(device_mesh=sharding_spec.device_mesh,
|
||||||
|
entire_shape=physical_shape,
|
||||||
|
dim_partition_dict=new_dim_partition_dict)
|
||||||
|
return current_sharding_spec
|
@ -1,3 +1,4 @@
|
|||||||
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -121,16 +122,12 @@ class ShardingStrategy_V2:
|
|||||||
communication_cost (TrainCycleItem): Communication cost to complete this strategy. (default to None)
|
communication_cost (TrainCycleItem): Communication cost to complete this strategy. (default to None)
|
||||||
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
|
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
|
||||||
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
|
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
|
||||||
input_resharding_costs (Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
|
||||||
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
|
||||||
strategy.(default to None)
|
|
||||||
"""
|
"""
|
||||||
name: str
|
name: str
|
||||||
sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None
|
sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None
|
||||||
compute_cost: TrainCycleItem = None
|
compute_cost: TrainCycleItem = None
|
||||||
communication_cost: TrainCycleItem = None
|
communication_cost: TrainCycleItem = None
|
||||||
memory_cost: TrainCycleItem = None
|
memory_cost: TrainCycleItem = None
|
||||||
input_resharding_costs: Dict[OperationData, List[float]] = None
|
|
||||||
communication_actions: Dict[OperationData, CommSpec] = None
|
communication_actions: Dict[OperationData, CommSpec] = None
|
||||||
resharding_costs: Dict[OperationData, Dict[ShardingSpec, TrainCycleItem]] = None
|
resharding_costs: Dict[OperationData, Dict[ShardingSpec, TrainCycleItem]] = None
|
||||||
|
|
||||||
@ -169,6 +166,26 @@ class ShardingStrategy_V2:
|
|||||||
return sharding_spec
|
return sharding_spec
|
||||||
raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}")
|
raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}")
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
|
||||||
|
def _deepcopy_dict_vals(data: Dict):
|
||||||
|
return {k: deepcopy(v) for k, v in data.items()}
|
||||||
|
|
||||||
|
sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs else None
|
||||||
|
communication_actions = _deepcopy_dict_vals(self.communication_actions) if self.communication_actions else None
|
||||||
|
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs else None
|
||||||
|
compute_cost = deepcopy(self.compute_cost)
|
||||||
|
communication_cost = deepcopy(self.communication_cost)
|
||||||
|
memory_cost = deepcopy(self.memory_cost)
|
||||||
|
|
||||||
|
return ShardingStrategy_V2(name=self.name,
|
||||||
|
sharding_specs=sharding_specs,
|
||||||
|
compute_cost=compute_cost,
|
||||||
|
communication_cost=communication_cost,
|
||||||
|
memory_cost=memory_cost,
|
||||||
|
communication_actions=communication_actions,
|
||||||
|
resharding_costs=resharding_costs)
|
||||||
|
|
||||||
|
|
||||||
class StrategiesVector(list):
|
class StrategiesVector(list):
|
||||||
'''
|
'''
|
||||||
|
@ -6,6 +6,8 @@ from enum import Enum
|
|||||||
from functools import reduce
|
from functools import reduce
|
||||||
import operator
|
import operator
|
||||||
|
|
||||||
|
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
|
||||||
|
|
||||||
ALLGATHER_COST = 20
|
ALLGATHER_COST = 20
|
||||||
SHARD_COST = 5
|
SHARD_COST = 5
|
||||||
STEP_PENALTY = 6
|
STEP_PENALTY = 6
|
||||||
@ -136,6 +138,10 @@ class _DimSpec:
|
|||||||
return difference
|
return difference
|
||||||
|
|
||||||
|
|
||||||
|
class ShardingException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ShardingSpec:
|
class ShardingSpec:
|
||||||
'''
|
'''
|
||||||
Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
|
Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
|
||||||
|
@ -3,14 +3,15 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||||
from colossalai.auto_parallel.solver.op_handler.dot_handler_v2 import LinearModuleHandler, LinearFunctionHandler
|
from colossalai.auto_parallel.solver.op_handler.dot_handler_v2 import LinearModuleHandler, LinearFunctionHandler
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector, ShardingStrategy_V2
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
|
||||||
def test_linear_module_handler():
|
def test_linear_module_handler():
|
||||||
model = nn.Sequential(nn.Linear(16, 32).to('meta'))
|
model = nn.Sequential(nn.Linear(16, 32).to('meta'))
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
|
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
|
||||||
gm = ColoGraphModule(model, graph)
|
gm = ColoGraphModule(model, graph)
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
|
||||||
@ -34,9 +35,9 @@ def test_linear_module_handler():
|
|||||||
|
|
||||||
assert mapping['input'].name == "input_1"
|
assert mapping['input'].name == "input_1"
|
||||||
assert mapping['input'].data.is_meta
|
assert mapping['input'].data.is_meta
|
||||||
assert mapping['input'].data.shape == torch.Size([4, 16])
|
assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16])
|
||||||
assert mapping['input'].type == OperationDataType.ARG
|
assert mapping['input'].type == OperationDataType.ARG
|
||||||
assert mapping['input'].logical_shape == torch.Size([4, 16])
|
assert mapping['input'].logical_shape == torch.Size([16, 16])
|
||||||
|
|
||||||
assert mapping['other'].name == "weight"
|
assert mapping['other'].name == "weight"
|
||||||
assert mapping['other'].data.is_meta
|
assert mapping['other'].data.is_meta
|
||||||
@ -52,11 +53,14 @@ def test_linear_module_handler():
|
|||||||
|
|
||||||
assert mapping['output'].name == "_0"
|
assert mapping['output'].name == "_0"
|
||||||
assert mapping['output'].data.is_meta
|
assert mapping['output'].data.is_meta
|
||||||
assert mapping['output'].data.shape == torch.Size([4, 32])
|
assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32])
|
||||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
|
assert mapping['output'].logical_shape == torch.Size([16, 32])
|
||||||
|
|
||||||
strategies_vector = handler.register_strategy()
|
strategies_vector = handler.register_strategy()
|
||||||
strategy_name_list = [val.name for val in strategies_vector]
|
strategy_name_list = [val.name for val in strategies_vector]
|
||||||
|
# one strategy will be converted to different physical sharding spec
|
||||||
|
assert len(strategy_name_list) > 8
|
||||||
|
|
||||||
# SS = SR x RS
|
# SS = SR x RS
|
||||||
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
||||||
@ -78,6 +82,19 @@ def test_linear_module_handler():
|
|||||||
assert 'RS0 = RR x RS0' in strategy_name_list
|
assert 'RS0 = RR x RS0' in strategy_name_list
|
||||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||||
|
|
||||||
|
for strategy in strategies_vector:
|
||||||
|
strategy: ShardingStrategy_V2
|
||||||
|
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
||||||
|
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
|
||||||
|
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||||
|
output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
|
||||||
|
|
||||||
|
# make sure the sharding matches across different operation data
|
||||||
|
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
|
||||||
|
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||||
|
assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1]
|
||||||
|
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||||
|
|
||||||
|
|
||||||
def test_linear_function_handler():
|
def test_linear_function_handler():
|
||||||
model = nn.Linear(16, 32).to('meta')
|
model = nn.Linear(16, 32).to('meta')
|
||||||
@ -123,6 +140,8 @@ def test_linear_function_handler():
|
|||||||
|
|
||||||
strategies_vector = handler.register_strategy()
|
strategies_vector = handler.register_strategy()
|
||||||
strategy_name_list = [val.name for val in strategies_vector]
|
strategy_name_list = [val.name for val in strategies_vector]
|
||||||
|
# one strategy will be converted to different physical sharding spec
|
||||||
|
assert len(strategy_name_list) > 8
|
||||||
|
|
||||||
# SS = SR x RS
|
# SS = SR x RS
|
||||||
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
||||||
@ -144,6 +163,19 @@ def test_linear_function_handler():
|
|||||||
assert 'RS0 = RR x RS0' in strategy_name_list
|
assert 'RS0 = RR x RS0' in strategy_name_list
|
||||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||||
|
|
||||||
|
for strategy in strategies_vector:
|
||||||
|
strategy: ShardingStrategy_V2
|
||||||
|
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
||||||
|
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
|
||||||
|
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||||
|
output_sharding_spec = strategy.get_sharding_spec_by_name('linear')
|
||||||
|
|
||||||
|
# make sure the sharding matches across different operation data
|
||||||
|
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
|
||||||
|
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||||
|
assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1]
|
||||||
|
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_linear_module_handler()
|
test_linear_module_handler()
|
||||||
|
Loading…
Reference in New Issue
Block a user