[autoparallel] add sequential order to communication actions (#1735)

This commit is contained in:
YuliangLiu0306
2022-10-20 18:48:18 +08:00
committed by GitHub
parent b893342f95
commit a4ce180e85
7 changed files with 293 additions and 90 deletions

View File

@@ -1,8 +1,9 @@
import torch
from enum import Enum
import torch.distributed as dist
from functools import reduce
import operator
from enum import Enum
from functools import reduce
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
__all__ = [
@@ -238,7 +239,7 @@ class CommSpec:
1. Compute the communication cost which will be used in auto parallel solver.
2. Convert the communication spec to real action which will be used in runtime.
It contains comm_pattern to determine the
communication method, sharding_spec to determine the communication size, gather_dim and shard_dim
communication method, sharding_spec to determine the communication size, gather_dim and shard_dim
to determine the buffer shape, and logical_process_axis
Argument:
@@ -296,7 +297,7 @@ class CommSpec:
'''
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
compute the communication cost.
For shard operation, it is an on-chip operation, so the communication cost is zero.
For shard operation, it is an on-chip operation, so the communication cost is zero.
'''
comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1)
cost_dict = {}
@@ -347,6 +348,7 @@ class CommSpec:
tensor.data = pattern_to_func_dict[self.comm_pattern](tensor, self)
else:
tensor.data = tensor
return tensor
pattern_to_func_dict = {