[autoparallel] update CommSpec (#1667)

This commit is contained in:
YuliangLiu0306 2022-09-29 11:20:59 +08:00 committed by GitHub
parent 247a9dbca9
commit 3f068d1409
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 413 additions and 390 deletions

View File

@ -79,7 +79,7 @@ def generate_resharding_costs(nodes: List[Node],
input_sharding_spec, input_spec) input_sharding_spec, input_spec)
# we need multiply the size of elem dtype to get correct communication cost # we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost * size_per_elem_bytes resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
except AssertionError as e: except AssertionError as e:
warnings.warn(f'{e}') warnings.warn(f'{e}')
resharding_cost = INFINITY_COST resharding_cost = INFINITY_COST

View File

@ -93,7 +93,7 @@ class BcastOpHandler(OperatorHandler):
input_sharding_spec, input_spec) input_sharding_spec, input_spec)
# we need multiply the size of elem dtype to get correct communication cost # we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost * size_per_elem_bytes resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
resharding_costs[input_node].append(resharding_cost) resharding_costs[input_node].append(resharding_cost)
return resharding_costs return resharding_costs

View File

@ -91,18 +91,11 @@ class StrategyGenerator_V2(ABC):
num_ele_in_comm = comm_spec.get_comm_cost() num_ele_in_comm = comm_spec.get_comm_cost()
dtype = operand.data.dtype dtype = operand.data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
cost = size_per_elem_bytes * num_ele_in_comm for phase, cost in num_ele_in_comm.items():
num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes
# compute the fwd comm_cost.fwd += num_ele_in_comm['forward']
# TODO: comm_spec.get_comm_cost should return a TrainCycleItem instead of the total cost. comm_cost.bwd += num_ele_in_comm['backward']
# it works fine here because only REDUCE_FWD_IDENTITY_BWD and IDENTITY_FWD_ALLREDUCE_BWD are used, comm_cost.total += num_ele_in_comm['total']
# so total cost is either for fwd or bwd.
if comm_spec.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
comm_cost.fwd += cost
elif comm_spec.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
comm_cost.fwd += cost
else:
raise ValueError(f"Found unknown CommunicationType {comm_spec.comm_pattern}")
# check if communication action exists # check if communication action exists
# if so, loop over each action and compute the cost of each action # if so, loop over each action and compute the cost of each action
@ -110,9 +103,6 @@ class StrategyGenerator_V2(ABC):
for operand, comm_spec in strategy.communication_actions.items(): for operand, comm_spec in strategy.communication_actions.items():
_compute_and_add(operand, comm_spec) _compute_and_add(operand, comm_spec)
# update the total cost
comm_cost.total = comm_cost.fwd + comm_cost.bwd
# update the communication cost attribute in-place # update the communication cost attribute in-place
strategy.communication_cost = comm_cost strategy.communication_cost = comm_cost
return strategy return strategy

View File

@ -9,10 +9,11 @@ from .colo_parameter import ColoParameter
from .utils import convert_parameter, named_params_with_colotensor from .utils import convert_parameter, named_params_with_colotensor
from .dist_spec_mgr import DistSpecManager from .dist_spec_mgr import DistSpecManager
from .param_op_hook import ParamOpHook, ParamOpHookManager from .param_op_hook import ParamOpHook, ParamOpHookManager
from .comm_spec import CollectiveCommPattern, CommSpec
from . import distspec from . import distspec
__all__ = [ __all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter', 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec',
'ReplicaSpec' 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern'
] ]

View File

@ -0,0 +1,358 @@
import torch
from enum import Enum
import torch.distributed as dist
from functools import reduce
import operator
from torch.distributed import ReduceOp
__all__ = [
'CollectiveCommPattern',
'CommSpec',
]
def _all_gather(tensor, comm_spec):
'''
Implement all gather operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
tensor_list = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
]
tensor = tensor
group = process_group
dist.all_gather(tensor_list, tensor, group=group)
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
return output
def _split(tensor, comm_spec):
'''
Implement shard operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, _ in process_groups_list:
if dist.get_rank() in rank_list:
tensor = tensor
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
start = length * rank_list.index(dist.get_rank())
output = torch.narrow(tensor, dim, start, length)
return output
def _all_to_all(tensor, comm_spec):
'''
Implement all to all operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
new_shape = list(tensor.shape)
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
new_shape = torch.Size(new_shape)
output_tensor_list = [
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
]
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
input_tensor_list = [
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
]
group = process_group
dist.all_to_all(output_tensor_list, input_tensor_list, group)
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
return output
def _all_reduce(tensor, comm_spec):
'''
Implement all reduce operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group)
return tensor
class _ReduceGrad(torch.autograd.Function):
"""
A customized communication operation which forward is an identity operation,
backward is all_reduce operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_, comm_spec):
ctx.comm_spec = comm_spec
return input_
@staticmethod
def backward(ctx, grad_output):
return _all_reduce(grad_output, ctx.comm_spec), None
class _ReduceInput(torch.autograd.Function):
"""
A customized communication operation which forward is all_reduce operation,
backward is an identity operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return _all_reduce(input_)
@staticmethod
def forward(ctx, input_, comm_spec):
return _all_reduce(input_, comm_spec)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
A customized communication operation which forward is split operation,
backward is an all gather operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod
def forward(ctx, input_, comm_spec):
ctx.comm_spec = comm_spec
return _split(input_, comm_spec)
@staticmethod
def backward(ctx, grad_output):
return _all_gather(grad_output, ctx.comm_spec), None
class _GatherForwardSplitBackward(torch.autograd.Function):
"""
A customized communication operation which forward is an all gather operation,
backward is split operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return _all_gather(input_)
@staticmethod
def forward(ctx, input_, comm_spec):
ctx.comm_spec = comm_spec
return _all_gather(input_, comm_spec)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.comm_spec), None
class _AllToAll(torch.autograd.Function):
"""
A customized communication operation which forward is an all to all operation,
backward is an all to all operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return _all_to_all(input_)
@staticmethod
def forward(ctx, input_, comm_spec):
output = _all_to_all(input_, comm_spec)
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
sharding_spec=comm_spec.sharding_spec,
gather_dim=comm_spec.shard_dim,
shard_dim=comm_spec.gather_dim,
logical_process_axis=comm_spec.logical_process_axis)
ctx.comm_spec = comm_spec_for_backward
return output
@staticmethod
def backward(ctx, grad_outputs):
return _all_to_all(grad_outputs, ctx.comm_spec), None
def reduce_grad(input_, comm_spec):
return _ReduceGrad.apply(input_, comm_spec)
def reduce_input(input_, comm_spec):
return _ReduceInput.apply(input_, comm_spec)
def split_forward_gather_backward(input_, comm_spec):
return _SplitForwardGatherBackward.apply(input_, comm_spec)
def gather_forward_split_backward(input_, comm_spec):
return _GatherForwardSplitBackward.apply(input_, comm_spec)
def all_to_all(input_, comm_spec):
return _AllToAll.apply(input_, comm_spec)
class CollectiveCommPattern(Enum):
GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd'
ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd'
SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd'
ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd'
IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd'
class CommSpec:
'''
Communication spec is used to record the communication action. It has two main functions:
1. Compute the communication cost which will be used in auto parallel solver.
2. Convert the communication spec to real action which will be used in runtime.
It contains comm_pattern to determine the
communication method, sharding_spec to determine the communication size, gather_dim and shard_dim
to determine the buffer shape, and logical_process_axis
Argument:
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
sharding_spec(ShardingSpec): This is sharding spec of the tensor which will join the communication action.
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
'''
def __init__(self,
comm_pattern,
sharding_spec,
gather_dim=None,
shard_dim=None,
logical_process_axis=None,
forward_only=False):
self.comm_pattern = comm_pattern
self.sharding_spec = sharding_spec
self.gather_dim = gather_dim
self.shard_dim = shard_dim
self.logical_process_axis = logical_process_axis
self.forward_only = forward_only
if isinstance(self.logical_process_axis, list):
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
self.logical_process_axis = 0
else:
self.device_mesh = self.sharding_spec.device_mesh
def __repr__(self):
res_list = ["CommSpec:("]
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
res_list.append(f"comm_pattern:GATHER_FWD_SPLIT_BWD, ")
res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
res_list.append(f"comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, ")
res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"shard_dim:{self.shard_dim}, ")
res_list.append(f"logical_process_axis: {self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
res_list.append(f"comm_pattern:SPLIT_FWD_GATHER_BWD, ")
res_list.append(f"shard_dim:{self.shard_dim}, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
res_list.append(f"comm_pattern:ALLREDUCE_FWD_IDENTITY_BWD, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
return ''.join(res_list)
def get_comm_cost(self):
'''
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
compute the communication cost.
For shard operation, it is an on-chip operation, so the communication cost is zero.
'''
comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1)
cost_dict = {}
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
# give a tiny cost to shard
backward_communication_cost = 10
if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
# grad should have same shape as input tensor
# all to all operation has same logical process axis as forward.
backward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
forward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis)
backward_communication_cost = 0
if self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
forward_communication_cost = 0
backward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
# give a tiny cost to shard
forward_communication_cost = 10
backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
if self.forward_only:
cost_dict["forward"] = forward_communication_cost
cost_dict["backward"] = 0
cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"]
else:
cost_dict["forward"] = forward_communication_cost
cost_dict["backward"] = backward_communication_cost
cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"]
return cost_dict
def covert_spec_to_action(self, tensor):
'''
Convert CommSpec into runtime action, implement real collection communication to target tensor.
The collection communication action is directed by the CommSpec.
Argument:
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
'''
if self.comm_pattern in pattern_to_func_dict:
tensor.data = pattern_to_func_dict[self.comm_pattern](tensor, self)
else:
tensor.data = tensor
pattern_to_func_dict = {
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: gather_forward_split_backward,
CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: all_to_all,
CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: split_forward_gather_backward,
CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: reduce_input,
CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: reduce_grad,
}

View File

@ -11,355 +11,9 @@ import math
from functools import reduce from functools import reduce
import operator import operator
from torch.distributed import ReduceOp from torch.distributed import ReduceOp
from .comm_spec import *
__all__ = [ __all__ = ['ShapeConsistencyManager', 'ShapeConsistencyOptions', 'set_shape_consistency_options']
'CollectiveCommPattern', 'CommSpec', 'ShapeConsistencyManager', 'ShapeConsistencyOptions',
'set_shape_consistency_options'
]
def _all_gather(tensor, comm_spec):
'''
Implement all gather operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
tensor_list = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
]
tensor = tensor
group = process_group
dist.all_gather(tensor_list, tensor, group=group)
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
return output
def _split(tensor, comm_spec):
'''
Implement shard operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, _ in process_groups_list:
if dist.get_rank() in rank_list:
tensor = tensor
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
start = length * rank_list.index(dist.get_rank())
output = torch.narrow(tensor, dim, start, length)
return output
def _all_to_all(tensor, comm_spec):
'''
Implement all to all operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
new_shape = list(tensor.shape)
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
new_shape = torch.Size(new_shape)
output_tensor_list = [
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
]
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
input_tensor_list = [
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
]
group = process_group
dist.all_to_all(output_tensor_list, input_tensor_list, group)
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
return output
def _all_reduce(tensor, comm_spec):
'''
Implement all reduce operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group)
return tensor
class _ReduceGrad(torch.autograd.Function):
"""
A customized communication operation which forward is an identity operation,
backward is all_reduce operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_, comm_spec):
ctx.comm_spec = comm_spec
return input_
@staticmethod
def backward(ctx, grad_output):
return _all_reduce(grad_output, ctx.comm_spec), None
class _ReduceInput(torch.autograd.Function):
"""
A customized communication operation which forward is all_reduce operation,
backward is an identity operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return _all_reduce(input_)
@staticmethod
def forward(ctx, input_, comm_spec):
return _all_reduce(input_, comm_spec)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
A customized communication operation which forward is split operation,
backward is an all gather operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod
def forward(ctx, input_, comm_spec):
ctx.comm_spec = comm_spec
return _split(input_, comm_spec)
@staticmethod
def backward(ctx, grad_output):
return _all_gather(grad_output, ctx.comm_spec), None
class _GatherForwardSplitBackward(torch.autograd.Function):
"""
A customized communication operation which forward is an all gather operation,
backward is split operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return _all_gather(input_)
@staticmethod
def forward(ctx, input_, comm_spec):
ctx.comm_spec = comm_spec
return _all_gather(input_, comm_spec)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.comm_spec), None
class _AllToAll(torch.autograd.Function):
"""
A customized communication operation which forward is an all to all operation,
backward is an all to all operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return _all_to_all(input_)
@staticmethod
def forward(ctx, input_, comm_spec):
output = _all_to_all(input_, comm_spec)
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
sharding_spec=comm_spec.sharding_spec,
gather_dim=comm_spec.shard_dim,
shard_dim=comm_spec.gather_dim,
logical_process_axis=comm_spec.logical_process_axis)
ctx.comm_spec = comm_spec_for_backward
return output
@staticmethod
def backward(ctx, grad_outputs):
return _all_to_all(grad_outputs, ctx.comm_spec), None
def reduce_grad(input_, comm_spec):
return _ReduceGrad.apply(input_, comm_spec)
def reduce_input(input_, comm_spec):
return _ReduceInput.apply(input_, comm_spec)
def split_forward_gather_backward(input_, comm_spec):
return _SplitForwardGatherBackward.apply(input_, comm_spec)
def gather_forward_split_backward(input_, comm_spec):
return _GatherForwardSplitBackward.apply(input_, comm_spec)
def all_to_all(input_, comm_spec):
return _AllToAll.apply(input_, comm_spec)
class CollectiveCommPattern(Enum):
GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd'
ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd'
SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd'
ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd'
IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd'
class CommSpec:
'''
Communication spec is used to record the communication action. It has two main functions:
1. Compute the communication cost which will be used in auto parallel solver.
2. Convert the communication spec to real action which will be used in runtime.
It contains comm_pattern to determine the
communication method, sharding_spec to determine the communication size, gather_dim and shard_dim
to determine the buffer shape, and logical_process_axis
Argument:
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
sharding_spec(ShardingSpec): This is sharding spec of the tensor which will join the communication action.
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
'''
def __init__(self,
comm_pattern,
sharding_spec,
gather_dim=None,
shard_dim=None,
logical_process_axis=None,
forward_only=False):
self.comm_pattern = comm_pattern
self.sharding_spec = sharding_spec
self.gather_dim = gather_dim
self.shard_dim = shard_dim
self.logical_process_axis = logical_process_axis
self.forward_only = forward_only
if isinstance(self.logical_process_axis, list):
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
self.logical_process_axis = 0
else:
self.device_mesh = self.sharding_spec.device_mesh
def __repr__(self):
res_list = ["CommSpec:("]
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
res_list.append(f"comm_pattern:GATHER_FWD_SPLIT_BWD, ")
res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
res_list.append(f"comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, ")
res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"shard_dim:{self.shard_dim}, ")
res_list.append(f"logical_process_axis: {self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
res_list.append(f"comm_pattern:SPLIT_FWD_GATHER_BWD, ")
res_list.append(f"shard_dim:{self.shard_dim}, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
res_list.append(f"comm_pattern:ALLREDUCE_FWD_IDENTITY_BWD, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
return ''.join(res_list)
def get_comm_cost(self):
'''
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
compute the communication cost.
For shard operation, it is an on-chip operation, so the communication cost is zero.
'''
comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1)
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
# give a tiny cost to shard
backward_communication_cost = 10
if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
# grad should have same shape as input tensor
# all to all operation has same logical process axis as forward.
backward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
forward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis)
backward_communication_cost = 0
if self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
forward_communication_cost = 0
backward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
# give a tiny cost to shard
forward_communication_cost = 10
backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
try:
if self.forward_only:
total_communication_cost = forward_communication_cost
else:
total_communication_cost = forward_communication_cost + backward_communication_cost
except:
raise RuntimeError(f"Could not find a matching CollectiveCommPattern for {self.comm_pattern}.")
return total_communication_cost
def covert_spec_to_action(self, tensor):
'''
Convert CommSpec into runtime action, implement real collection communication to target tensor.
The collection communication action is directed by the CommSpec.
Argument:
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
'''
if self.comm_pattern in pattern_to_func_dict:
tensor.data = pattern_to_func_dict[self.comm_pattern](tensor, self)
else:
tensor.data = tensor
pattern_to_func_dict = {
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: gather_forward_split_backward,
CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: all_to_all,
CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: split_forward_gather_backward,
CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: reduce_input,
CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: reduce_grad,
}
@dataclass @dataclass
@ -406,7 +60,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
assert isinstance(value, bool) assert isinstance(value, bool)
self._forward_only = value self._forward_only = value
def get_all_all_gather_spec(self, source_spec, orig_cost): def get_all_all_gather_spec(self, source_spec, orig_cost_dict):
''' '''
Get all valid sharding specs from source_spec with single all-gather operation, and Get all valid sharding specs from source_spec with single all-gather operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver. accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
@ -463,16 +117,18 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
forward_only=self.forward_only) forward_only=self.forward_only)
# compute the communication cost with CommSpec # compute the communication cost with CommSpec
cost = comm_spec.get_comm_cost() cost_dict = comm_spec.get_comm_cost()
# generate new sharding spec # generate new sharding spec
new_sharding_spec = ShardingSpec(source_spec.device_mesh, new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape, source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict) dim_partition_dict=new_dim_partition_dict)
valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost) for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
return valid_spec_dict return valid_spec_dict
def get_all_all_to_all_spec(self, source_spec, orig_cost): def get_all_all_to_all_spec(self, source_spec, orig_cost_dict):
''' '''
Get all valid sharding specs from source_spec with single all-to-all operation, and Get all valid sharding specs from source_spec with single all-to-all operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver. accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
@ -552,7 +208,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
forward_only=self.forward_only) forward_only=self.forward_only)
# compute the communication cost with CommSpec # compute the communication cost with CommSpec
cost = comm_spec.get_comm_cost() cost_dict = comm_spec.get_comm_cost()
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
# We won't add empty list into dim_partition_dict # We won't add empty list into dim_partition_dict
@ -570,10 +226,12 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
new_sharding_spec = ShardingSpec(source_spec.device_mesh, new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape, source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict) dim_partition_dict=new_dim_partition_dict)
valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost) for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
return valid_spec_dict return valid_spec_dict
def get_all_shard_spec(self, source_spec, orig_cost): def get_all_shard_spec(self, source_spec, orig_cost_dict):
''' '''
Get all valid sharding specs from source_spec with single shard operation, and Get all valid sharding specs from source_spec with single shard operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver. accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
@ -639,16 +297,18 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
forward_only=self.forward_only) forward_only=self.forward_only)
# compute the communication cost with CommSpec # compute the communication cost with CommSpec
cost = comm_spec.get_comm_cost() cost_dict = comm_spec.get_comm_cost()
# generate new sharding spec # generate new sharding spec
new_sharding_spec = ShardingSpec(source_spec.device_mesh, new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape, source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict) dim_partition_dict=new_dim_partition_dict)
valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost) for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
return valid_spec_dict return valid_spec_dict
def get_all_one_step_transform_spec(self, source_spec, orig_cost): def get_all_one_step_transform_spec(self, source_spec, orig_cost_dict):
''' '''
Get all valid sharding specs from source_spec with one step transform, and Get all valid sharding specs from source_spec with one step transform, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver. accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
@ -665,9 +325,9 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
''' '''
valid_spec_dict = {} valid_spec_dict = {}
valid_spec_dict.update(self.get_all_all_gather_spec(source_spec, orig_cost)) valid_spec_dict.update(self.get_all_all_gather_spec(source_spec, orig_cost_dict))
valid_spec_dict.update(self.get_all_all_to_all_spec(source_spec, orig_cost)) valid_spec_dict.update(self.get_all_all_to_all_spec(source_spec, orig_cost_dict))
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost)) valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict))
return valid_spec_dict return valid_spec_dict
def shape_consistency(self, source_spec, target_spec): def shape_consistency(self, source_spec, target_spec):
@ -730,7 +390,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
total_cost: 12294.402000000002 total_cost: 12294.402000000002
''' '''
MAX_TRANSFORM_STEPS = 20 MAX_TRANSFORM_STEPS = 20
total_cost = 0 total_cost_dict = {'forward': 0, 'backward': 0, 'total': 0}
total_steps = 0 total_steps = 0
transform_path = [] transform_path = []
comm_action_sequence = [] comm_action_sequence = []
@ -740,35 +400,37 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# We do nothing if the sharding spec is all the same. # We do nothing if the sharding spec is all the same.
if source_spec.sharding_sequence_difference(target_spec) == 0: if source_spec.sharding_sequence_difference(target_spec) == 0:
self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence) self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence)
return (transform_path, comm_action_sequence, total_cost) return (transform_path, comm_action_sequence, total_cost_dict)
temp_sharding_spec = source_spec temp_sharding_spec = source_spec
transform_path.append(temp_sharding_spec) transform_path.append(temp_sharding_spec)
# To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms # To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
while total_steps <= MAX_TRANSFORM_STEPS: while total_steps <= MAX_TRANSFORM_STEPS:
valid_transform_spec_dict = self.get_all_one_step_transform_spec(temp_sharding_spec, total_cost) valid_transform_spec_dict = self.get_all_one_step_transform_spec(temp_sharding_spec, total_cost_dict)
best_difference_score = math.inf best_difference_score = math.inf
for sharding_spec, info_pairs in valid_transform_spec_dict.items(): for sharding_spec, info_pairs in valid_transform_spec_dict.items():
comm_spec, cost = info_pairs comm_spec, cost_dict = info_pairs
spec_difference = sharding_spec.sharding_sequence_difference(target_spec) spec_difference = sharding_spec.sharding_sequence_difference(target_spec)
if spec_difference == 0: if spec_difference == 0:
total_cost += cost for phase, cost in total_cost_dict.items():
total_cost_dict[phase] = cost + cost_dict[phase]
transform_path.append(sharding_spec) transform_path.append(sharding_spec)
comm_action_sequence.append(comm_spec) comm_action_sequence.append(comm_spec)
self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence) self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence)
return (transform_path, comm_action_sequence, total_cost) return (transform_path, comm_action_sequence, total_cost_dict)
if spec_difference < best_difference_score: if spec_difference < best_difference_score:
temp_sharding_spec = sharding_spec temp_sharding_spec = sharding_spec
temp_cost = cost temp_cost_dict = cost_dict
temp_comm_spec = comm_spec temp_comm_spec = comm_spec
best_difference_score = spec_difference best_difference_score = spec_difference
transform_path.append(temp_sharding_spec) transform_path.append(temp_sharding_spec)
comm_action_sequence.append(temp_comm_spec) comm_action_sequence.append(temp_comm_spec)
total_cost += temp_cost for phase, cost in total_cost_dict.items():
total_cost_dict[phase] = cost + temp_cost_dict[phase]
total_steps += 1 total_steps += 1
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")

View File

@ -27,7 +27,11 @@ def test_one_step_transform():
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0), 0), DistSpec: # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0), 0), DistSpec:
# shard_sequence: S0,R,R # shard_sequence: S0,R,R
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), 0)} # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), 0)}
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0) rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {
"forward": 0,
"backward": 0,
"total": 0
})
assert '[R, S1, R]' in [ assert '[R, S1, R]' in [
str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys() str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys()
@ -48,7 +52,11 @@ def test_one_step_transform():
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:0, shard_dim:2, logical_process_axis: 0), 0), DistSpec: # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:0, shard_dim:2, logical_process_axis: 0), 0), DistSpec:
# shard_sequence: S0,R,S1 # shard_sequence: S0,R,S1
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:2, logical_process_axis: 1), 0)} # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:2, logical_process_axis: 1), 0)}
rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, 0) rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, {
"forward": 0,
"backward": 0,
"total": 0
})
assert '[S01, R, R]' in [ assert '[S01, R, R]' in [
str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys()
@ -72,7 +80,11 @@ def test_one_step_transform():
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1), 0), DistSpec: # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1), 0), DistSpec:
# shard_sequence: S0,R,S1 # shard_sequence: S0,R,S1
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:2, logical_process_axis:1), 0)} # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:2, logical_process_axis:1), 0)}
rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, 0) rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, {
"forward": 0,
"backward": 0,
"total": 0
})
assert '[S01, R, R]' in [ assert '[S01, R, R]' in [
str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()